Skip to content

Commit

Permalink
Add proposal
Browse files Browse the repository at this point in the history
  • Loading branch information
andrey-churkin committed Jul 19, 2023
1 parent fd79455 commit 55f8060
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 340 deletions.
9 changes: 9 additions & 0 deletions nncf/common/tensor_new/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ class TensorDataType(Enum):
float64 = auto()
int8 = auto()
uint8 = auto()


class TensorDeviceType(Enum):
"""
Enum representing the different tensor device types.
"""

CPU = auto()
GPU = auto()
127 changes: 102 additions & 25 deletions nncf/common/tensor_new/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,91 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
import functools
from typing import Optional, Tuple, Union, TypeVar

from nncf.common.tensor_new.tensor import Tensor
from nncf.common.tensor_new.tensor import tensor_func_dispatcher

from nncf.common.tensor_new.enums import TensorDeviceType
from nncf.common.tensor_new.enums import TensorDataType

def all(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin

TTensor = TypeVar("TTensor")
T = TypeVar("T") # TODO: Verify


@functools.singledispatch
def device(a: TTensor) -> TensorDeviceType:
"""
:param a:
:return:
"""


@functools.singledispatch
def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor:
"""
:param a:
:param axis:
:return:
"""


@functools.singledispatch
def flatten(a: TTensor) -> TTensor:
"""
:param a:
:return:
"""


@functools.singledispatch
def max(a: TTensor, axis: Optional[T] = None) -> TTensor:
"""
:param a:
:param axis:
:return:
"""


@functools.singledispatch
def min(a: TTensor, axis: Optional[T] = None) -> TTensor:
"""
:param a:
:param axis:
:return:
"""


@functools.singledispatch
def abs(a: TTensor) -> TTensor:
"""
:param a:
:return:
"""


@functools.singledispatch
def as_type(a: TTensor, dtype: TensorDataType):
"""
:param a:
:param dtype:
"""


@functools.singledispatch
def reshape(a: TTensor, shape: T) -> TTensor:
"""
:param a:
:param shape:
:return:
"""


###############################################################################


@functools.singledispatch
def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin
"""
Test whether all tensor elements along a given axis evaluate to True.
Expand All @@ -25,10 +103,10 @@ def all(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: #
:return: A new boolean or tensor is returned unless out is specified,
in which case a reference to out is returned.
"""
return tensor_func_dispatcher("all", a, axis=axis)


def allclose(a: Tensor, b: Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor:
@functools.singledispatch
def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor:
"""
Returns True if two arrays are element-wise equal within a tolerance.
Expand All @@ -41,10 +119,10 @@ def allclose(a: Tensor, b: Tensor, rtol: float = 1e-05, atol: float = 1e-08, equ
Defaults to False.
:return: True if the two arrays are equal within the given tolerance, otherwise False.
"""
return tensor_func_dispatcher("allclose", a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)


def any(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
@functools.singledispatch
def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin
"""
Test whether all tensor elements along a given axis evaluate to True.
Expand All @@ -54,10 +132,10 @@ def any(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: #
:return: A new boolean or tensor is returned unless out is specified,
in which case a reference to out is returned.
"""
return tensor_func_dispatcher("any", a, axis=axis)


def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor:
@functools.singledispatch
def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor:
"""
Counts the number of non-zero values in the tensor input.
Expand All @@ -67,20 +145,20 @@ def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> T
:return: Number of non-zero values in the tensor along a given axis.
Otherwise, the total number of non-zero values in the tensor is returned.
"""
return tensor_func_dispatcher("count_nonzero", a, axis=axis)


def is_empty(a: Tensor) -> Tensor:
@functools.singledispatch
def is_empty(a: TTensor) -> bool:
"""
Return True if input tensor is empty.
:param a: The input tensor.
:return: True is tensor is empty, otherwise False.
"""
return tensor_func_dispatcher("is_empty", a)


def isclose(a: Tensor, b: Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor:
@functools.singledispatch
def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor:
"""
Returns a boolean array where two arrays are element-wise equal within a tolerance.
Expand All @@ -93,42 +171,42 @@ def isclose(a: Tensor, b: Tensor, rtol: float = 1e-05, atol: float = 1e-08, equa
Defaults to False.
:return: Returns a boolean tensor of where a and b are equal within the given tolerance.
"""
return tensor_func_dispatcher("isclose", a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)


def maximum(x1: Tensor, x2: Tensor) -> Tensor:
@functools.singledispatch
def maximum(x1: TTensor, x2: TTensor) -> TTensor:
"""
Element-wise maximum of tensor elements.
:param x1: The first input tensor.
:param x2: The second input tensor.
:return: Output tensor.
"""
return tensor_func_dispatcher("maximum", x1, x2)


def minimum(x1: Tensor, x2: Tensor) -> Tensor:
@functools.singledispatch
def minimum(x1: TTensor, x2: TTensor) -> TTensor:
"""
Element-wise minimum of tensor elements.
:param input: The first input tensor.
:param other: The second input tensor.
:return: Output tensor.
"""
return tensor_func_dispatcher("minimum", x1, x2)


def ones_like(a: Tensor) -> Tensor:
@functools.singledispatch
def ones_like(a: TTensor) -> TTensor:
"""
Return an tensor of ones with the same shape and type as a given tensor.
:param a: The shape and data-type of a define these same attributes of the returned tensor.
:return: Tensor of ones with the same shape and type as a.
"""
return tensor_func_dispatcher("ones_like", a)


def where(condition: Tensor, x: Tensor, y: Tensor) -> Tensor:
@functools.singledispatch
def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor:
"""
Return elements chosen from x or y depending on condition.
Expand All @@ -137,14 +215,13 @@ def where(condition: Tensor, x: Tensor, y: Tensor) -> Tensor:
:param y: Value at indices where condition is False.
:return: An tensor with elements from x where condition is True, and elements from y elsewhere.
"""
return tensor_func_dispatcher("where", condition, x, y)


def zeros_like(a: Tensor) -> Tensor:
@functools.singledispatch
def zeros_like(a: TTensor) -> TTensor:
"""
Return an tensor of zeros with the same shape and type as a given tensor.
:param input: The shape and data-type of a define these same attributes of the returned tensor.
:return: tensor of zeros with the same shape and type as a.
"""
return tensor_func_dispatcher("zeros_like", a)
123 changes: 123 additions & 0 deletions nncf/common/tensor_new/numpy_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, TypeVar, Union

import numpy as np

from nncf.common.tensor_new import functions
from nncf.common.tensor_new.enums import TensorDataType
from nncf.common.tensor_new.enums import TensorDeviceType


TensorType = TypeVar("TensorType") # TODO: Should be removed


DTYPE_MAP = {
TensorDataType.float16: np.float16,
TensorDataType.float32: np.float32,
TensorDataType.float64: np.float64,
TensorDataType.int8: np.int8,
TensorDataType.uint8: np.uint8,
}


@functions.device.register
def _(a: np.ndarray) -> TensorDeviceType:
return TensorDeviceType.CPU


@functions.squeeze.register
def _(a: np.ndarray, axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray:
return np.squeeze(a, axis=axis)


@functions.flatten.register
def _(a: np.ndarray) -> np.ndarray:
return a.flatten()


@functions.max.register
def _(a: np.ndarray, axis: Optional[TensorType] = None) -> np.ndarray: # pylint: disable=redefined-builtin
return np.max(a, axis=axis)


@functions.min.register
def _(a: np.ndarray, axis: Optional[TensorType] = None) -> np.ndarray: # pylint: disable=redefined-builtin
return np.min(a, axis=axis)


@functions.abs.register
def _(a: np.ndarray) -> np.ndarray:
return np.absolute(a)


@functions.as_type.register
def _(a: np.ndarray, dtype: TensorDataType):
return a.astype(DTYPE_MAP[dtype])


###############################################################################


@functions.all.register
def _(a: np.ndarray, axis: Optional[TensorType] = None) -> TensorType: # pylint: disable=redefined-builtin
return np.all(a, axis=axis)


@functions.allclose.register
def _(a: np.ndarray, b: np.ndarray, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> bool:
return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)


@functions.any.register
def _(a: np.ndarray, axis: Optional[TensorType] = None) -> TensorType: # pylint: disable=redefined-builtin
return np.any(a, axis=axis)


@functions.count_nonzero.register
def _(a: np.ndarray, axis: Optional[TensorType] = None) -> np.ndarray:
return np.count_nonzero(a, axis=axis)


@functions.is_empty.register
def _(a: np.ndarray) -> bool:
return a.size == 0


@functions.isclose.register
def _(a: np.ndarray, b: np.ndarray, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False):
return np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)


@functions.maximum.register
def _(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
return np.maximum(x1, x2)


@functions.minimum.register
def _(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
return np.minimum(x1, x2)


@functions.ones_like.register
def _(a: np.ndarray) -> np.ndarray:
return np.ones_like(a)


@functions.where.register
def _(condition: np.ndarray, x: Union[np.ndarray, float, bool], y: Union[np.ndarray, float, bool]) -> np.ndarray:
return np.where(condition, x, y)


@functions.zeros_like.register
def _(a: np.ndarray) -> np.ndarray:
return np.zeros_like(a)
Loading

0 comments on commit 55f8060

Please sign in to comment.