Skip to content

Commit

Permalink
Test covering
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 24, 2024
1 parent e391e29 commit be4d6f6
Show file tree
Hide file tree
Showing 15 changed files with 113 additions and 142 deletions.
1 change: 0 additions & 1 deletion nncf/common/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class TransformationPriority(IntEnum):
FP32_TENSOR_STATISTICS_OBSERVATION = 1
PRUNING_PRIORITY = 2
SPARSIFICATION_PRIORITY = 3
OP_INSERTION_PRIORITY = 4
QUANTIZATION_PRIORITY = 11


Expand Down
2 changes: 2 additions & 0 deletions nncf/experimental/tensor/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from nncf.experimental.tensor.functions.numeric import moveaxis as moveaxis
from nncf.experimental.tensor.functions.numeric import multiply as multiply
from nncf.experimental.tensor.functions.numeric import ones_like as ones_like
from nncf.experimental.tensor.functions.numeric import power as power
from nncf.experimental.tensor.functions.numeric import quantile as quantile
from nncf.experimental.tensor.functions.numeric import reshape as reshape
from nncf.experimental.tensor.functions.numeric import round as round
from nncf.experimental.tensor.functions.numeric import squeeze as squeeze
Expand Down
30 changes: 23 additions & 7 deletions nncf/experimental/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,23 +383,39 @@ def round(a: Tensor, decimals=0) -> Tensor:

@functools.singledispatch
@tensor_guard
def power(a: Tensor, pwr: float) -> Tensor:
return Tensor(power(a.data, pwr))
def power(a: Tensor, exponent: float) -> Tensor:
"""
Takes the power of each element in input with given power and
returns a tensor with the result.
:param a: Input data.
:param exponent: Exponent value.
:return: The result of the power of each element in input with given exponent.
"""
return Tensor(power(a.data, exponent))


@functools.singledispatch
@tensor_guard
def quantile(
a: Tensor,
q: Union[float, List[float]],
axis: Union[int, List[int]] = None,
axis: Union[int, Tuple[int]] = None,
keepdims: Optional[bool] = None,
) -> Union[float, Tensor]:
retval = quantile(a.data, q, axis, keepdims)
"""
Compute the quantile(s) of the data along the specified axis.
if isinstance(retval, float):
return retval
return Tensor(retval)
:param a: Given tensor.
:params q: Quantile or sequence of quantiles to compute, which must be between
0 and 1 inclusive.
:param axis: Axis or axes along which the quantiles are computed.
:param keepdims: If True, the axes which are reduced are left in the result
as dimensions with size one.
:return: An tensor with quantiles, the first axis of the result corresponds
to the quantiles, the second axis of the result corresponds to the quantiles values.
"""
return Tensor(quantile(a.data, q, axis, keepdims))


