diff --git a/nncf/experimental/common/tensor_statistics/statistical_functions.py b/nncf/experimental/common/tensor_statistics/statistical_functions.py new file mode 100644 index 00000000000..24fc115a058 --- /dev/null +++ b/nncf/experimental/common/tensor_statistics/statistical_functions.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nncf.experimental.tensor import Tensor +from nncf.experimental.tensor import functions as fns + + +def mean_per_channel(x: Tensor, axis: int) -> Tensor: + """ + Computes the mean of elements across given channel dimension of Tensor. + + :param x: Tensor to reduce. + :param axis: The channel dimensions to reduce. + :return: Reduced Tensor. + """ + if len(x.shape) < 3: + return fns.mean(x, axis=0) + pos_axis = axis + x.ndim if axis < 0 else axis + if pos_axis < 0 or pos_axis >= x.ndim: + raise ValueError(f"axis {axis} is out of bounds for array of dimension {x.ndim}") + axis = tuple(i for i in range(x.ndim) if i != pos_axis) + return fns.mean(x, axis=axis) diff --git a/nncf/experimental/tensor/README.md b/nncf/experimental/tensor/README.md index 09e3dc6a1e0..ea8ae2168c9 100644 --- a/nncf/experimental/tensor/README.md +++ b/nncf/experimental/tensor/README.md @@ -6,7 +6,7 @@ making them more portable and reusable. ## Usage -The main idea is common algorithms should use wrapped tensors and provide to backend-specific function unwrapped tensor. +Common algorithms should use wrapped tensors and provide the unwrapped tensor to the backend-specific function. ### Initialization Tensor @@ -32,6 +32,8 @@ tenor_b = Tensor(np.array([1,2])) tensor_a + tenor_b # Tensor(array([2, 4])) ``` +**NOTE** Division operations for the numpy backend are performed with warnings disabled for the same for all backends. + ### Comparison operators All math operations are overrided to operated with wrapped object and return `Tensor` @@ -55,16 +57,16 @@ nncf_tensor.max() # Tensor(2) All available functions you can found in [functions.py](functions.py). ```python -from nncf.experimental.tensor import functions -functions.max(nncf_tensor) # Tensor(2) +from nncf.experimental.tensor import functions as fns +fns.max(nncf_tensor) # Tensor(2) ``` **NOTE** A function requires at least one positional argument, which is used to dispatch the function to the appropriate implementation depending on the type of argument. ```python -functions.max(nncf_tensor) # Correct -functions.max(a=nncf_tensor) # TypeError: wrapper requires at least 1 positional argument +fns.max(nncf_tensor) # Correct +fns.max(a=nncf_tensor) # TypeError: wrapper requires at least 1 positional argument ``` ### Loop over Tensor @@ -100,7 +102,7 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) class Tensor: ... def foo(self, arg1: Type) -> "Tensor": - return functions.foo(self, arg1) + return fns.foo(self, arg1) ``` 2. Add function to [function.py](function.py) @@ -120,28 +122,36 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) return NotImplemented(f"Function `foo` is not implemented for {type(a)}") ``` -3. Add function name to `__all__` in [function.py](function.py) + **NOTE** For the case when the first argument has type `List[Tensor]`, use the `_dispatch_list` function. This function dispatches function by first element in the first argument. + + ```python + @functools.singledispatch + def foo(x: List[Tensor], axis: int = 0) -> Tensor: + if isinstance(x, List): + unwrapped_x = [i.data for i in x] + return Tensor(_dispatch_list(foo, unwrapped_x, axis=axis)) + raise NotImplementedError(f"Function `foo` is not implemented for {type(x)}") + ``` -4. Add backend specific implementation of method to: +3. Add backend specific implementation of method to: - - [numpy_function.py](numpy_function.py) + - [numpy_function.py](numpy_functions.py) ```python - @functions.foo.register(np.ndarray) - @functions.foo.register(np.number) + @_register_numpy_types(fns.foo) def _(a: TType, arg1: Type) -> np.ndarray: return np.foo(a, arg1) ``` - - [torch_function.py](torch_function.py) + - [torch_function.py](torch_functions.py) ```python - @functions.foo.register(torch.Tensor) + @fns.foo.register(torch.Tensor) def _(a: torch.Tensor, arg1: Type) -> torch.Tensor: return torch.foo(a, arg1) ``` -5. Add test of method to [test template](tests/shared/test_templates/template_test_nncf_tensor.py) for Tensor class +4. Add test of method to [test template](../../../tests/shared/test_templates/template_test_nncf_tensor.py) for Tensor class ### Add new backend diff --git a/nncf/experimental/tensor/functions.py b/nncf/experimental/tensor/functions.py index 30f27a65cce..33e55c1ef50 100644 --- a/nncf/experimental/tensor/functions.py +++ b/nncf/experimental/tensor/functions.py @@ -10,14 +10,12 @@ # limitations under the License. import functools -from typing import List, Optional, Tuple, TypeVar, Union +from typing import Callable, List, Optional, Tuple, Union -from nncf.experimental.tensor import Tensor -from nncf.experimental.tensor import unwrap_tensor_data from nncf.experimental.tensor.enums import TensorDataType from nncf.experimental.tensor.enums import TensorDeviceType - -TTensor = TypeVar("TTensor") +from nncf.experimental.tensor.tensor import Tensor +from nncf.experimental.tensor.tensor import unwrap_tensor_data def _tensor_guard(func: callable): @@ -36,7 +34,7 @@ def wrapper(*args, **kwargs): @functools.singledispatch @_tensor_guard -def device(a: TTensor) -> TensorDeviceType: +def device(a: Tensor) -> TensorDeviceType: """ Return the device of the tensor. @@ -48,7 +46,7 @@ def device(a: TTensor) -> TensorDeviceType: @functools.singledispatch @_tensor_guard -def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: +def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Remove axes of length one from a. @@ -63,7 +61,7 @@ def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTenso @functools.singledispatch @_tensor_guard -def flatten(a: TTensor) -> TTensor: +def flatten(a: Tensor) -> Tensor: """ Return a copy of the tensor collapsed into one dimension. @@ -75,7 +73,7 @@ def flatten(a: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def max(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def max(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Return the maximum of an array or maximum along an axis. @@ -88,7 +86,7 @@ def max(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def min(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def min(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Return the minimum of an array or minimum along an axis. @@ -101,7 +99,7 @@ def min(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def abs(a: TTensor) -> TTensor: # pylint: disable=redefined-builtin +def abs(a: Tensor) -> Tensor: # pylint: disable=redefined-builtin """ Calculate the absolute value element-wise. @@ -113,7 +111,7 @@ def abs(a: TTensor) -> TTensor: # pylint: disable=redefined-builtin @functools.singledispatch @_tensor_guard -def astype(a: TTensor, data_type: TensorDataType) -> TTensor: +def astype(a: Tensor, data_type: TensorDataType) -> Tensor: """ Copy of the tensor, cast to a specified type. @@ -127,7 +125,7 @@ def astype(a: TTensor, data_type: TensorDataType) -> TTensor: @functools.singledispatch @_tensor_guard -def dtype(a: TTensor) -> TensorDataType: +def dtype(a: Tensor) -> TensorDataType: """ Return data type of the tensor. @@ -139,7 +137,7 @@ def dtype(a: TTensor) -> TensorDataType: @functools.singledispatch @_tensor_guard -def reshape(a: TTensor, shape: List[int]) -> TTensor: +def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor: """ Gives a new shape to a tensor without changing its data. @@ -152,7 +150,7 @@ def reshape(a: TTensor, shape: List[int]) -> TTensor: @functools.singledispatch @_tensor_guard -def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def all(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Test whether all tensor elements along a given axis evaluate to True. @@ -165,7 +163,9 @@ def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor: +def allclose( + a: Tensor, b: Union[Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> Tensor: """ Returns True if two arrays are element-wise equal within a tolerance. @@ -191,7 +191,7 @@ def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, e @functools.singledispatch @_tensor_guard -def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def any(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Test whether any tensor elements along a given axis evaluate to True. @@ -204,7 +204,7 @@ def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: +def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Counts the number of non-zero values in the tensor input. @@ -218,19 +218,21 @@ def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> @functools.singledispatch @_tensor_guard -def isempty(a: TTensor) -> TTensor: +def isempty(a: Tensor) -> bool: """ Return True if input tensor is empty. :param a: The input tensor. :return: True if tensor is empty, otherwise False. """ - return Tensor(isempty(a.data)) + return isempty(a.data) @functools.singledispatch @_tensor_guard -def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor: +def isclose( + a: Tensor, b: Union[Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> Tensor: """ Returns a boolean array where two arrays are element-wise equal within a tolerance. @@ -256,7 +258,7 @@ def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, eq @functools.singledispatch @_tensor_guard -def maximum(x1: TTensor, x2: TTensor) -> TTensor: +def maximum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Element-wise maximum of tensor elements. @@ -269,7 +271,7 @@ def maximum(x1: TTensor, x2: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def minimum(x1: TTensor, x2: TTensor) -> TTensor: +def minimum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Element-wise minimum of tensor elements. @@ -282,7 +284,7 @@ def minimum(x1: TTensor, x2: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def ones_like(a: TTensor) -> TTensor: +def ones_like(a: Tensor) -> Tensor: """ Return a tensor of ones with the same shape and type as a given tensor. @@ -294,7 +296,7 @@ def ones_like(a: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor: +def where(condition: Tensor, x: Union[Tensor, float], y: Union[Tensor, float]) -> Tensor: """ Return elements chosen from x or y depending on condition. @@ -314,7 +316,7 @@ def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def zeros_like(a: TTensor) -> TTensor: +def zeros_like(a: Tensor) -> Tensor: """ Return an tensor of zeros with the same shape and type as a given tensor. @@ -324,28 +326,114 @@ def zeros_like(a: TTensor) -> TTensor: return Tensor(zeros_like(a.data)) -__all__ = [ - "device", - "squeeze", - "flatten", - "max", - "min", - "abs", - "astype", - "reshape", - "all", - "allclose", - "any", - "count_nonzero", - "isempty", - "isclose", - "maximum", - "minimum", - "ones_like", - "minimum", - "where", - "zeros_like", -] +@functools.singledispatch +def stack(x: List[Tensor], axis: int = 0) -> Tensor: + """ + Stacks a list of Tensors rank-R tensors into one Tensor rank-(R+1) tensor. + + :param x: List of Tensors. + :param axis: The axis to stack along. + :return: Stacked Tensor. + """ + if isinstance(x, List): + return Tensor(_dispatch_list(stack, x, axis=axis)) + raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}") + + +@functools.singledispatch +@_tensor_guard +def unstack(a: Tensor, axis: int = 0) -> List[Tensor]: + """ + Unstack a Tensor into list. + + :param a: Tensor to unstack. + :param axis: The axis to unstack along. + :return: List of Tensor. + """ + res = unstack(a.data, axis=axis) + return [Tensor(i) for i in res] + + +@functools.singledispatch +@_tensor_guard +def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> Tensor: + """ + Move axes of an array to new positions. + + :param a: The array whose axes should be reordered. + :param source: Original positions of the axes to move. These must be unique. + :param destination: Destination positions for each of the original axes. These must also be unique. + :return: Array with moved axes. + """ + return Tensor(moveaxis(a.data, source, destination)) + + +@functools.singledispatch +@_tensor_guard +def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor: + """ + Compute the arithmetic mean along the specified axis. + + :param a: Array containing numbers whose mean is desired. + :param axis: Axis or axes along which the means are computed. + :param keepdims: Destination positions for each of the original axes. These must also be unique. + :return: Array with moved axes. + """ + return Tensor(mean(a.data, axis, keepdims)) + + +@functools.singledispatch +@_tensor_guard +def round(a: Tensor, decimals=0) -> Tensor: # pylint: disable=redefined-builtin + """ + Evenly round to the given number of decimals. + + :param a: Input data. + :param decimals: Number of decimal places to round to (default: 0). If decimals is negative, + it specifies the number of positions to the left of the decimal point. + :return: An array of the same type as a, containing the rounded values. + """ + return Tensor(round(a.data, decimals)) + + +@functools.singledispatch +@_tensor_guard +def _binary_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor: + """ + Applies a binary operation with disable warnings. + + :param a: The first tensor. + :param b: The second tensor. + :param operator_fn: The binary operation function. + :return: The result of the binary operation. + """ + return Tensor(_binary_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) + + +@functools.singledispatch +@_tensor_guard +def _binary_reverse_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor: + """ + Applies a binary reverse operation with disable warnings. + + :param a: The first tensor. + :param b: The second tensor. + :param operator_fn: The binary operation function. + :return: The result of the binary operation. + """ + return Tensor(_binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) + + +def _dispatch_list(fn: "functools._SingleDispatchCallable", tensor_list: List[Tensor], *args, **kwargs): + """ + Dispatches the function to the type of the wrapped data of the first element in tensor_list. + + :param fn: A function wrapped by `functools.singledispatch`. + :param tensor_list: List of Tensors. + :return: The result value of the function call. + """ + unwrapped_list = [i.data for i in tensor_list] + return fn.dispatch(type(unwrapped_list[0]))(unwrapped_list, *args, **kwargs) def _initialize_backends(): diff --git a/nncf/experimental/tensor/numpy_functions.py b/nncf/experimental/tensor/numpy_functions.py index be070db4bdb..b4c515b1da1 100644 --- a/nncf/experimental/tensor/numpy_functions.py +++ b/nncf/experimental/tensor/numpy_functions.py @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np -from nncf.experimental.tensor import functions +from nncf.experimental.tensor import functions as fns from nncf.experimental.tensor.enums import TensorDataType from nncf.experimental.tensor.enums import TensorDeviceType @@ -28,137 +28,179 @@ DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} -@functions.device.register(np.ndarray) -@functions.device.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> TensorDeviceType: +def _register_numpy_types(singledispatch_fn): + """ + Decorator to register function to singledispatch for numpy classes. + + :param singledispatch_fn: singledispatch function. + """ + + def inner(func): + singledispatch_fn.register(np.ndarray)(func) + singledispatch_fn.register(np.generic)(func) + return func + + return inner + + +@_register_numpy_types(fns.device) +def _(a: Union[np.ndarray, np.generic]) -> TensorDeviceType: return TensorDeviceType.CPU -@functions.squeeze.register(np.ndarray) -@functions.squeeze.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: +@_register_numpy_types(fns.squeeze) +def _( + a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[np.ndarray, np.generic]: return np.squeeze(a, axis=axis) -@functions.flatten.register(np.ndarray) -@functions.flatten.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> np.ndarray: +@_register_numpy_types(fns.flatten) +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return a.flatten() -@functions.max.register(np.ndarray) -@functions.max.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: +@_register_numpy_types(fns.max) +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: return np.max(a, axis=axis) -@functions.min.register(np.ndarray) -@functions.min.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: +@_register_numpy_types(fns.min) +def _( + a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[np.ndarray, np.generic]: return np.min(a, axis=axis) -@functions.abs.register(np.ndarray) -@functions.abs.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> np.ndarray: +@_register_numpy_types(fns.abs) +def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]: return np.absolute(a) -@functions.astype.register(np.ndarray) -@functions.astype.register(np.number) -def _(a: Union[np.ndarray, np.number], dtype: TensorDataType) -> np.ndarray: +@_register_numpy_types(fns.astype) +def _(a: Union[np.ndarray, np.generic], dtype: TensorDataType) -> Union[np.ndarray, np.generic]: return a.astype(DTYPE_MAP[dtype]) -@functions.dtype.register(np.ndarray) -@functions.dtype.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> TensorDataType: +@_register_numpy_types(fns.dtype) +def _(a: Union[np.ndarray, np.generic]) -> TensorDataType: return DTYPE_MAP_REV[np.dtype(a.dtype)] -@functions.reshape.register(np.ndarray) -@functions.reshape.register(np.number) -def _(a: Union[np.ndarray, np.number], shape: Union[int, Tuple[int]]) -> np.ndarray: +@_register_numpy_types(fns.reshape) +def _(a: Union[np.ndarray, np.generic], shape: Union[int, Tuple[int, ...]]) -> np.ndarray: return a.reshape(shape) -@functions.all.register(np.ndarray) -@functions.all.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> Union[np.ndarray, bool]: +@_register_numpy_types(fns.all) +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[np.ndarray, bool]: return np.all(a, axis=axis) -@functions.allclose.register(np.ndarray) -@functions.allclose.register(np.number) +@_register_numpy_types(fns.allclose) def _( - a: Union[np.ndarray, np.number], - b: Union[np.ndarray, np.number], + a: Union[np.ndarray, np.generic], + b: Union[np.ndarray, np.generic, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, -) -> bool: +) -> np.ndarray: return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@functions.any.register(np.ndarray) -@functions.any.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> Union[np.ndarray, bool]: +@_register_numpy_types(fns.any) +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[np.ndarray, bool]: return np.any(a, axis=axis) -@functions.count_nonzero.register(np.ndarray) -@functions.count_nonzero.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: - return np.count_nonzero(a, axis=axis) +@_register_numpy_types(fns.count_nonzero) +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: + return np.array(np.count_nonzero(a, axis=axis)) -@functions.isempty.register(np.ndarray) -@functions.isempty.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> bool: +@_register_numpy_types(fns.isempty) +def _(a: Union[np.ndarray, np.generic]) -> bool: return a.size == 0 -@functions.isclose.register(np.ndarray) -@functions.isclose.register(np.number) +@_register_numpy_types(fns.isclose) def _( - a: Union[np.ndarray, np.number], - b: np.ndarray, + a: Union[np.ndarray, np.generic], + b: Union[np.ndarray, np.generic, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, -): +) -> Union[np.ndarray, bool]: return np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@functions.maximum.register(np.ndarray) -@functions.maximum.register(np.number) -def _(x1: Union[np.ndarray, np.number], x2: np.ndarray) -> np.ndarray: +@_register_numpy_types(fns.maximum) +def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic, float]) -> np.ndarray: return np.maximum(x1, x2) -@functions.minimum.register(np.ndarray) -@functions.minimum.register(np.number) -def _(x1: Union[np.ndarray, np.number], x2: np.ndarray) -> np.ndarray: +@_register_numpy_types(fns.minimum) +def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic, float]) -> np.ndarray: return np.minimum(x1, x2) -@functions.ones_like.register(np.ndarray) -@functions.ones_like.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> np.ndarray: +@_register_numpy_types(fns.ones_like) +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.ones_like(a) -@functions.where.register(np.ndarray) -@functions.where.register(np.number) +@_register_numpy_types(fns.where) def _( - condition: Union[np.ndarray, np.number], - x: Union[np.ndarray, np.number, float, bool], - y: Union[np.ndarray, float, bool], + condition: Union[np.ndarray, np.generic], + x: Union[np.ndarray, np.generic, float], + y: Union[np.ndarray, np.generic, float], ) -> np.ndarray: return np.where(condition, x, y) -@functions.zeros_like.register(np.ndarray) -@functions.zeros_like.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> np.ndarray: +@_register_numpy_types(fns.zeros_like) +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.zeros_like(a) + + +@_register_numpy_types(fns.stack) +def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]: + return np.stack(x, axis=axis) + + +@_register_numpy_types(fns.unstack) +def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]: + return [np.squeeze(e, axis) for e in np.split(x, x.shape[axis], axis=axis)] + + +@_register_numpy_types(fns.moveaxis) +def _(a: np.ndarray, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> np.ndarray: + return np.moveaxis(a, source, destination) + + +@_register_numpy_types(fns.mean) +def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> np.ndarray: + return np.mean(a, axis=axis, keepdims=keepdims) + + +@_register_numpy_types(fns.round) +def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray: + return np.round(a, decimals=decimals) + + +@_register_numpy_types(fns._binary_op_nowarn) # pylint: disable=protected-access +def _( + a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable +) -> Union[np.ndarray, np.generic]: + # Run operator with disabled warning + with np.errstate(invalid="ignore", divide="ignore"): + return operator_fn(a, b) + + +@_register_numpy_types(fns._binary_reverse_op_nowarn) # pylint: disable=protected-access +def _( + a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable +) -> Union[np.ndarray, np.generic]: + # Run operator with disabled warning + with np.errstate(invalid="ignore", divide="ignore"): + return operator_fn(b, a) diff --git a/nncf/experimental/tensor/tensor.py b/nncf/experimental/tensor/tensor.py index daa8e37aff4..76fd05c4ff1 100644 --- a/nncf/experimental/tensor/tensor.py +++ b/nncf/experimental/tensor/tensor.py @@ -8,9 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations - -from typing import Any, List, Optional, Tuple, TypeVar, Union +import operator +from typing import Any, Optional, Tuple, TypeVar, Union from nncf.experimental.tensor.enums import TensorDataType from nncf.experimental.tensor.enums import TensorDeviceType @@ -31,8 +32,12 @@ def data(self) -> TTensor: return self._data @property - def shape(self) -> List[int]: - return list(self.data.shape) + def shape(self) -> Tuple[int, ...]: + return tuple(self.data.shape) + + @property + def ndim(self) -> int: + return self.data.ndim @property def device(self) -> TensorDeviceType: @@ -48,7 +53,7 @@ def __bool__(self) -> bool: def __iter__(self): return TensorIterator(self.data) - def __getitem__(self, index: int) -> "Tensor": + def __getitem__(self, index: int) -> Tensor: return Tensor(self.data[index]) def __str__(self) -> str: @@ -59,86 +64,86 @@ def __repr__(self) -> str: # built-in operations - def __add__(self, other: TTensor) -> "Tensor": + def __add__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data + unwrap_tensor_data(other)) - def __radd__(self, other: TTensor) -> "Tensor": + def __radd__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) + self.data) - def __sub__(self, other: TTensor) -> "Tensor": + def __sub__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data - unwrap_tensor_data(other)) - def __rsub__(self, other: TTensor) -> "Tensor": + def __rsub__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) - self.data) - def __mul__(self, other: TTensor) -> "Tensor": + def __mul__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data * unwrap_tensor_data(other)) - def __rmul__(self, other: TTensor) -> "Tensor": + def __rmul__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) * self.data) - def __pow__(self, other: TTensor) -> "Tensor": + def __pow__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data ** unwrap_tensor_data(other)) - def __truediv__(self, other: TTensor) -> "Tensor": - return Tensor(self.data / unwrap_tensor_data(other)) + def __truediv__(self, other: Union[Tensor, float]) -> Tensor: + return _call_function("_binary_op_nowarn", self, other, operator.truediv) - def __rtruediv__(self, other: TTensor) -> "Tensor": - return Tensor(unwrap_tensor_data(other) / self.data) + def __rtruediv__(self, other: Union[Tensor, float]) -> Tensor: + return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv) - def __floordiv__(self, other: TTensor) -> "Tensor": - return Tensor(self.data // unwrap_tensor_data(other)) + def __floordiv__(self, other: Union[Tensor, float]) -> Tensor: + return _call_function("_binary_op_nowarn", self, other, operator.floordiv) - def __rfloordiv__(self, other: TTensor) -> "Tensor": - return Tensor(unwrap_tensor_data(other) // self.data) + def __rfloordiv__(self, other: Union[Tensor, float]) -> Tensor: + return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv) - def __neg__(self) -> "Tensor": + def __neg__(self) -> Tensor: return Tensor(-self.data) # Comparison operators - def __lt__(self, other: TTensor) -> "Tensor": + def __lt__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data < unwrap_tensor_data(other)) - def __le__(self, other: TTensor) -> "Tensor": + def __le__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data <= unwrap_tensor_data(other)) - def __eq__(self, other: TTensor) -> "Tensor": + def __eq__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data == unwrap_tensor_data(other)) - def __ne__(self, other: TTensor) -> "Tensor": + def __ne__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data != unwrap_tensor_data(other)) - def __gt__(self, other: TTensor) -> "Tensor": + def __gt__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data > unwrap_tensor_data(other)) - def __ge__(self, other: TTensor) -> "Tensor": + def __ge__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data >= unwrap_tensor_data(other)) # Tensor functions - def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> "Tensor": + def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: return _call_function("squeeze", self, axis) - def flatten(self) -> "Tensor": + def flatten(self) -> Tensor: return _call_function("flatten", self) - def max(self, axis: Optional[TTensor] = None) -> "Tensor": + def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: return _call_function("max", self, axis) - def min(self, axis: Optional[TTensor] = None) -> "Tensor": + def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: return _call_function("min", self, axis) - def abs(self) -> "Tensor": + def abs(self) -> Tensor: return _call_function("abs", self) - def isempty(self) -> "Tensor": + def isempty(self) -> bool: return _call_function("isempty", self) - def astype(self, dtype: TensorDataType): + def astype(self, dtype: TensorDataType) -> Tensor: return _call_function("astype", self, dtype) - def reshape(self, shape: TTensor) -> "Tensor": + def reshape(self, shape: Tuple[int, ...]) -> Tensor: return _call_function("reshape", self, shape) diff --git a/nncf/experimental/tensor/torch_functions.py b/nncf/experimental/tensor/torch_functions.py index 09ef0f1b886..273d5419781 100644 --- a/nncf/experimental/tensor/torch_functions.py +++ b/nncf/experimental/tensor/torch_functions.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from nncf.experimental.tensor import TensorDataType from nncf.experimental.tensor import TensorDeviceType -from nncf.experimental.tensor import functions +from nncf.experimental.tensor import functions as fns DTYPE_MAP = { TensorDataType.float16: torch.float16, @@ -28,7 +28,7 @@ DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} -@functions.device.register(torch.Tensor) +@fns.device.register(torch.Tensor) def _(a: torch.Tensor) -> TensorDeviceType: DEVICE_MAP = { "cpu": TensorDeviceType.CPU, @@ -37,112 +37,162 @@ def _(a: torch.Tensor) -> TensorDeviceType: return DEVICE_MAP[a.device.type] -@functions.squeeze.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +@fns.squeeze.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: if axis is None: return a.squeeze() + if isinstance(axis, Tuple) and any(1 != a.shape[i] for i in axis): + # Make Numpy behavior, torch.squeeze skips axes that are not equal to one.. + raise ValueError("Cannot select an axis to squeeze out which has size not equal to one") return a.squeeze(axis) -@functions.flatten.register(torch.Tensor) +@fns.flatten.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return a.flatten() -@functions.max.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +@fns.max.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: + # Analog of numpy.max is torch.amax if axis is None: - return torch.max(a) - return torch.max(a, dim=axis).values + return torch.amax(a) + return torch.amax(a, dim=axis) -@functions.min.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +@fns.min.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: + # Analog of numpy.min is torch.amin if axis is None: - return torch.min(a) - return torch.min(a, dim=axis).values + return torch.amin(a) + return torch.amin(a, dim=axis) -@functions.abs.register(torch.Tensor) +@fns.abs.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return torch.absolute(a) -@functions.astype.register(torch.Tensor) +@fns.astype.register(torch.Tensor) def _(a: torch.Tensor, dtype: TensorDataType) -> torch.Tensor: return a.type(DTYPE_MAP[dtype]) -@functions.dtype.register(torch.Tensor) +@fns.dtype.register(torch.Tensor) def _(a: torch.Tensor) -> TensorDataType: return DTYPE_MAP_REV[a.dtype] -@functions.reshape.register(torch.Tensor) -def _(a: torch.Tensor, shape: List[int]) -> torch.Tensor: +@fns.reshape.register(torch.Tensor) +def _(a: torch.Tensor, shape: Tuple[int, ...]) -> torch.Tensor: return a.reshape(shape) -@functions.all.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[torch.Tensor, bool]: +@fns.all.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[torch.Tensor, bool]: if axis is None: return torch.all(a) return torch.all(a, dim=axis) -@functions.allclose.register(torch.Tensor) -def _(a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> bool: +@fns.allclose.register(torch.Tensor) +def _( + a: torch.Tensor, b: Union[torch.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> bool: + if not isinstance(b, torch.Tensor): + b = torch.tensor(b, device=a.device) return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@functions.any.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[torch.Tensor, bool]: +@fns.any.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[torch.Tensor, bool]: if axis is None: return torch.any(a) return torch.any(a, dim=axis) -@functions.count_nonzero.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +@fns.count_nonzero.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: return torch.count_nonzero(a, dim=axis) -@functions.isempty.register(torch.Tensor) +@fns.isempty.register(torch.Tensor) def _(a: torch.Tensor) -> bool: return a.numel() == 0 -@functions.isclose.register(torch.Tensor) -def _(a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False): +@fns.isclose.register(torch.Tensor) +def _( + a: torch.Tensor, b: Union[torch.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +): + if not isinstance(b, torch.Tensor): + b = torch.tensor(b, device=a.device) return torch.isclose(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) -@functions.maximum.register(torch.Tensor) -def _(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: +@fns.maximum.register(torch.Tensor) +def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: if not isinstance(x2, torch.Tensor): x2 = torch.tensor(x2, device=x1.data.device) return torch.maximum(x1, x2) -@functions.minimum.register(torch.Tensor) -def _(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: +@fns.minimum.register(torch.Tensor) +def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: if not isinstance(x2, torch.Tensor): x2 = torch.tensor(x2, device=x1.data.device) return torch.minimum(x1, x2) -@functions.ones_like.register(torch.Tensor) +@fns.ones_like.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return torch.ones_like(a) -@functions.where.register(torch.Tensor) +@fns.where.register(torch.Tensor) def _( condition: torch.Tensor, x: Union[torch.Tensor, float, bool], y: Union[torch.Tensor, float, bool] ) -> torch.Tensor: return torch.where(condition, x, y) -@functions.zeros_like.register(torch.Tensor) +@fns.zeros_like.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return torch.zeros_like(a) + + +@fns.stack.register(torch.Tensor) +def _(x: List[torch.Tensor], axis: int = 0) -> List[torch.Tensor]: + return torch.stack(x, dim=axis) + + +@fns.unstack.register(torch.Tensor) +def _(x: torch.Tensor, axis: int = 0) -> List[torch.Tensor]: + if not list(x.shape): + x = x.unsqueeze(0) + return torch.unbind(x, dim=axis) + + +@fns.moveaxis.register(torch.Tensor) +def _(a: torch.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> torch.Tensor: + return torch.moveaxis(a, source, destination) + + +@fns.mean.register(torch.Tensor) +def _(a: torch.Tensor, axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> torch.Tensor: + return torch.mean(a, axis=axis, keepdims=keepdims) + + +@fns.round.register(torch.Tensor) +def _(a: torch.Tensor, decimals=0) -> torch.Tensor: + return torch.round(a, decimals=decimals) + + +@fns._binary_op_nowarn.register(torch.Tensor) # pylint: disable=protected-access +def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor: + return operator_fn(a, b) + + +@fns._binary_reverse_op_nowarn.register(torch.Tensor) # pylint: disable=protected-access +def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor: + return operator_fn(b, a) diff --git a/nncf/onnx/quantization/quantizer_parameters.py b/nncf/onnx/quantization/quantizer_parameters.py index 71b3d976b50..e4470fdce1b 100644 --- a/nncf/onnx/quantization/quantizer_parameters.py +++ b/nncf/onnx/quantization/quantizer_parameters.py @@ -54,8 +54,8 @@ def convert_fq_params_to_onnx_params( if levels not in [255, 256]: raise ValueError("Can only export to INT8/UIN8 256-level ONNX Quantize/Dequantize pairs.") - input_low, input_high = parameters.input_low, parameters.input_high - output_low, output_high = parameters.output_low, parameters.output_high + input_low, input_high = parameters.input_low.data, parameters.input_high.data + output_low, output_high = parameters.output_low.data, parameters.output_high.data if not np.allclose(input_high, output_high) or not np.allclose(input_low, output_low): raise ValueError( "ONNX Quantize/Dequantize pairs only support input_high == output_high and input_low == output_low." diff --git a/nncf/openvino/graph/model_transformer.py b/nncf/openvino/graph/model_transformer.py index 19e43f4b131..16cad27bb65 100644 --- a/nncf/openvino/graph/model_transformer.py +++ b/nncf/openvino/graph/model_transformer.py @@ -249,10 +249,10 @@ def _convert_to_fp16(data): clip_data = np.clip(data, np.finfo(np.float16).min, np.finfo(np.float16).max) return clip_data.astype(np.float16) - input_low = _convert_to_fp16(fq_params.input_low) - input_high = _convert_to_fp16(fq_params.input_high) - output_low = _convert_to_fp16(fq_params.output_low) - output_high = _convert_to_fp16(fq_params.output_high) + input_low = _convert_to_fp16(fq_params.input_low.data) + input_high = _convert_to_fp16(fq_params.input_high.data) + output_low = _convert_to_fp16(fq_params.output_low.data) + output_high = _convert_to_fp16(fq_params.output_high.data) return input_low, input_high, output_low, output_high @staticmethod @@ -266,10 +266,10 @@ def _insert_fake_quantize_op( :param name_to_node_mapping: Mapping from node name to node instance. """ fq_params = transformation.quantizer_parameters - input_low = fq_params.input_low - input_high = fq_params.input_high - output_low = fq_params.output_low - output_high = fq_params.output_high + input_low = fq_params.input_low.data + input_high = fq_params.input_high.data + output_low = fq_params.output_low.data + output_high = fq_params.output_high.data levels = fq_params.levels node_name = transformation.target_point.target_node_name diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 7ed9182d004..65b058b612f 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from math import inf from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from nncf import Dataset @@ -25,6 +26,9 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend +from nncf.experimental.common.tensor_statistics.statistical_functions import mean_per_channel +from nncf.experimental.tensor import Tensor +from nncf.experimental.tensor import functions as fns from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS @@ -167,9 +171,9 @@ def apply( output_name=sub_output_name, ) - bias_shift = self.reshape_bias_shift(bias_shift, bias_value, channel_axis) + bias_shift = self._reshape_bias_shift(bias_shift, bias_value, channel_axis) updated_bias = bias_value + bias_shift - magnitude = self._backend_entity.get_bias_shift_magnitude(bias_value, updated_bias) + magnitude = self._get_bias_shift_magnitude(bias_value, updated_bias) if magnitude < self.threshold: nncf_logger.debug(f"{node_name} bias would be changed") @@ -185,7 +189,22 @@ def apply( return transformed_model - def reshape_bias_shift(self, bias_shift: TTensor, bias_value: TTensor, channel_axis: int) -> TTensor: + @staticmethod + def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Tensor) -> float: + """ + Calculates bias shift magnitude based on the current and updated values. + + :param current_bias_value: The original bias value. + :param updated_bias_value: The updated bias value. + :return: Magnitude between original and updated bias values. + """ + bias_shift_magnitude = inf + if fns.count_nonzero(current_bias_value == 0) == 0: + bias_shift_magnitude = fns.max(fns.abs((updated_bias_value - current_bias_value) / current_bias_value)) + return bias_shift_magnitude + + @staticmethod + def _reshape_bias_shift(bias_shift: Tensor, bias_value: Tensor, channel_axis: int) -> Tensor: """ Reshape bias_shift tensor in case of dimensions of bias_value is more then 1. @@ -198,7 +217,7 @@ def reshape_bias_shift(self, bias_shift: TTensor, bias_value: TTensor, channel_a if bias_value.ndim > 1: new_shape = [1] * bias_value.ndim new_shape[channel_axis] = bias_shift.shape[0] - bias_shift = self._backend_entity.reshape_tensor(bias_shift, new_shape) + bias_shift = bias_shift.reshape(new_shape) return bias_shift def _get_fp_inputs(self, statistic_points: StatisticPointsContainer, node_name: str) -> Tuple[List, List]: @@ -222,7 +241,7 @@ def input_filter_func(point): node_name, input_filter_func, self._algorithm_key ): statistics = tensor_collector.get_statistics() - input_fp.extend(statistics.mean_values) + input_fp.extend(Tensor(statistics.mean_values)) input_shape.extend(statistics.shape) return input_fp, input_shape @@ -245,7 +264,7 @@ def output_filter_func(point): for tensor_collector in statistic_points.get_algo_statistics_for_node( node_name, output_filter_func, self._algorithm_key ): - output_fp.extend(tensor_collector.get_statistics().mean_values) + output_fp.extend(Tensor(tensor_collector.get_statistics().mean_values)) return output_fp def _extract_submodel(self, model_transformer: ModelTransformer, node_name: str) -> TModel: @@ -299,8 +318,8 @@ def _get_bias_shift( engine = EngineFactory.create(model) raw_output = engine.infer(input_blob) q_outputs = self._backend_entity.process_model_output(raw_output, output_name) - q_outputs = self._backend_entity.tensor_processor.mean_per_channel(q_outputs, channel_axis).tensor - bias_shift = self._backend_entity.post_process_output_data(output_fp) - q_outputs + q_outputs = mean_per_channel(q_outputs, channel_axis) + bias_shift = fns.stack(output_fp) - q_outputs return bias_shift def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: diff --git a/nncf/quantization/algorithms/fast_bias_correction/backend.py b/nncf/quantization/algorithms/fast_bias_correction/backend.py index 38618fe2efa..6dde985c5d4 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/backend.py @@ -23,6 +23,7 @@ from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.common.utils.registry import Registry +from nncf.experimental.tensor import Tensor TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") @@ -31,13 +32,6 @@ class FastBiasCorrectionAlgoBackend(ABC): - @property - @abstractmethod - def tensor_processor(self): - """ - Returns backend-specific instance of the NNCFCollectorTensorProcessor. - """ - @staticmethod @abstractmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint: @@ -120,7 +114,7 @@ def create_input_data( @staticmethod @abstractmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: TModel) -> np.ndarray: + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: TModel) -> Tensor: """ Returns bias value in the NumPy format of provided node. @@ -156,7 +150,7 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: @staticmethod @abstractmethod - def process_model_output(raw_data: OutputType, output_name: str) -> NNCFTensor: + def process_model_output(raw_data: OutputType, output_name: str) -> Tensor: """ Returns backend-specific processed output from the model. @@ -176,37 +170,6 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: :return: Boolean indicating whether the node has a bias or not. """ - @staticmethod - @abstractmethod - def get_bias_shift_magnitude(current_bias_value: TTensor, updated_bias_value: TTensor) -> float: - """ - Calculates bias shift magnitude based on the current and updated values. - - :param current_bias_value: The original bias value. - :param updated_bias_value: The updated bias value. - :return: Magnitude between original and updated bias values. - """ - - @staticmethod - @abstractmethod - def post_process_output_data(data: List[TTensor]) -> TTensor: - """ - Convert data to backend specific type. - - :param data: List of data. - :return: Converted data. - """ - - @staticmethod - @abstractmethod - def reshape_tensor(data: TTensor, new_shape: List[int]) -> TTensor: - """ - Reshape tensor. - - :param data: Tensor. - :param new_shape: New shape. - """ - @staticmethod @abstractmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: diff --git a/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py index 02bb54106c6..d0646f6aeb2 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py @@ -18,6 +18,7 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType from nncf.common.utils.backend import BackendType +from nncf.experimental.tensor import Tensor from nncf.onnx.graph.node_utils import get_bias_value from nncf.onnx.graph.node_utils import is_any_weight_quantized from nncf.onnx.graph.node_utils import is_node_with_bias @@ -27,8 +28,6 @@ from nncf.onnx.graph.transformations.commands import ONNXNullBiasInsertionCommand from nncf.onnx.graph.transformations.commands import ONNXTargetPoint from nncf.onnx.statistics.collectors import ONNXMeanStatisticCollector -from nncf.onnx.statistics.collectors import ONNXNNCFCollectorTensorProcessor -from nncf.onnx.tensor import ONNXNNCFTensor from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend @@ -39,10 +38,6 @@ class ONNXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): def types_to_insert_bias(self): return [] - @property - def tensor_processor(self) -> ONNXNNCFCollectorTensorProcessor: - return ONNXNNCFCollectorTensorProcessor - @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> ONNXTargetPoint: return ONNXTargetPoint(target_type, target_node_name, port_id) @@ -53,9 +48,9 @@ def create_bias_insertion_command(node: NNCFNode) -> ONNXNullBiasInsertionComman @staticmethod def create_bias_correction_command( - node: NNCFNode, bias_value: np.ndarray, nncf_graph: NNCFGraph + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph ) -> ONNXBiasCorrectionCommand: - return create_bias_correction_command(node, bias_value) + return create_bias_correction_command(node, bias_value.data) @staticmethod def model_extraction_command(inputs: List[str], outputs: List[str]) -> ONNXModelExtractionCommand: @@ -76,27 +71,26 @@ def get_sub_input_output_names(subgraph: onnx.ModelProto) -> Tuple[str, str]: @staticmethod def create_input_data( - shape: Tuple[int], data: List[np.ndarray], input_name: str, channel_axis: int + shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int ) -> Dict[str, np.array]: - blob = np.zeros(shape) + blob = np.zeros(shape, dtype=data[0].data.dtype) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) - blob[index] = data[j] - blob = blob.astype(data[0].dtype) + blob[index] = data[j].data input_data = {input_name: blob} return input_data @staticmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto) -> np.ndarray: - return get_bias_value(node, model) + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto) -> Tensor: + return Tensor(get_bias_value(node, model)) @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: return 0, 0 @staticmethod - def process_model_output(raw_data: Dict, output_name: str) -> ONNXNNCFTensor: - return ONNXNNCFTensor(raw_data[output_name]) + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data[output_name]) @staticmethod def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: @@ -106,21 +100,6 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return is_node_with_bias(node) - @staticmethod - def get_bias_shift_magnitude(current_bias_value: np.ndarray, updated_bias_value: np.ndarray) -> float: - bias_shift_magnitude = np.inf - if np.count_nonzero(current_bias_value == 0) == 0: - bias_shift_magnitude = np.max(np.abs((updated_bias_value - current_bias_value) / current_bias_value)) - return bias_shift_magnitude - - @staticmethod - def post_process_output_data(data: List[np.ndarray]) -> np.ndarray: - return np.array(data) - - @staticmethod - def reshape_tensor(data: np.ndarray, new_shape: List[int]) -> np.ndarray: - return data.reshape(new_shape) - @staticmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: return node.node_name, node.node_name diff --git a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py index 61ebb5a695b..d2744da5864 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py @@ -19,6 +19,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.tensor import Tensor from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS from nncf.openvino.graph.node_utils import get_bias_value from nncf.openvino.graph.node_utils import is_node_with_bias @@ -26,28 +27,22 @@ from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand from nncf.openvino.graph.transformations.commands import OVModelExtractionCommand from nncf.openvino.graph.transformations.commands import OVTargetPoint -from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor from nncf.openvino.statistics.collectors import get_mean_statistic_collector -from nncf.openvino.tensor import OVNNCFTensor from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend @ALGO_BACKENDS.register(BackendType.OPENVINO) class OVFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): - @property - def tensor_processor(self) -> OVNNCFCollectorTensorProcessor: - return OVNNCFCollectorTensorProcessor - @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: return OVTargetPoint(target_type, target_node_name, port_id) @staticmethod def create_bias_correction_command( - node: NNCFNode, bias_value: np.ndarray, nncf_graph: NNCFGraph + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph ) -> OVBiasCorrectionCommand: - return OVCommandCreator.create_command_to_update_bias(node, bias_value, nncf_graph) + return OVCommandCreator.create_command_to_update_bias(node, bias_value.data, nncf_graph) @staticmethod def model_extraction_command(inputs: List[str], outputs: List[str]) -> OVModelExtractionCommand: @@ -68,19 +63,18 @@ def get_sub_input_output_names(subgraph: ov.Model) -> Tuple[str, str]: @staticmethod def create_input_data( - shape: Tuple[int], data: List[np.ndarray], input_name: str, channel_axis: int + shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int ) -> Dict[str, np.ndarray]: - blob = np.zeros(shape) + blob = np.zeros(shape, dtype=data[0].data.dtype) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) - blob[index] = data[j] - blob = blob.astype(data[0].dtype) + blob[index] = data[j].data input_data = {input_name: blob} return input_data @staticmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray: - return get_bias_value(node, nncf_graph, model) + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> Tensor: + return Tensor(get_bias_value(node, nncf_graph, model)) @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: @@ -97,28 +91,13 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return weight_node.metatype in FAKE_QUANTIZE_OPERATIONS @staticmethod - def process_model_output(raw_data: Dict, output_name: str) -> OVNNCFTensor: - return OVNNCFTensor(raw_data[output_name]) + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data[output_name]) @staticmethod def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return is_node_with_bias(node, nncf_graph) - @staticmethod - def get_bias_shift_magnitude(current_bias_value: np.ndarray, updated_bias_value: np.ndarray) -> float: - bias_shift_magnitude = np.inf - if np.count_nonzero(current_bias_value == 0) == 0: - bias_shift_magnitude = np.max(np.abs((updated_bias_value - current_bias_value) / current_bias_value)) - return bias_shift_magnitude - - @staticmethod - def post_process_output_data(data: List[np.ndarray]) -> np.ndarray: - return np.array(data) - - @staticmethod - def reshape_tensor(data: np.ndarray, new_shape: List[int]) -> np.ndarray: - return data.reshape(new_shape) - @staticmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: return node.node_name, node.node_name diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py index cb9b0026e3f..193be8994d9 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py @@ -20,6 +20,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.tensor import Tensor from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.torch.graph.transformations.command_creation import create_bias_correction_command @@ -31,8 +32,6 @@ from nncf.torch.model_analyzer import is_node_with_fused_bias from nncf.torch.model_analyzer import is_quantized_weights from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.tensor import PTNNCFTensor -from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector @@ -43,10 +42,6 @@ class PTFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, } - @property - def tensor_processor(self) -> PTNNCFCollectorTensorProcessor: - return PTNNCFCollectorTensorProcessor - @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: @@ -57,9 +52,9 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - @staticmethod def create_bias_correction_command( - node: NNCFNode, bias_value: np.ndarray, nncf_graph: NNCFGraph + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph ) -> PTBiasCorrectionCommand: - return create_bias_correction_command(node, bias_value) + return create_bias_correction_command(node, bias_value.data) @staticmethod def model_extraction_command(inputs: List[str], outputs: List[str]) -> PTModelExtractionWithFusedBiasCommand: @@ -80,26 +75,24 @@ def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: return None, None @staticmethod - def create_input_data( - shape: Tuple[int], data: List[torch.Tensor], input_name: str, channel_axis: int - ) -> torch.Tensor: - blob = torch.zeros(shape, dtype=data[0].dtype) + def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: + blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) - blob[index] = data[j] + blob[index] = data[j].data return blob @staticmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: NNCFNetwork) -> np.ndarray: - return get_fused_bias_value(node, model) + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: NNCFNetwork) -> Tensor: + return Tensor(get_fused_bias_value(node, model)) @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: return 0, 0 @staticmethod - def process_model_output(raw_data: Dict, output_name: str) -> PTNNCFTensor: - return PTNNCFTensor(raw_data) + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data) @staticmethod def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: @@ -109,21 +102,6 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return is_node_with_fused_bias(node, nncf_graph) - @staticmethod - def get_bias_shift_magnitude(current_bias_value: torch.Tensor, updated_bias_value: torch.Tensor) -> float: - bias_shift_magnitude = torch.inf - if torch.count_nonzero(current_bias_value == 0) == 0: - bias_shift_magnitude = torch.max(torch.abs((updated_bias_value - current_bias_value) / current_bias_value)) - return bias_shift_magnitude - - @staticmethod - def post_process_output_data(data: List[torch.Tensor]) -> torch.Tensor: - return torch.Tensor(data) - - @staticmethod - def reshape_tensor(data: torch.Tensor, new_shape: List[int]) -> torch.Tensor: - return data.reshape(new_shape) - @staticmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: input_node_name = node.node_name diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index 47cf5695832..d3c1e25d0ae 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -103,7 +103,7 @@ def create_quantizer_insertion_command( quantizer_config: QuantizerConfig, parameters: FakeQuantizeParameters, ): - tensor_type = np.int8 if np.any(parameters.input_low < 0) else np.uint8 + tensor_type = np.int8 if np.any(parameters.input_low.data < 0) else np.uint8 if target_point.is_weight_target_point(): tensor_type = np.int8 # The weight is restricted to have only signed range nncf_input_node_next_nodes = ONNXMinMaxAlgoBackend._get_input_edges_mapping(nncf_graph) diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 4ad4e309dc8..5412d42853d 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -138,7 +138,7 @@ def _get_reduction_axes_and_use_abs_max( else: raise NotImplementedError(f"Unsupported target point type {target_point.type}.") - # TODO (l-bat): Disable quantizer propogation through layout changing operations + # TODO (l-bat): Disable quantizer propagation through layout changing operations channel_axis = 1 # OpenVINO activations have channel first layout: [N, C, Z, Y, X] axes = get_channel_agnostic_reduction_axes([channel_axis], shape) return axes, use_abs_max diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 0a8fe5778c5..d258411698e 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -11,7 +11,6 @@ from typing import Dict, List, Optional, Set, Tuple -import numpy as np import torch import nncf.torch.graph.operator_metatypes as om @@ -19,7 +18,6 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.layer_attributes import WeightedLayerAttributes -from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.common.hardware.config import HWConfig @@ -38,7 +36,6 @@ from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand from nncf.torch.hardware.config import PTHWConfig -from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.quantization.init_range import PTRangeInitCollectorParams @@ -112,10 +109,6 @@ def hw_config(self) -> HWConfig: def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT - @staticmethod - def model_transformer(model: NNCFNetwork) -> ModelTransformer: - return PTModelTransformer(model) - @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: @@ -139,10 +132,10 @@ def create_quantizer_insertion_command( def unify_statistics(statistics: List[PTMinMaxTensorStatistic]) -> PTMinMaxTensorStatistic: max_values, min_values = [], [] for statistic in statistics: - max_values.append(torch.tensor(statistic.max_values).flatten()) - min_values.append(torch.tensor(statistic.min_values).flatten()) - max_values = torch.max(torch.tensor(max_values)) - min_values = torch.min(torch.tensor(min_values)) + max_values.append(statistic.max_values.flatten()) + min_values.append(statistic.min_values.flatten()) + max_values = torch.amax(torch.stack(max_values), dim=0) + min_values = torch.amin(torch.stack(min_values), dim=0) return PTMinMaxTensorStatistic(min_values=min_values, max_values=max_values) @staticmethod @@ -279,13 +272,12 @@ def _create_quantizer( def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None: quantizer.eps = 0 if isinstance(quantizer, AsymmetricQuantizer): - quantizer.input_low = torch.nn.Parameter(torch.from_numpy(parameters.input_low)) - quantizer.input_range = torch.nn.Parameter( - torch.from_numpy(np.array(parameters.input_high - parameters.input_low)) - ) + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data) + input_range = parameters.input_high - parameters.input_low + quantizer.input_range = torch.nn.Parameter(input_range.data) else: - quantizer.signed = np.any(parameters.input_low < 0) - quantizer.scale = torch.nn.Parameter(torch.from_numpy(parameters.input_high)) + quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) + quantizer.scale = torch.nn.Parameter(parameters.input_high.data) @staticmethod def _create_quantizer_insertion_command( diff --git a/nncf/quantization/fake_quantize.py b/nncf/quantization/fake_quantize.py index 4b68187e68e..74b61523830 100644 --- a/nncf/quantization/fake_quantize.py +++ b/nncf/quantization/fake_quantize.py @@ -21,6 +21,9 @@ from nncf.common.quantization.structs import QuantizerConfig from nncf.common.quantization.structs import QuantizerGroup from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.experimental.tensor import Tensor +from nncf.experimental.tensor import TensorDataType +from nncf.experimental.tensor import functions as fns @dataclass @@ -35,14 +38,14 @@ class FakeQuantizeParameters: :param levels: Number of quantization levels. """ - input_low: np.ndarray - input_high: np.ndarray - output_low: np.ndarray - output_high: np.ndarray + input_low: Tensor + input_high: Tensor + output_low: Tensor + output_high: Tensor levels: int -def fix_zero_filters_symmetric(max_values: np.ndarray, eps: float = 0.01) -> np.ndarray: +def fix_zero_filters_symmetric(max_values: Tensor, eps: float = 0.01) -> Tensor: """ Fixes zero filters for symmetric quantizer. @@ -50,14 +53,12 @@ def fix_zero_filters_symmetric(max_values: np.ndarray, eps: float = 0.01) -> np. :param eps: Correction coefficient. :return: Fixed the high quant number. """ - max_range = np.max(max_values) - lower_threshold = np.maximum(8e-5, eps * max_range) - return np.maximum(lower_threshold, max_values) + max_range = fns.max(max_values) + lower_threshold = fns.maximum(max_range * eps, 8e-5) + return fns.maximum(lower_threshold, max_values) -def fix_zero_filters_asymmetric( - min_values: np.ndarray, max_values: np.ndarray, eps: float = 1e-8 -) -> Tuple[np.ndarray, np.ndarray]: +def fix_zero_filters_asymmetric(min_values: Tensor, max_values: Tensor, eps: float = 1e-8) -> Tuple[Tensor, Tensor]: """ Fixes zero filters for asymmetric quantizer. @@ -69,20 +70,17 @@ def fix_zero_filters_asymmetric( level_high - fixed the high quant number """ ranges = max_values - min_values - ranges = ranges.flatten() if isinstance(ranges, np.ndarray) else np.array([ranges]) min_correction = 8e-4 - corrections = [ - (np.maximum(eps * rng, rng) - rng) * 0.5 if rng > min_correction else min_correction for rng in ranges - ] - corrections = np.array(corrections).reshape(max_values.shape) + corrections = fns.where(ranges > min_correction, (fns.maximum(eps * ranges, ranges) - ranges) * 0.5, min_correction) + level_low = min_values - corrections level_high = max_values + corrections return level_low, level_high def tune_range( - left_border: np.ndarray, right_border: np.ndarray, num_bits: int, unify_zp: bool = False -) -> Tuple[np.ndarray, np.ndarray]: + left_border: Tensor, right_border: Tensor, num_bits: int, unify_zp: bool = False +) -> Tuple[Tensor, Tensor]: """ Tunes asymmetric quantization range to unify the zero point of all channels if `unify_zp` is True, or sets zero quant precisely to zero value otherwise. @@ -101,22 +99,21 @@ def tune_range( if unify_zp: scale = (right_border - left_border) / level_high zero_point = -left_border / scale - avg_zpts = np.round(np.mean(zero_point)) - qval = np.ones_like(left_border) * avg_zpts + avg_zpts = fns.round(fns.mean(zero_point)) + qval = fns.ones_like(left_border) * avg_zpts else: s = level_high / (right_border - left_border) fval = -left_border * s - qval = np.round(fval) + qval = fns.round(fval) - with np.errstate(invalid="ignore", divide="ignore"): - ra = np.where(qval < level_high, qval / (qval - level_high) * right_border, left_border) - rb = np.where(qval > 0.0, (qval - level_high) / qval * left_border, right_border) + ra = fns.where(qval < level_high, qval / (qval - level_high) * right_border, left_border) + rb = fns.where(qval > 0.0, (qval - level_high) / qval * left_border, right_border) range_a = right_border - ra range_b = rb - left_border - mask = np.where(range_a > range_b, 1.0, 0.0) - inv_mask = np.abs(1.0 - mask) + mask = fns.where(range_a > range_b, 1.0, 0.0) + inv_mask = fns.abs(1.0 - mask) ra = mask * ra + inv_mask * left_border rb = inv_mask * rb + mask * right_border @@ -125,12 +122,12 @@ def tune_range( def symmetric_range( - min_values: np.ndarray, - max_values: np.ndarray, + min_values: Tensor, + max_values: Tensor, levels: int, quantizer_config: QuantizerConfig, q_group: QuantizerGroup, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[Tensor, Tensor]: """ Calculates the numbers of the low and high quant for the symmetric quantization scheme. @@ -148,21 +145,23 @@ def symmetric_range( else: signed = quantizer_config.signedness_to_force is True level_low = ( - np.zeros_like(level_high) if np.all(min_values >= 0) and not signed else -level_high * levels / (levels - 2) + fns.zeros_like(level_high) + if fns.all(min_values >= 0) and not signed + else -level_high * levels / (levels - 2) ) - level_low = level_low.astype(np.float32) - level_high = level_high.astype(np.float32) + level_low = level_low.astype(TensorDataType.float32) + level_high = level_high.astype(TensorDataType.float32) return level_low, level_high def asymmetric_range( - min_values: np.ndarray, - max_values: np.ndarray, + min_values: Tensor, + max_values: Tensor, quantizer_config: QuantizerConfig, q_group: QuantizerGroup, unify_zp: bool = False, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[Tensor, Tensor]: """ Calculates the numbers of the low and high quant for the asymmetric quantization scheme. @@ -176,15 +175,15 @@ def asymmetric_range( level_high - the high quant number """ level_low, level_high = fix_zero_filters_asymmetric(min_values, max_values) - level_low = np.where(level_low < 0.0, level_low, 0.0) - level_high = np.where(level_high > 0.0, level_high, 0.0) + level_low = fns.where(level_low < 0.0, level_low, 0.0) + level_high = fns.where(level_high > 0.0, level_high, 0.0) if unify_zp and q_group == QuantizerGroup.ACTIVATIONS: raise NotImplementedError("Unified zero point is not supported for activations.") level_low, level_high = tune_range(level_low, level_high, quantizer_config.num_bits, unify_zp=unify_zp) - level_low = level_low.astype(np.float32) - level_high = level_high.astype(np.float32) + level_low = level_low.astype(TensorDataType.float32) + level_high = level_high.astype(TensorDataType.float32) return level_low, level_high @@ -221,8 +220,8 @@ def calculate_quantizer_parameters( False - the full range is used. :return: Parameters of the FakeQuantize layer. """ - min_values = np.array(statistics.min_values).astype(np.float32) - max_values = np.array(statistics.max_values).astype(np.float32) + min_values = Tensor(statistics.min_values).astype(TensorDataType.float32) + max_values = Tensor(statistics.max_values).astype(TensorDataType.float32) if half_range: input_low, input_high, levels = _calculate_scaled_parameters( @@ -240,21 +239,20 @@ def calculate_quantizer_parameters( input_low, input_high = asymmetric_range(min_values, max_values, quantizer_config, quant_group) if not quantizer_config.per_channel: - input_low = np.squeeze(input_low) - input_high = np.squeeze(input_high) + input_low = fns.squeeze(input_low) + input_high = fns.squeeze(input_high) - input_low, input_high = np.array(input_low), np.array(input_high) output_low, output_high = input_low, input_high return FakeQuantizeParameters(input_low, input_high, output_low, output_high, levels) def _calculate_scaled_parameters( - min_values: np.ndarray, - max_values: np.ndarray, + min_values: Tensor, + max_values: Tensor, quantizer_config: QuantizerConfig, quant_group: QuantizerGroup, narrow_range: bool, -) -> Tuple[np.ndarray, np.ndarray, int]: +) -> Tuple[Tensor, Tensor, int]: """ Calculates FakeQuantize layer attributes scaled to effectively use a half range of the quantization range. diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index 1d3464882fe..01a916b61a5 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -16,6 +16,7 @@ import onnx from nncf import Dataset +from nncf.experimental.tensor import Tensor from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.onnx.graph.onnx_graph import ONNXGraph from nncf.onnx.statistics.statistics import ONNXMinMaxTensorStatistic @@ -32,10 +33,18 @@ def mock_collect_statistics(mocker): - get_statistics_value = ONNXMinMaxTensorStatistic(min_values=-1, max_values=1) + get_statistics_value = ONNXMinMaxTensorStatistic( + min_values=np.array(-1, dtype=np.float32), max_values=np.array(1, dtype=np.float32) + ) _ = mocker.patch( "nncf.quantization.fake_quantize.calculate_quantizer_parameters", - return_value=FakeQuantizeParameters(np.array(0), np.array(0), np.array(0), np.array(0), 256), + return_value=FakeQuantizeParameters( + Tensor(np.array(0, dtype=np.float32)), + Tensor(np.array(0, dtype=np.float32)), + Tensor(np.array(0, dtype=np.float32)), + Tensor(np.array(0, dtype=np.float32)), + 256, + ), ) _ = mocker.patch( "nncf.common.tensor_statistics.aggregator.StatisticsAggregator.collect_statistics", return_value=None diff --git a/tests/post_training/test_templates/test_calculate_quantizer_parameters.py b/tests/post_training/test_templates/test_calculate_quantizer_parameters.py index bffe249d064..ebcc5df7e75 100644 --- a/tests/post_training/test_templates/test_calculate_quantizer_parameters.py +++ b/tests/post_training/test_templates/test_calculate_quantizer_parameters.py @@ -19,6 +19,7 @@ from nncf.common.quantization.structs import QuantizationMode from nncf.common.quantization.structs import QuantizerConfig from nncf.common.quantization.structs import QuantizerGroup +from nncf.experimental.tensor import functions as fns from nncf.quantization.fake_quantize import FakeQuantizeParameters from nncf.quantization.fake_quantize import calculate_quantizer_parameters from tests.post_training.conftest import FQ_CALCULATED_PARAMETERS_PATH @@ -32,10 +33,10 @@ def compare_fq_parameters(ref_params, params): assert ref_params.input_high.shape == params.input_high.shape assert ref_params.output_low.shape == params.output_low.shape assert ref_params.output_high.shape == params.output_high.shape - assert np.allclose(ref_params.input_low, params.input_low) - assert np.allclose(ref_params.input_high, params.input_high) - assert np.allclose(ref_params.output_low, params.output_low) - assert np.allclose(ref_params.output_high, params.output_high) + assert fns.allclose(ref_params.input_low, params.input_low) + assert fns.allclose(ref_params.input_high, params.input_high) + assert fns.allclose(ref_params.output_low, params.output_low) + assert fns.allclose(ref_params.output_high, params.output_high) def get_test_reference_key(q_group, q_config, narrow_range, hf_range): diff --git a/tests/post_training/test_templates/test_fast_bias_correction.py b/tests/post_training/test_templates/test_fast_bias_correction.py index b972ce851cd..c4ea71d6551 100644 --- a/tests/post_training/test_templates/test_fast_bias_correction.py +++ b/tests/post_training/test_templates/test_fast_bias_correction.py @@ -67,7 +67,7 @@ def test_reshape_bias_shift(self, bias_value: list, bias_shift: list, channel_ax algo = FastBiasCorrection(subset_size=1, inplace_statistics=False) # pylint: disable=protected-access algo._backend_entity = self.get_backend() - new_bias_shift = algo.reshape_bias_shift(bias_shift, bias_value, channel_axis) + new_bias_shift = algo._reshape_bias_shift(bias_shift, bias_value, channel_axis) assert list(new_bias_shift.shape) == ref_shape @staticmethod diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 9fff5e9de1c..461deb14fce 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -17,10 +17,11 @@ import pytest +from nncf.experimental.common.tensor_statistics import statistical_functions as s_fns from nncf.experimental.tensor import Tensor from nncf.experimental.tensor import TensorDataType from nncf.experimental.tensor import TensorDeviceType -from nncf.experimental.tensor import functions +from nncf.experimental.tensor import functions as fns TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") @@ -68,6 +69,7 @@ def test_operators_tensor(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", OPERATOR_MAP.keys()) def test_operators_int(self, op_name): @@ -83,6 +85,7 @@ def test_operators_int(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", ("add", "sub", "mul", "truediv", "floordiv")) def test_operators_int_rev(self, op_name): @@ -98,6 +101,7 @@ def test_operators_int_rev(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", COMPARISON_OPERATOR_MAP.keys()) def test_comparison_tensor(self, op_name): @@ -150,6 +154,7 @@ def test_comparison_int_rev(self, op_name): ([[[[1], [2]], [[1], [2]]]], None, [[1, 2], [1, 2]]), ([[[[1], [2]], [[1], [2]]]], 0, [[[1], [2]], [[1], [2]]]), ([[[[1], [2]], [[1], [2]]]], -1, [[[1, 2], [1, 2]]]), + ([[[[1], [2]], [[1], [2]]]], (0, 3), [[1, 2], [1, 2]]), ), ) def test_squeeze(self, val, axis, ref): @@ -157,11 +162,23 @@ def test_squeeze(self, val, axis, ref): nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) res = nncf_tensor.squeeze(axis=axis) - if isinstance(ref, list): - assert functions.all(res == ref_tensor) - else: - assert res == ref_tensor assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device + + @pytest.mark.parametrize( + "val, axis, exception_type, exception_match", + ( + ([[[[1], [2]], [[1], [2]]]], (0, 1), ValueError, "not equal to one"), + ([[[[1], [2]], [[1], [2]]]], 42, IndexError, "out of"), + ([[[[1], [2]], [[1], [2]]]], (0, 42), IndexError, "out of"), + ), + ) + def test_squeeze_axis_error(self, val, axis, exception_type, exception_match): + tensor = self.to_tensor(val) + nncf_tensor = Tensor(tensor) + with pytest.raises(exception_type, match=exception_match): + nncf_tensor.squeeze(axis=axis) @pytest.mark.parametrize( "val, axis, ref", @@ -177,12 +194,10 @@ def test_fn_squeeze(self, val, axis, ref): tensor = self.to_tensor(val) nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) - res = functions.squeeze(nncf_tensor, axis=axis) - if isinstance(ref, list): - assert functions.all(res == ref_tensor) - else: - assert res == ref_tensor + res = fns.squeeze(nncf_tensor, axis=axis) assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val,ref", @@ -197,31 +212,9 @@ def test_flatten(self, val, ref): nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) res = nncf_tensor.flatten() - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor - assert isinstance(res, Tensor) - - @pytest.mark.parametrize( - "val, axis, ref", - ( - (1, None, 1), - ([1], None, 1), - ([[[[1], [2]], [[3], [4]]]], None, 4), - ([[1, 2], [3, 4]], 1, [2, 4]), - ), - ) - def test_max(self, val, axis, ref): - tensor = self.to_tensor(val) - nncf_tensor = Tensor(tensor) - ref_tensor = self.to_tensor(ref) - res = nncf_tensor.max(axis=axis) - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val, axis, ref", @@ -236,12 +229,10 @@ def test_fn_max(self, val, axis, ref): tensor = self.to_tensor(val) nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) - res = functions.max(nncf_tensor, axis=axis) - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor + res = fns.max(nncf_tensor, axis=axis) assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val, axis, ref", @@ -256,30 +247,9 @@ def test_min(self, val, axis, ref): nncf_tensor = Tensor(self.to_tensor(val)) ref_tensor = self.to_tensor(ref) res = nncf_tensor.min(axis=axis) - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor - assert isinstance(res, Tensor) - - @pytest.mark.parametrize( - "val, axis, ref", - ( - (1, None, 1), - ([1], None, 1), - ([[[[1], [2]], [[3], [4]]]], None, 1), - ([[1, 2], [3, 4]], 1, [1, 3]), - ), - ) - def test_fn_min(self, val, axis, ref): - nncf_tensor = Tensor(self.to_tensor(val)) - ref_tensor = self.to_tensor(ref) - res = functions.min(nncf_tensor, axis=axis) - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val, ref", @@ -292,11 +262,9 @@ def test_abs(self, val, ref): nncf_tensor = Tensor(self.to_tensor(val)) nncf_ref_tensor = Tensor(self.to_tensor(ref)) res = nncf_tensor.abs() - if isinstance(ref, list): - assert all(res == nncf_ref_tensor) - else: - assert res == nncf_ref_tensor assert isinstance(res, Tensor) + assert fns.allclose(res, nncf_ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val, ref", @@ -308,12 +276,10 @@ def test_abs(self, val, ref): def test_fn_abs(self, val, ref): nncf_tensor = Tensor(self.to_tensor(val)) nncf_ref_tensor = Tensor(self.to_tensor(ref)) - res = functions.abs(nncf_tensor) - if isinstance(ref, list): - assert all(res == nncf_ref_tensor) - else: - assert res == nncf_ref_tensor + res = fns.abs(nncf_tensor) assert isinstance(res, Tensor) + assert fns.allclose(res, nncf_ref_tensor) + assert res.device == nncf_tensor.device def test_getitem(self): arr = [0, 1, 2] @@ -321,6 +287,7 @@ def test_getitem(self): res = nncf_tensor[1] assert res == 1 assert isinstance(res, Tensor) + assert res.device == nncf_tensor.device def test_iter(self): arr = [0, 1, 2] @@ -341,67 +308,72 @@ def test_iter(self): ), ) def test_fn_count_nonzero(self, axis, ref): - tensor = self.to_tensor([[1, 2], [1, 0]]) + tensor = self.to_tensor([[1.0, 2.0], [1.0, 0.0]]) nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) - res = functions.count_nonzero(nncf_tensor, axis=axis) - if axis is None: - assert res.data == ref_tensor - else: - assert all(res.data == self.to_tensor(ref)) + res = fns.count_nonzero(nncf_tensor, axis=axis) + assert isinstance(res, Tensor) + assert fns.allclose(res.data, ref_tensor) + assert res.device == nncf_tensor.device def test_fn_zeros_like(self): tensor = self.to_tensor([1, 2]) nncf_tensor = Tensor(tensor) - res = functions.zeros_like(nncf_tensor) + res = fns.zeros_like(nncf_tensor) assert all(res == Tensor(tensor * 0)) assert isinstance(res, Tensor) + assert res.device == nncf_tensor.device def test_fn_maximum(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_b = Tensor(self.to_tensor([2, 1])) tensor_ref = self.to_tensor([2, 2]) - res = functions.maximum(tensor_a, tensor_b) + res = fns.maximum(tensor_a, tensor_b) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device def test_fn_maximum_list(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_b = [2, 1] tensor_ref = self.to_tensor([2, 2]) - res = functions.maximum(tensor_a, tensor_b) + res = fns.maximum(tensor_a, tensor_b) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device def test_fn_minimum(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_b = Tensor(self.to_tensor([2, 1])) tensor_ref = self.to_tensor([1, 1]) - res = functions.minimum(tensor_a, tensor_b) + res = fns.minimum(tensor_a, tensor_b) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device def test_fn_minimum_list(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_b = [2, 1] tensor_ref = self.to_tensor([1, 1]) - res = functions.minimum(tensor_a, tensor_b) + res = fns.minimum(tensor_a, tensor_b) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device def test_fn_ones_like(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_ref = self.to_tensor([1, 1]) - res = functions.ones_like(tensor_a) + res = fns.ones_like(tensor_a) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device @pytest.mark.parametrize( "val, axis, ref", @@ -414,12 +386,10 @@ def test_fn_ones_like(self): ) def test_fn_all(self, val, axis, ref): tensor = Tensor(self.to_tensor(val)) - res = functions.all(tensor, axis=axis) - if isinstance(ref, list): - assert all(res.data == self.to_tensor(ref)) - else: - assert res.data == self.to_tensor(ref) + res = fns.all(tensor, axis=axis) assert isinstance(res, Tensor) + assert fns.allclose(res.data, self.to_tensor(ref)) + assert res.device == tensor.device @pytest.mark.parametrize( "val, axis, ref", @@ -432,19 +402,19 @@ def test_fn_all(self, val, axis, ref): ) def test_fn_any(self, val, axis, ref): tensor = Tensor(self.to_tensor(val)) - res = functions.any(tensor, axis=axis) - if isinstance(ref, list): - assert all(res.data == self.to_tensor(ref)) - else: - assert res == ref + res = fns.any(tensor, axis=axis) + assert isinstance(res, Tensor) + assert fns.allclose(res.data, self.to_tensor(ref)) + assert res.device == tensor.device def test_fn_where(self): tensor = Tensor(self.to_tensor([1, -1])) tensor_ref = self.to_tensor([1, 0]) - res = functions.where(tensor > 0, 1, 0) + res = fns.where(tensor > 0, 1, 0) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor.device @pytest.mark.parametrize( "val, ref", @@ -456,9 +426,9 @@ def test_fn_where(self): ) def test_fn_isempty(self, val, ref): tensor = Tensor(self.to_tensor(val)) - res = functions.isempty(tensor) + res = fns.isempty(tensor) assert res == ref - assert isinstance(res, Tensor) + assert isinstance(res, bool) @pytest.mark.parametrize( "val, ref", @@ -472,7 +442,7 @@ def test_isempty(self, val, ref): tensor = Tensor(self.to_tensor(val)) res = tensor.isempty() assert res == ref - assert isinstance(res, Tensor) + assert isinstance(res, bool) @pytest.mark.parametrize( "x1, x2, rtol, atol, ref", @@ -482,19 +452,19 @@ def test_isempty(self, val, ref): ([0.1], [0.10001], 0.1, None, True), ([0.1], [0.10001], None, 0.1, True), ([0.1], [0.20001], None, 0.1, False), + ([0.1], 0.1, None, None, True), ), ) def test_fn_allclose(self, x1, x2, rtol, atol, ref): tensor1 = Tensor(self.to_tensor(x1)) tensor2 = Tensor(self.to_tensor(x2)) if rtol is not None: - res = functions.allclose(tensor1, tensor2, rtol=rtol) + res = fns.allclose(tensor1, tensor2, rtol=rtol) elif atol is not None: - res = functions.allclose(tensor1, tensor2, atol=atol) + res = fns.allclose(tensor1, tensor2, atol=atol) else: - res = functions.allclose(tensor1, tensor2) + res = fns.allclose(tensor1, tensor2) assert res == ref - assert isinstance(res, Tensor) @pytest.mark.parametrize( "x1, x2, rtol, atol, ref", @@ -503,17 +473,18 @@ def test_fn_allclose(self, x1, x2, rtol, atol, ref): ([0.1], [0.10001], None, None, [False]), ([0.1], [0.10001], 0.1, None, [True]), ([0.1], [0.10001], None, 0.1, [True]), + ([0.1], 0.1, None, None, [True]), ), ) def test_fn_isclose(self, x1, x2, rtol, atol, ref): tensor1 = Tensor(self.to_tensor(x1)) tensor2 = Tensor(self.to_tensor(x2)) if rtol is not None: - res = functions.isclose(tensor1, tensor2, rtol=rtol) + res = fns.isclose(tensor1, tensor2, rtol=rtol) elif atol is not None: - res = functions.isclose(tensor1, tensor2, atol=atol) + res = fns.isclose(tensor1, tensor2, atol=atol) else: - res = functions.isclose(tensor1, tensor2) + res = fns.isclose(tensor1, tensor2) assert all(res == self.to_tensor(ref)) assert isinstance(res, Tensor) @@ -526,23 +497,202 @@ def test_astype(self): res = tensor.astype(TensorDataType.int8) assert isinstance(res, Tensor) assert res.dtype == TensorDataType.int8 + assert res.device == tensor.device def test_fn_astype(self): tensor = Tensor(self.to_tensor([1])) - res = functions.astype(tensor, TensorDataType.int8) + res = fns.astype(tensor, TensorDataType.int8) assert isinstance(res, Tensor) assert res.dtype == TensorDataType.int8 def test_reshape(self): tensor = Tensor(self.to_tensor([1, 1])) - assert tensor.shape == [2] - assert tensor.reshape([1, 2]).shape == [1, 2] + res = tensor.reshape((1, 2)) + assert tensor.shape == (2,) + assert res.shape == (1, 2) + assert res.device == tensor.device def test_fn_reshape(self): tensor = Tensor(self.to_tensor([1, 1])) - assert tensor.shape == [2] - assert functions.reshape(tensor, [1, 2]).shape == [1, 2] + res = fns.reshape(tensor, (1, 2)) + assert tensor.shape == (2,) + assert res.shape == (1, 2) + assert res.device == tensor.device def test_not_implemented(self): with pytest.raises(NotImplementedError, match="is not implemented for"): - functions.device({}, [1, 2]) + fns.device({}, [1, 2]) + + @pytest.mark.parametrize( + "x, axis, ref", + ( + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 0, + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 1, + [[0.8, 0.1], [0.2, 0.7], [0.2, 0.1]], + ), + ), + ) + def test_fn_unstack(self, x, axis, ref): + tensor = Tensor(self.to_tensor(x)) + ref = [self.to_tensor(r) for r in ref] + + res = fns.unstack(tensor, axis=axis) + + assert isinstance(res, list) + for i, _ in enumerate(ref): + assert all(res[i] == ref[i]) + assert res[i].device == tensor.device + + @pytest.mark.parametrize( + "x, axis, ref", + ( + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 0, + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 1, + [[0.8, 0.1], [0.2, 0.7], [0.2, 0.1]], + ), + ), + ) + def test_fn_stack(self, x, axis, ref): + list_tensor = [Tensor(self.to_tensor(i)) for i in x] + ref = self.to_tensor(ref) + + res = fns.stack(list_tensor, axis=axis) + + assert isinstance(res, Tensor) + assert fns.all(res.data == ref) + assert res.device == list_tensor[0].device + + def test_fn_moveaxis(self): + tensor = [[0, 0, 0], [0, 0, 0]] + tensor = Tensor(self.to_tensor(tensor)) + + res = fns.moveaxis(tensor, 0, -1) + + assert res.shape == (3, 2) + + @pytest.mark.parametrize( + "x, axis, keepdims, ref", + ( + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 0, + False, + [0.45, 0.45, 0.15], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 0, + True, + [[0.45, 0.45, 0.15]], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + (0, 1), + True, + [[0.35]], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + None, + False, + 0.35, + ), + ), + ) + def test_fn_mean(self, x, axis, keepdims, ref): + tensor = Tensor(self.to_tensor(x)) + ref_tensor = self.to_tensor(ref) + + res = fns.mean(tensor, axis, keepdims) + + assert isinstance(res, Tensor) + assert fns.allclose(res.data, ref_tensor) + assert res.device == tensor.device + + @pytest.mark.parametrize( + "val, decimals, ref", + ( + (1.1, 0, 1.0), + ([1.1, 0.9], 0, [1.0, 1.0]), + ([1.11, 0.91], 1, [1.1, 0.9]), + ), + ) + def test_fn_round(self, val, decimals, ref): + tensor = Tensor(self.to_tensor(val)) + ref_tensor = self.to_tensor(ref) + + res = fns.round(tensor, decimals) + + assert isinstance(res, Tensor) + assert fns.allclose(res.data, ref_tensor) + assert res.device == tensor.device + + @pytest.mark.parametrize( + "val, axis, ref", + ( + ( + [[9.0, 9.0], [7.0, 1.0]], + 0, + [8.0, 5.0], + ), + ( + [[[9.0, 9.0], [0.0, 3.0]], [[5.0, 1.0], [7.0, 1.0]]], + 0, + [5.25, 3.5], + ), + ( + [[[9.0, 9.0], [0.0, 3.0]], [[5.0, 1.0], [7.0, 1.0]]], + 2, + [5.25, 3.5], + ), + ( + [ + [[[9.0, 6.0], [8.0, 5.0]], [[3.0, 9.0], [4.0, 6.0]]], + [[[3.0, 9.0], [9.0, 2.0]], [[2.0, 4.0], [2.0, 5.0]]], + ], + 0, + [6.25, 4.5], + ), + ( + [ + [[[9.0, 6.0], [8.0, 5.0]], [[3.0, 9.0], [4.0, 6.0]]], + [[[3.0, 9.0], [9.0, 2.0]], [[2.0, 4.0], [2.0, 5.0]]], + ], + 1, + [6.375, 4.375], + ), + ( + [ + [[[9.0, 6.0], [8.0, 5.0]], [[3.0, 9.0], [4.0, 6.0]]], + [[[3.0, 9.0], [9.0, 2.0]], [[2.0, 4.0], [2.0, 5.0]]], + ], + -1, + [5.0, 5.75], + ), + ), + ) + def test_fn_mean_per_channel(self, val, axis, ref): + tensor = Tensor(self.to_tensor(val)) + ref_tensor = self.to_tensor(ref) + res = s_fns.mean_per_channel(tensor, axis) + assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor), f"{res.data}" + assert res.device == tensor.device + + @pytest.mark.parametrize("axis", (3, 4, -4, -5)) + def test_fn_mean_per_channel_incorrect_axis(self, axis): + tensor = Tensor(self.to_tensor([[[9.0, 9.0], [0.0, 3.0]], [[5.0, 1.0], [7.0, 1.0]]])) + with pytest.raises(ValueError, match="is out of bounds for array of dimension"): + s_fns.mean_per_channel(tensor, axis) diff --git a/tests/torch/ptq/test_calculation_quantizer_params.py b/tests/torch/ptq/test_calculation_quantizer_params.py index 3c0cf83a64e..cea666d0065 100644 --- a/tests/torch/ptq/test_calculation_quantizer_params.py +++ b/tests/torch/ptq/test_calculation_quantizer_params.py @@ -25,6 +25,8 @@ from nncf.common.quantization.structs import QuantizationPreset from nncf.common.quantization.structs import QuantizerConfig from nncf.common.quantization.structs import QuantizerGroup +from nncf.experimental.tensor import Tensor +from nncf.experimental.tensor import functions as fn from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from nncf.quantization.fake_quantize import FakeQuantizeParameters @@ -54,10 +56,10 @@ class CaseSymParams: SYM_CASES = ( CaseSymParams( fq_params=FakeQuantizeParameters( - np.array(-0.49920455, dtype=np.float32), - np.array(0.49530452, dtype=np.float32), - np.array(-0.49920455, dtype=np.float32), - np.array(0.49530452, dtype=np.float32), + Tensor(torch.tensor(-0.49920455, dtype=torch.float32)), + Tensor(torch.tensor(0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(-0.49920455, dtype=torch.float32)), + Tensor(torch.tensor(0.49530452, dtype=torch.float32)), 256, ), per_channel=False, @@ -66,10 +68,10 @@ class CaseSymParams: ), CaseSymParams( fq_params=FakeQuantizeParameters( - np.array(-0.49530452, dtype=np.float32), - np.array(0.49530452, dtype=np.float32), - np.array(-0.49530452, dtype=np.float32), - np.array(0.49530452, dtype=np.float32), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49530452, dtype=torch.float32)), 255, ), per_channel=False, @@ -78,33 +80,33 @@ class CaseSymParams: ), CaseSymParams( fq_params=FakeQuantizeParameters( - np.array([-0.4835594, -0.49530452, -0.49221927], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([0.4797816, 0.49920455, 0.48837382], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([-0.4835594, -0.49530452, -0.49221927], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([0.4797816, 0.49920455, 0.48837382], dtype=np.float32).reshape(1, 3, 1, 1), + Tensor(torch.tensor([-0.4835594, -0.49530452, -0.49221927], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([0.4797816, 0.49920455, 0.48837382], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([-0.4835594, -0.49530452, -0.49221927], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([0.4797816, 0.49920455, 0.48837382], dtype=torch.float32).reshape(1, 3, 1, 1)), 256, ), per_channel=True, quant_group=QuantizerGroup.ACTIVATIONS, - ref_scale=np.array([0.4797816, 0.49920455, 0.48837382]).reshape(1, 3, 1, 1), + ref_scale=torch.tensor([0.4797816, 0.49920455, 0.48837382]).reshape(1, 3, 1, 1), ), CaseSymParams( fq_params=FakeQuantizeParameters( - np.array([-0.48837382, -0.49530452], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([0.48837382, 0.49530452], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([-0.48837382, -0.49530452], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([0.48837382, 0.49530452], dtype=np.float32).reshape(2, 1, 1, 1), + Tensor(torch.tensor([-0.48837382, -0.49530452], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([0.48837382, 0.49530452], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([-0.48837382, -0.49530452], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([0.48837382, 0.49530452], dtype=torch.float32).reshape(2, 1, 1, 1)), 255, ), per_channel=True, quant_group=QuantizerGroup.WEIGHTS, - ref_scale=np.array([0.48837382, 0.49530452]).reshape(2, 1, 1, 1), + ref_scale=torch.tensor([0.48837382, 0.49530452]).reshape(2, 1, 1, 1), ), ) @pytest.mark.parametrize("case_to_test", SYM_CASES) -def test_quantizer_params_sym(case_to_test): +def test_quantizer_params_sym(case_to_test: CaseSymParams): per_ch = case_to_test.per_channel fq_params = case_to_test.fq_params quant_group = case_to_test.quant_group @@ -140,10 +142,10 @@ class CaseAsymParams: ASYM_CASES = ( CaseAsymParams( fq_params=FakeQuantizeParameters( - np.array(-0.49530452, dtype=np.float32), - np.array(0.49143496, dtype=np.float32), - np.array(-0.49530452, dtype=np.float32), - np.array(0.49143496, dtype=np.float32), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49143496, dtype=torch.float32)), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49143496, dtype=torch.float32)), 256, ), per_channel=False, @@ -153,10 +155,10 @@ class CaseAsymParams: ), CaseAsymParams( fq_params=FakeQuantizeParameters( - np.array(-0.49530452, dtype=np.float32), - np.array(0.49143496, dtype=np.float32), - np.array(-0.49530452, dtype=np.float32), - np.array(0.49143496, dtype=np.float32), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49143496, dtype=torch.float32)), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49143496, dtype=torch.float32)), 256, ), per_channel=False, @@ -166,35 +168,35 @@ class CaseAsymParams: ), CaseAsymParams( fq_params=FakeQuantizeParameters( - np.array([-0.48051512, -0.49776307, -0.44099426], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([0.4767611, 0.47861832, 0.48837382], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([-0.48051512, -0.49776307, -0.44099426], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([0.4767611, 0.47861832, 0.48837382], dtype=np.float32).reshape(1, 3, 1, 1), + Tensor(torch.tensor([-0.48051512, -0.49776307, -0.44099426], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([0.4767611, 0.47861832, 0.48837382], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([-0.48051512, -0.49776307, -0.44099426], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([0.4767611, 0.47861832, 0.48837382], dtype=torch.float32).reshape(1, 3, 1, 1)), 256, ), per_channel=True, quant_group=QuantizerGroup.ACTIVATIONS, - ref_inp_low=np.array([-0.48051512, -0.49776307, -0.44099426]).reshape(1, 3, 1, 1), - ref_inp_range=np.array([0.9572762, 0.9763814, 0.9293681]).reshape(1, 3, 1, 1), + ref_inp_low=torch.tensor([-0.48051512, -0.49776307, -0.44099426]).reshape(1, 3, 1, 1), + ref_inp_range=torch.tensor([0.9572762, 0.9763814, 0.9293681]).reshape(1, 3, 1, 1), ), CaseAsymParams( fq_params=FakeQuantizeParameters( - np.array([-0.4845584, -0.49583155], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([0.48837382, 0.4767611], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([-0.4845584, -0.49583155], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([0.48837382, 0.4767611], dtype=np.float32).reshape(2, 1, 1, 1), + Tensor(torch.tensor([-0.4845584, -0.49583155], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([0.48837382, 0.4767611], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([-0.4845584, -0.49583155], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([0.48837382, 0.4767611], dtype=torch.float32).reshape(2, 1, 1, 1)), 256, ), per_channel=True, quant_group=QuantizerGroup.WEIGHTS, - ref_inp_low=np.array([-0.4845584, -0.49583155]).reshape(2, 1, 1, 1), - ref_inp_range=np.array([0.97293222, 0.97259265]).reshape(2, 1, 1, 1), + ref_inp_low=torch.tensor([-0.4845584, -0.49583155]).reshape(2, 1, 1, 1), + ref_inp_range=torch.tensor([0.97293222, 0.97259265]).reshape(2, 1, 1, 1), ), ) @pytest.mark.parametrize("case_to_test", ASYM_CASES) -def test_quantizer_params_asym(case_to_test): +def test_quantizer_params_asym(case_to_test: CaseSymParams): per_ch = case_to_test.per_channel fq_params = case_to_test.fq_params quant_group = case_to_test.quant_group @@ -212,8 +214,8 @@ def test_quantizer_params_asym(case_to_test): ) quantizer = PTMinMaxAlgoBackend._create_quantizer(qconfig, scale_shape, fq_params, target_type) assert quantizer.levels == fq_params.levels - assert np.allclose(quantizer.input_low.detach().numpy(), case_to_test.ref_inp_low) - assert np.allclose(quantizer.input_range.detach().numpy(), case_to_test.ref_inp_range) + assert fn.allclose(quantizer.input_low.data, case_to_test.ref_inp_low) + assert fn.allclose(quantizer.input_range.data, case_to_test.ref_inp_range) class LinearTestModel(nn.Module): @@ -270,10 +272,7 @@ def calculate_statistics(data, mode, qgroup, half_range=False): else: max_values = np.amax(data, axes) - statistics = PTMinMaxTensorStatistic( - min_values=torch.from_numpy(np.array(min_values)), - max_values=torch.from_numpy(np.array(max_values)), - ) + statistics = PTMinMaxTensorStatistic(min_values=torch.tensor(min_values), max_values=torch.tensor(max_values)) signedness_to_force = True if qgroup == QuantizerGroup.WEIGHTS else None qconfig = QuantizerConfig(num_bits=8, mode=mode, per_channel=per_ch, signedness_to_force=signedness_to_force) narrow_range = get_quantizer_narrow_range(qconfig, qgroup) @@ -346,8 +345,8 @@ def test_quantizer_parameters_export(tmp_path: Path): for name, param in fq_params.items(): assert name in torch_ptq_params - assert np.allclose(param["input_low"], torch_ptq_params[name]["input_low"]) - assert np.allclose(param["input_high"], torch_ptq_params[name]["input_high"]) + assert fn.allclose(param["input_low"], torch_ptq_params[name]["input_low"]) + assert fn.allclose(param["input_high"], torch_ptq_params[name]["input_high"]) class TestFQParams(TemplateTestFQParams): diff --git a/tests/torch/ptq/test_fast_bias_correction.py b/tests/torch/ptq/test_fast_bias_correction.py index b713aeb802c..7f5639aaeba 100644 --- a/tests/torch/ptq/test_fast_bias_correction.py +++ b/tests/torch/ptq/test_fast_bias_correction.py @@ -59,3 +59,30 @@ def check_bias(model: NNCFNetwork, ref_bias: list): assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" return raise ValueError("Not found node with bias") + + +class TestTorchCudaFBCAlgorithm(TestTorchFBCAlgorithm): + @staticmethod + def list_to_backend_type(data: List) -> torch.Tensor: + return torch.Tensor(data).cuda() + + @staticmethod + def backend_specific_model(model: bool, tmp_dir: str): + return get_nncf_network(model.cuda(), model.INPUT_SIZE) + + @staticmethod + def fn_to_type(tensor): + return torch.Tensor(tensor).cuda() + + @staticmethod + def check_bias(model: NNCFNetwork, ref_bias: list): + ref_bias = torch.Tensor(ref_bias) + nncf_graph = NNCFGraphFactory.create(model) + for node in nncf_graph.get_all_nodes(): + if not is_node_with_fused_bias(node, nncf_graph): + continue + bias_value = get_fused_bias_value(node, model).cpu() + # TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189 + assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" + return + raise ValueError("Not found node with bias")