Skip to content

Commit

Permalink
Weights layout in conv/matmul layer attributes is introduced
Browse files Browse the repository at this point in the history
Refactor smooth quant to use weights layout

Tests
  • Loading branch information
daniil-lyakhov committed Aug 25, 2023
1 parent 0dcca5a commit 9da200c
Show file tree
Hide file tree
Showing 14 changed files with 545 additions and 322 deletions.
24 changes: 22 additions & 2 deletions nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Tuple, Union
from typing import Any, List, Optional, Tuple, Union


class Dtype(Enum):
FLOAT = "float"
INTEGER = "int"


class LayoutElem(Enum):
C_IN = "channels_in"
C_OUT = "channels_out"
SPATIAL = "spatial"
GROUPS = "groups"


class BaseLayerAttributes(ABC):
"""
This class stores base useful for some algorithms attributes
Expand All @@ -30,6 +37,9 @@ class BaseLayerAttributes(ABC):
def __eq__(self, __o: object) -> bool:
return isinstance(__o, self.__class__) and self.__dict__ == __o.__dict__

def get_backend_agnostic_attributes(self) -> "BaseLayerAttributes":
return self


class MultipleInputLayerAttributes(BaseLayerAttributes):
def __init__(self, axis: int):
Expand Down Expand Up @@ -109,7 +119,14 @@ def get_target_dim_for_compression(self) -> int:


class LinearLayerAttributes(WeightedLayerAttributes):
def __init__(self, weight_requires_grad: bool, in_features: int, out_features: int, with_bias: bool = True):
def __init__(
self,
weight_requires_grad: bool,
in_features: int,
out_features: int,
with_bias: bool = True,
weights_layout: Optional[Tuple[LayoutElem, ...]] = None,
):
"""
:param weight_requires_grad: Is True if gradients need to be computed for the corresponding Tensor,
Expand All @@ -120,6 +137,7 @@ def __init__(self, weight_requires_grad: bool, in_features: int, out_features: i
super().__init__(weight_requires_grad, with_bias=with_bias)
self.in_features = in_features
self.out_features = out_features
self.weights_layout = weights_layout

def get_weight_shape(self) -> List[int]:
return [self.out_features, self.in_features]
Expand All @@ -144,6 +162,7 @@ def __init__(
transpose: bool,
padding_values: Tuple[int, ...],
with_bias: bool = False,
weights_layout: Optional[Tuple[LayoutElem, ...]] = None,
):
"""
Expand All @@ -167,6 +186,7 @@ def __init__(
self.groups = groups
self.transpose = transpose
self.padding_values = padding_values
self.weights_layout = weights_layout

def get_weight_shape(self) -> List[int]:
if not self.transpose:
Expand Down
122 changes: 72 additions & 50 deletions nncf/openvino/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.layer_attributes import GenericWeightedLayerAttributes
from nncf.common.graph.layer_attributes import LayoutElem
from nncf.common.graph.layer_attributes import LinearLayerAttributes
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionBackpropDataMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVDepthwiseConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionBackpropDataMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVOpMetatype


Expand All @@ -33,7 +36,7 @@ class OVLayerAttributes(BaseLayerAttributes):
def __init__(
self,
constant_attributes: Dict[int, Any],
layer_attributes: Optional[Dict[int, BaseLayerAttributes]] = None,
layer_attributes: Optional[BaseLayerAttributes] = None,
inputs_attributes: Optional[Dict[Any, Any]] = None,
):
"""
Expand All @@ -49,10 +52,6 @@ def __init__(
def constant_attributes(self) -> Dict[int, Any]:
return self._constant_attributes

@property
def layer_attributes(self) -> Optional[Dict[int, BaseLayerAttributes]]:
return self._layer_attributes

@property
def input_attributes(self) -> Optional[Dict[Any, Any]]:
return self._inputs_attributes
Expand All @@ -67,6 +66,9 @@ def get_const_port_ids(self) -> List[int]:
return list(self._constant_attributes.keys())
return []

def get_backend_agnostic_attributes(self):
return self._layer_attributes


def get_weighted_layer_attributes(
ov_node: ov.Node, ov_metatype: OVOpMetatype, constant_attributes: Dict[str, Any]
Expand All @@ -79,51 +81,71 @@ def get_weighted_layer_attributes(
:param constant_attributes: Constant attributes collected for the given node.
:return: Weighted layer attributes for the given node.
"""
retval = {}
for port_id, attrs in constant_attributes.items():
if ov_metatype in [
OVConvolutionMetatype,
OVDepthwiseConvolutionMetatype,
OVGroupConvolutionMetatype,
OVConvolutionBackpropDataMetatype,
OVGroupConvolutionBackpropDataMetatype,
]:
node_attrs = ov_node.get_attributes()
kwargs = {
"weight_requires_grad": False,
"stride": tuple(node_attrs["strides"]),
"dilations": node_attrs["dilations"],
"transpose": ov_metatype in [OVConvolutionBackpropDataMetatype, OVGroupConvolutionBackpropDataMetatype],
# TODO: ticket 114378: unify pad attribute
"padding_values": tuple(node_attrs["pads_begin"] + node_attrs["pads_end"]),
if len(constant_attributes) != 1:
return None

port_id, attrs = constant_attributes.copy().popitem()
if ov_metatype in [
OVConvolutionMetatype,
OVDepthwiseConvolutionMetatype,
OVGroupConvolutionMetatype,
OVConvolutionBackpropDataMetatype,
OVGroupConvolutionBackpropDataMetatype,
]:
node_attrs = ov_node.get_attributes()
kwargs = {
"weight_requires_grad": False,
"stride": tuple(node_attrs["strides"]),
"dilations": node_attrs["dilations"],
"transpose": ov_metatype in [OVConvolutionBackpropDataMetatype, OVGroupConvolutionBackpropDataMetatype],
# TODO: ticket 114378: unify pad attribute
"padding_values": tuple(node_attrs["pads_begin"] + node_attrs["pads_end"]),
}

weights_layout_map = {
OVConvolutionMetatype: [LayoutElem.C_OUT, LayoutElem.C_IN],
OVGroupConvolutionMetatype: [LayoutElem.GROUPS, LayoutElem.C_OUT, LayoutElem.C_IN],
OVDepthwiseConvolutionMetatype: [LayoutElem.GROUPS, LayoutElem.C_OUT, LayoutElem.C_IN],
OVConvolutionBackpropDataMetatype: [LayoutElem.C_IN, LayoutElem.C_OUT],
OVGroupConvolutionBackpropDataMetatype: [LayoutElem.GROUPS, LayoutElem.C_IN, LayoutElem.C_OUT],
}

weights_layout = weights_layout_map[ov_metatype]
weights_shape = attrs["shape"]
kwargs.update(
{
"in_channels": weights_shape[weights_layout.index(LayoutElem.C_IN)],
"out_channels": weights_shape[weights_layout.index(LayoutElem.C_OUT)],
"kernel_size": tuple(weights_shape[len(weights_layout) :]),
"groups": weights_shape[weights_layout.index(LayoutElem.GROUPS)]
if LayoutElem.GROUPS in weights_layout
else 1,
}

const_shape = attrs["shape"]
if ov_metatype in [OVConvolutionMetatype, OVConvolutionBackpropDataMetatype]:
kwargs.update(
{
"in_channels": const_shape[1],
"out_channels": const_shape[0],
"kernel_size": tuple(const_shape[2:]),
"groups": 1,
}
)
)
kwargs.update({"weights_layout": tuple(weights_layout + len(kwargs["kernel_size"]) * [LayoutElem.SPATIAL])})

return ConvolutionLayerAttributes(**kwargs)
if ov_metatype == OVMatMulMetatype:
weights_shape = attrs["shape"]

weights_layout = [LayoutElem.SPATIAL] * (len(weights_shape) - 2)
if len(weights_shape) > 1:
transpose = attrs.get("transpose", False)
if (transpose and port_id == 0) or (not transpose and port_id == 1):
weights_layout += [LayoutElem.C_IN, LayoutElem.C_OUT]
else:
kwargs.update(
{
"in_channels": const_shape[2],
"out_channels": const_shape[1],
"kernel_size": tuple(const_shape[3:]),
"groups": const_shape[0],
}
)
if kwargs["transpose"]:
kwargs["in_channels"], kwargs["out_channels"] = kwargs["out_channels"], kwargs["in_channels"]

common_layer_attr = ConvolutionLayerAttributes(**kwargs)
weights_layout += [LayoutElem.C_OUT, LayoutElem.C_IN]
else:
common_layer_attr = GenericWeightedLayerAttributes(
weight_requires_grad=False, weight_shape=attrs.get("shape", None)
)
retval[port_id] = common_layer_attr
return retval
weights_layout += [LayoutElem.C_IN]

kwargs = {
"weight_requires_grad": False,
"in_features": weights_shape[weights_layout.index(LayoutElem.C_IN)],
"out_features": weights_shape[weights_layout.index(LayoutElem.C_OUT)]
if LayoutElem.C_OUT in weights_layout
else None,
"with_bias": False,
"weights_layout": weights_layout,
}
return LinearLayerAttributes(**kwargs)
return GenericWeightedLayerAttributes(weight_requires_grad=False, weight_shape=attrs.get("shape", None))
39 changes: 35 additions & 4 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import LayoutElem
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_shape
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.channel_alignment.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend
Expand Down Expand Up @@ -112,10 +115,23 @@ def filter_func(point: StatisticPoint) -> bool:
assert len(tensor_collectors) == 1
stat = tensor_collectors[0].get_statistics()
if stat.min_values is None or stat.max_values is None:
nncf_logger.debug(
f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} "
"because statistics were not collected for this pair."
)
continue

conv_in_cont = ConvParamsContainer(conv_in, model, graph, self._backend_entity)
conv_out_cont = ConvParamsContainer(conv_out, model, graph, self._backend_entity)
if (
conv_in_cont.dims.conv_weight_out_channels_dim is None
or conv_out_cont.dims.conv_weight_out_channels_dim is None
):
nncf_logger.debug(
f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} "
" because one of the node is 1D MatMul, 1D Matmuls are not supported by CA algortihm yet."
)
continue

amean = (stat.max_values + stat.min_values) * 0.5
conv_in_cont.bias, conv_out_cont.bias = self._align_means(
Expand Down Expand Up @@ -247,7 +263,7 @@ def _align_scales(
return updated_conv_in_value, updated_conv_out_value, updated_bias_in_value

def _check_consumer_conv_node(self, conv_node: NNCFNode) -> bool:
attrs = self._backend_entity.get_conv_layer_attributes(conv_node)
attrs = conv_node.layer_attributes.get_backend_agnostic_attributes()
if attrs is None:
return False
# Check groups amount == 1
Expand Down Expand Up @@ -373,9 +389,10 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin
statistic_container = StatisticPointsContainer()
for conv_in, add_in, _ in self._get_node_pairs(graph):
target_point, node_in = self._get_target_point_and_node_in(conv_in, add_in)

channel_axis = conv_in.metatype.output_channel_axis
reduction_shape = list(range(len(graph.get_output_edges(node_in)[0].tensor_shape)))
reduction_shape.remove(channel_axis)
activation_shape = list(range(len(graph.get_output_edges(node_in)[0].tensor_shape)))
reduction_shape = get_channel_agnostic_reduction_shape([channel_axis], activation_shape)

statistic_collector = self._backend_entity.get_statistic_collector(
tuple(reduction_shape), self._quantile, self.subset_size, self.inplace_statistics
Expand Down Expand Up @@ -447,7 +464,21 @@ def __init__(
bias = backend_entity.create_bias_tensor(conv_op, nncf_graph, 0)
self.stated_bias = StatedTensor(bias)
self._op = conv_op
self._dims = backend_entity.get_dims_descriptor(conv_op)
weights_layout = conv_op.layer_attributes.get_backend_agnostic_attributes().weights_layout
if LayoutElem.GROUPS in weights_layout:
# Using groups dim as output channels dim for ChannelAlignment algorithm
# TODO(dlyakhov) support group convolutions with groups number not in [1, out_channels]
self._dims = LayoutDescriptor(
weights_layout.index(LayoutElem.GROUPS),
weights_layout.index(LayoutElem.C_IN),
conv_op.metatype.output_channel_axis,
)
else:
self._dims = LayoutDescriptor(
weights_layout.index(LayoutElem.C_OUT) if LayoutElem.C_OUT in weights_layout else None,
weights_layout.index(LayoutElem.C_IN),
conv_op.metatype.output_channel_axis,
)

@property
def weight(self):
Expand Down
21 changes: 0 additions & 21 deletions nncf/quantization/algorithms/channel_alignment/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,6 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
(bias is added to the output tensor of that operation), `False` otherwise.
"""

@staticmethod
@abstractmethod
def get_dims_descriptor(node: NNCFNode) -> LayoutDescriptor:
"""
Return weights layout descriptor of the given node if it is possible and None otherwise.
Only convolutional and linear nodes are supported.
:param node: NNCFNode to get layout descriptor from.
:return: Weights layout descriptor of the given node if it is possible and None otherwise.
"""

@staticmethod
@abstractmethod
def get_conv_layer_attributes(node: NNCFNode) -> Optional[ConvolutionLayerAttributes]:
"""
Returns convolutional layer attributes of given node if they are present and None otherwise.
:param node: NNCFNode to take convolutional layer attributes from.
:return: Convolutional layer attributes of given node if they are present and None otherwise
"""

@staticmethod
@abstractmethod
def create_bias_tensor(node: NNCFNode, nncf_graph: NNCFGraph, value: Any):
Expand Down
Loading

0 comments on commit 9da200c

Please sign in to comment.