@functools.singledispatch
Expand Down
21 changes: 7 additions & 14 deletions nncf/experimental/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,31 +186,24 @@ def _(
return np.clip(a, a_min=min_val, a_max=max_val)


@register_numpy_types(numeric.eps)
def _(a: Union[np.ndarray, np.generic], dtype: TensorDataType) -> float:
return np.finfo(DTYPE_MAP[dtype]).eps


@register_numpy_types(numeric.power)
def _(a: Union[np.ndarray, np.generic], pwr: float) -> Union[np.ndarray, np.generic]:
return np.power(a, pwr)
def _(a: Union[np.ndarray, np.generic], exponent: float) -> Union[np.ndarray, np.generic]:
return np.power(a, exponent)


@register_numpy_types(numeric.quantile)
def _(
a: Union[np.ndarray, np.generic],
q: Union[float, List[float]],
axis: Union[int, List[int]] = None,
axis: Union[int, Tuple[int]] = None,
keepdims: Optional[bool] = None,
) -> Union[float, Union[np.ndarray, np.generic]]:
if keepdims is None:
keepdims = np._NoValue
return np.quantile(a, q=q, axis=axis, keepdims=keepdims)


@register_numpy_types(numeric.size)
def _(a: Union[np.ndarray, np.generic]) -> int:
return a.size
ret_val = np.quantile(a, q=q, axis=axis, keepdims=keepdims)
if isinstance(ret_val, np.ndarray):
return ret_val
return np.array(ret_val)


@register_numpy_types(numeric._binary_op_nowarn)
Expand Down
34 changes: 14 additions & 20 deletions nncf/experimental/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,37 +197,31 @@ def _(a: torch.Tensor, min_val: float, max_val: Optional[float] = None) -> torch
return torch.clip(a, min=min_val, max=max_val)


@numeric.eps.register(torch.Tensor)
def _(a: torch.Tensor, dtype: TensorDataType) -> float:
return torch.finfo(DTYPE_MAP[dtype]).eps


@numeric.power.register(torch.Tensor)
def _(a: torch.Tensor, pwr: float) -> torch.Tensor:
return torch.pow(a, exponent=pwr)
def _(a: torch.Tensor, exponent: float) -> torch.Tensor:
return torch.pow(a, exponent=exponent)


@numeric.quantile.register(torch.Tensor)
def _(
a: torch.Tensor,
q: Union[float, List[float]],
axis: Union[int, List[int]] = None,
axis: Union[int, Tuple[int]] = None,
keepdims: Optional[bool] = None,
) -> Union[float, torch.Tensor]:
device = a.device
# See https://github.com/pytorch/pytorch/issues/61582
# https://github.com/pytorch/pytorch/issues/64947
device = a.device
if keepdims is None:
keepdims = np._NoValue
np_result = np.quantile(a.detach().cpu().numpy(), q=q, axis=axis, keepdims=keepdims)
if isinstance(np_result, np.ndarray):
return torch.tensor(np_result).type(a.dtype).to(device)
return np_result


@numeric.size.register(torch.Tensor)
def _(a: torch.Tensor) -> int:
return a.numel()
if len(a) <= 16_000_000 and isinstance(axis, int):
result = torch.quantile(
a,
torch.tensor(q, dtype=a.dtype, device=a.device),
axis,
keepdims,
)
else:
result = torch.tensor(np.quantile(a.detach().cpu().numpy(), q=q, axis=axis, keepdims=keepdims))
return result.type(a.dtype).to(device)


@numeric._binary_op_nowarn.register(torch.Tensor)
Expand Down
16 changes: 12 additions & 4 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,14 @@ def apply(
if any(val.data is None for val in activations_value):
empty_statistic = True
break
assert len(activations_value) == 1
if len(activations_value) != 1:
raise RuntimeError(
(
"More than one statistic is collected for one node during"
f"Smooth Quanti algorithm: {node_to_smooth.node_name}"
)
)

activations_value = self._clip_statistics(activations_value)

weight_value = self._backend_entity.get_weight_value(node_to_smooth, model)
Expand Down Expand Up @@ -194,7 +201,7 @@ def _calculate_scale_and_ratio(
a_min = fns.quantile(scales, quantile, keepdims=False)
a_max = 1e2

scales = fns.clip(scales, min_val=a_min, max_val=a_max)
scales = fns.clip(scales, a_min=a_min, a_max=a_max)
ratio = scales.min() / (scales.max() + eps)
return scales, ratio

Expand Down Expand Up @@ -253,6 +260,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin
for node_data in nodes_to_smooth_data:
node_to_smooth = node_data["node_to_smooth"]
target_point = self._backend_entity.target_point(
target_type=self._backend_entity.pre_layer_target_type(),
target_node_name=node_to_smooth.node_name,
port_id=node_data["input_act_port"],
)
Expand Down Expand Up @@ -305,7 +313,7 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: List[
nodes_to_smooth_data.append(
{
"node_to_smooth": node_with_weight,
"input_act_port": self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph),
"input_act_port": activation_port_id,
}
)
return nodes_to_smooth_data
Expand Down Expand Up @@ -435,4 +443,4 @@ def _clip_statistics(statistics: List[Tensor]) -> Tensor:

statistics = fns.stack(statistics)
squeezed = fns.squeeze(statistics)
return fns.clip(squeezed, min_val=a_min, max_val=None)
return fns.clip(squeezed, a_min=a_min, a_max=None)
34 changes: 28 additions & 6 deletions nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from abc import ABC
from abc import abstractmethod
from typing import List, Tuple, TypeVar
from typing import Callable, List, Tuple, TypeVar

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.tensor import Tensor

Expand Down Expand Up @@ -55,10 +57,20 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:

@staticmethod
@abstractmethod
def target_point(target_node_name: str, port_id: int) -> TargetPoint:
def pre_layer_target_type() -> TargetType:
"""
Returns backend-specific pre layer target type.
:returns: Backend-specific pre layer target type.
"""

@staticmethod
@abstractmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint:
"""
Returns backend-specific target point.
:param target_type: Type of the location that should be modified.
:param target_node_name: Name of the located node.
:param port_id: Port ID of the tensor for the statistics distribution.
:return: Backend-specific TargetPoint.
Expand Down Expand Up @@ -184,10 +196,20 @@ def get_weight_channel_axis(node: NNCFNode) -> int:

@staticmethod
@abstractmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
pass
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
"""
Returns true if given node shares constant with a different node.
:param node: NNCFNode instance.
:param nncf_graph: NNCFGraph instance.
:return: Whether the given node is shares weights with a different node or not.
"""

@staticmethod
@abstractmethod
def get_filter_fn_for_statistics(activation_port_id: int):
pass
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
"""
Returns backend-specific callable to filter statistic containers according to its statistic point.
:param activation_port_id: Activation port id for the statistic collection target node.
"""
22 changes: 11 additions & 11 deletions nncf/quantization/algorithms/smooth_quant/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
from typing import Callable, List, Tuple

import numpy as np
import openvino.runtime as ov
Expand Down Expand Up @@ -52,8 +52,12 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
return QUANTIZE_AGNOSTIC_OPERATIONS

@staticmethod
def target_point(target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, port_id)
def pre_layer_target_type() -> TargetType:
return TargetType.PRE_LAYER_OPERATION

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(target_type, target_node_name, port_id)

@staticmethod
def is_node_with_weights(node: NNCFNode) -> bool:
Expand Down Expand Up @@ -92,23 +96,19 @@ def get_abs_max_channel_collector(

@staticmethod
def get_weight_value(node_with_weight: NNCFNode, model: ov.Model) -> Tensor:
port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight)
port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight)
return Tensor(get_weight_value(node_with_weight, model, port_id))

@staticmethod
def get_weight_tensor_port_id(node: NNCFNode) -> int:
return OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node)

@staticmethod
def _get_weight_tensor_port_id(node: NNCFNode) -> int:
const_ids = node.layer_attributes.get_const_port_ids()
if len(const_ids) != 1:
raise RuntimeError(f"Found more than 1 port for {node.node_name} node")
return const_ids[0]

@staticmethod
def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> OVWeightUpdateCommand:
weight_port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight)
weight_port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight)
return OVCommandCreator.create_command_to_update_weight(node_with_weight, weight_value, weight_port_id)

@staticmethod
Expand Down Expand Up @@ -154,13 +154,13 @@ def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
return -2 + port_id if transpose else -1 - port_id

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
weight_port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node)
weight_node = nncf_graph.get_input_edges(node)[weight_port_id].from_node
return len(nncf_graph.get_next_nodes(weight_node)) > 1

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int):
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
def filter_func(point: StatisticPoint) -> bool:
return point.target_point.port_id == activation_port_id

Expand Down
16 changes: 10 additions & 6 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
from typing import Callable, List, Tuple

import numpy as np

Expand Down Expand Up @@ -63,8 +63,12 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT[QuantizationTrait.QUANTIZATION_AGNOSTIC]

@staticmethod
def target_point(target_node_name: str, port_id: int) -> PTTargetPoint:
return PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node_name, input_port_id=port_id)
def pre_layer_target_type() -> TargetType:
return TargetType.OPERATOR_PRE_HOOK

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint:
return PTTargetPoint(target_type, target_node_name, input_port_id=port_id)

@staticmethod
def is_node_with_weights(node: NNCFNode) -> bool:
Expand Down Expand Up @@ -92,7 +96,7 @@ def get_abs_max_channel_collector(
def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> Tensor:
node_module = model.nncf.get_containing_module(node_with_weight.node_name)
if node_module.weight is None:
return None
raise RuntimeError(f"{node_module} module has no .weight attribute.")
return Tensor(node_module.weight.data)

@staticmethod
Expand Down Expand Up @@ -130,11 +134,11 @@ def get_weight_channel_axis(node: NNCFNode) -> int:
return 1

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return node.is_shared()

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int):
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
def filter_func(point: StatisticPoint) -> bool:
return True

Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def wrap_operator(operator, operator_info: PatchedOperatorInfo):
Wraps the input callable object (`operator`) with the functionality that allows the calls to this object
to be tracked by the currently set global TracingContext. The wrapped functions can be then intercepted,
their arguments and return values modified arbitrarily and, for functions that correspond to operations on
tensors in a DNN, their general position and address in the DNN's model control flow graph can be established.
tensors in a DNN, their general position and address in the DNN's model control flow graph can be established.
:param: operator: A callable object to be wrapped.
:param: operator_info (PatchedOperatorInfo): An informational struct containing the specifics of wrapping
Expand Down
Loading

0 comments on commit be4d6f6

Please sign in to comment.