Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 27, 2023
1 parent d220ca5 commit d4bae69
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 97 deletions.
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,7 @@ def _calculate_activation_scale(
"""
activation_ports_map = {node: self._backend_entity.get_activations_port_id(node, nncf_graph) for node in nodes}
channel_axes = [
self._backend_entity.get_activation_channel_axis(node, port, activations_shape)
for node, port in activation_ports_map.items()
self._backend_entity.get_activation_channel_axis(node, port) for node, port in activation_ports_map.items()
]
channel_axis = channel_axes[0]

Expand Down Expand Up @@ -341,7 +340,7 @@ def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode,
shape = nncf_graph.get_input_edges(node)[input_port].tensor_shape
reduction_axes = tuple([])
if len(shape) > 1:
channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port, shape)
channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port)
reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes(channel_axis, shape)
return reduction_axes

Expand Down
13 changes: 1 addition & 12 deletions nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def scale_insertion_command(

@staticmethod
@abstractmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
"""
Returns axis number of the activation tensor which correspond to it channel.
Expand All @@ -242,17 +242,6 @@ def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
:return: Channel axis number.
"""

@staticmethod
@abstractmethod
def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
"""
Returns port-based channel axis.
:param port_id: Specified input port id.
:param transpose: Transpose position.
:return: Channel axis.
"""

@staticmethod
@abstractmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
Expand Down
8 changes: 4 additions & 4 deletions nncf/quantization/algorithms/smooth_quant/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def scale_insertion_command(
)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
channel_axis = 1

if port_id > 1:
Expand All @@ -174,7 +174,7 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape:
and "transpose" in node.layer_attributes.input_attributes
):
transpose = node.layer_attributes.input_attributes["transpose"]
channel_axis = OVSmoothQuantAlgoBackend.calculate_port_based_channel_axis(port_id, transpose)
channel_axis = OVSmoothQuantAlgoBackend._calculate_port_based_channel_axis(port_id, transpose)

return channel_axis

Expand All @@ -189,12 +189,12 @@ def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
if node.metatype == OVMatMulMetatype:
if "transpose" in node.layer_attributes.constant_attributes[port_id]:
transpose = node.layer_attributes.constant_attributes[port_id]["transpose"]
channel_axis = OVSmoothQuantAlgoBackend.calculate_port_based_channel_axis(port_id, transpose)
channel_axis = OVSmoothQuantAlgoBackend._calculate_port_based_channel_axis(port_id, transpose)

return channel_axis

@staticmethod
def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
def _calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
return -2 + port_id if transpose else -1 - port_id

@staticmethod
Expand Down
12 changes: 6 additions & 6 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,17 @@ def scale_insertion_command(
return multiply_insertion_command(nodes, scale_value, scale_node_name, input_port_id)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int:
return len(activations_shape) - 1
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
if node.metatype == om.PTModuleLinearMetatype:
return -1
# TODO: Add activation axis calculation when MatMul wiil be supported
return 1

@staticmethod
def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
# TODO: Add activation axis calculation when MatMul wiil be supported
return 1

@staticmethod
def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
return -2 + port_id if transpose else -1 - port_id

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
return node.is_shared()
Expand Down
2 changes: 0 additions & 2 deletions nncf/torch/dynamic_graph/layer_attributes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@
from nncf.common.graph.layer_attributes import PermuteLayerAttributes
from nncf.common.graph.layer_attributes import ReshapeLayerAttributes
from nncf.common.graph.layer_attributes import TransposeLayerAttributes
from nncf.common.graph.layer_attributes import MatMulLayerAttributes
from nncf.common.graph.utils import get_concat_axis
from nncf.common.graph.utils import get_split_axis
from nncf.torch.graph.operator_metatypes import PTCatMetatype
from nncf.torch.graph.operator_metatypes import PTGroupNormMetatype
from nncf.torch.graph.operator_metatypes import PTPadMetatype
from nncf.torch.graph.operator_metatypes import PTMatMulMetatype
from nncf.torch.graph.operator_metatypes import PTReshapeMetatype
from nncf.torch.graph.operator_metatypes import PTSplitMetatype
from nncf.torch.graph.operator_metatypes import PTSqueezeMetatype
Expand Down
17 changes: 9 additions & 8 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW
return PTWeightUpdateCommand(target_point, weight_value)


class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
super().__init__()
self._scale_value = scale_value

def forward(self, x):
return torch.mul(x, self._scale_value)


def multiply_insertion_command(
target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int
) -> PTInsertionCommand:
Expand All @@ -56,12 +65,4 @@ def multiply_insertion_command(
target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY))

class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
super().__init__()
self._scale_value = scale_value

def forward(self, x):
return torch.mul(x, self._scale_value)

return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name)
2 changes: 1 addition & 1 deletion nncf/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def device(self) -> torch.device:
return self._tensor.device

def is_empty(self) -> bool:
return self.tensor.size == 0
return self.tensor.numel() == 0
55 changes: 51 additions & 4 deletions tests/openvino/native/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,61 @@
import torch
from openvino.tools.mo import convert_model

from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend
from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm

OV_LINEAR_MODEL_MM_OP_MAP = {
"MatMul1": "/MatMul",
"MatMul2": "/MatMul_1",
"MatMul3": "/MatMul_2",
"MatMul4": "/MatMul_4",
"MatMul5": "32",
"MatMul6": "37",
"MatMul7": "54",
"MatMul8": "68",
"Linear1": "/linear_2/MatMul",
"Linear2": "/linear_1/MatMul",
"Linear3": "/linear_3/MatMul",
"Linear4": "/linear_4/MatMul",
}


OV_LINEAR_MODEL_SQ_OP_MAP = {
"MatMul1": "/Reshape_0_0/nncf_smooth_quant",
"MatMul2": "/Reshape_0_0/nncf_smooth_quant",
"MatMul3": "/Reshape_1_0_0/nncf_smooth_quant",
"MatMul4": "/Reshape_1_0_1/nncf_smooth_quant",
"MatMul5": "/Reshape_2_0_0/nncf_smooth_quant",
"MatMul6": "/ReduceMax_0_0/nncf_smooth_quant",
"MatMul7": "/Reshape_3_0_0/nncf_smooth_quant",
"MatMul8": "/Reshape_4_0_0/nncf_smooth_quant",
"Linear1": "/Split_1_0/nncf_smooth_quant",
"Linear2": "/Split_0_0/nncf_smooth_quant",
"Linear3": "/Add_0_0/nncf_smooth_quant",
"Linear4": "/Add_0_0/nncf_smooth_quant",
}


class TestOVSQAlgorithm(TemplateTestSQAlgorithm):
@staticmethod
def fn_to_type(tensor) -> np.ndarray:
return np.array(tensor)

@pytest.fixture(params=[False, True], ids=["out_of_palce", "inplace"])
def inplace_statistics(self, request) -> bool:
return request.param

def get_node_name_map(self) -> Dict[str, str]:
return OV_LINEAR_MODEL_MM_OP_MAP

@staticmethod
def get_target_node_name(command: TransformationCommand):
return command.target_point.target_node_name

@staticmethod
def get_transform_fn() -> Callable:
def transform_fn(data_item):
Expand All @@ -53,10 +96,14 @@ def backend_specific_model(model: torch.nn.Module, tmp_dir: str) -> ov.Model:
@staticmethod
def check_scales(model: ov.Model, reference_values: Dict[str, np.ndarray]) -> None:
ops_list = {op.get_friendly_name(): op for op in model.get_ops()}
for ref_name, ref_value in reference_values.items():
node = ops_list[ref_name]
const_node = node.input(1).get_source_output().get_node()

for ref_names, ref_value in reference_values.items():
const_nodes = []
for ref_name in ref_names:
node = ops_list[OV_LINEAR_MODEL_SQ_OP_MAP[ref_name]]
const_nodes.append(node.input(1).get_source_output().get_node())
# Check unified group acutally shares one constant
assert all(node is const_nodes[0] for node in const_nodes[1:])
const_node = const_nodes[0]
assert const_node.get_type_name() == "Constant"

value = const_node.data
Expand Down
31 changes: 29 additions & 2 deletions tests/post_training/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,30 @@ def __init__(self) -> None:
self.matmul_7_data = torch.randn((6, 6), dtype=torch.float32)
self.matmul_8_data = torch.randn((10, 6), dtype=torch.float32)

self.linear_3 = nn.Linear(4, 4)
self.linear_3.weight.data = torch.randn((4, 4), dtype=torch.float32)
self.linear_3.bias.data = torch.randn((1, 4), dtype=torch.float32)

self.linear_4 = nn.Linear(4, 4)
self.linear_4.weight.data = torch.randn((4, 4), dtype=torch.float32)
self.linear_4.bias.data = torch.randn((1, 4), dtype=torch.float32)

def forward(self, x):
x = torch.reshape(x, (1, 3, 2, 4))

x_1 = torch.matmul(x, self.matmul_1_data)
x_2 = torch.matmul(x, self.matmul_2_data)

x = torch.add(x_1, x_2)

x_3 = self.linear_3(x)
x_4 = self.linear_4(x)

x_ = torch.add(x_3, x_4)

x = torch.add(x, x_)
x = torch.sub(x, x_)

x_1 = torch.reshape(x, (1, 3, 8))

x_1_1 = torch.matmul(x_1, self.matmul_3_data)
Expand All @@ -189,13 +206,23 @@ def forward(self, x):
class NonZeroLinearModel(nn.Module):
INPUT_SIZE = [10]

def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 5)
self.linear.weight.data = torch.ones((5, 1))
self.linear.bias.data = torch.zeros((1, 1))

self.linear1 = nn.Linear(10, 10)
self.linear1.weight.data = torch.ones((10, 10))
self.linear1.bias.data = torch.zeros((1, 1))

def forward(self, x):
zeros = (x > torch.inf).float()
empty = torch.nonzero(zeros).reshape((-1, 1, 1)).float()
y = torch.matmul(empty, torch.ones((1, 5)))
y = self.linear(empty)
y += 5
y = torch.cat((torch.ones((1, 10)), y.reshape(1, -1)), dim=1)
y = torch.matmul(y, torch.ones(10, 10))
y = self.linear1(y)
y += 5
return y

Expand Down
Loading

0 comments on commit d4bae69

Please sign in to comment.