Skip to content

Commit

Permalink
Introduce fns.divide
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Nov 4, 2024
1 parent 9cf712c commit 74c4168
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 31 deletions.
15 changes: 3 additions & 12 deletions nncf/quantization/algorithms/weight_compression/weight_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def calculate_signed_scale(weight: Tensor, reduction_axes: ReductionAxes, num_bi
w_max = fns.max(weight, axis=reduction_axes, keepdims=True)

scale = fns.where(w_abs_min >= w_max, w_abs_min, -w_max)
scale /= level_high
fns.inplace_divide(scale, level_high)

eps = fns.finfo(scale).eps
scale = fns.where(fns.abs(scale) < eps, eps, scale)
Expand Down Expand Up @@ -286,7 +286,6 @@ def calculate_quantized_weight(
config: WeightCompressionConfig,
scale: Tensor,
zero_point: Optional[Tensor] = None,
invert_scale=False,
) -> Tensor:
"""
Quantizes the weight tensor using the provided scale and zero point.
Expand All @@ -295,7 +294,6 @@ def calculate_quantized_weight(
:param config: Weight compression configuration.
:param scale: Scale tensor used for quantization.
:param zero_point: Zero point tensor used for quantization.
:param invert_scale: applies inversion for scale and then multiply by weights instead of division.
:return: Quantized weight tensor of uint8 or int8 type.
"""
if weight.dtype != TensorDataType.float32:
Expand All @@ -309,11 +307,7 @@ def calculate_quantized_weight(
level_low = 0 if asym_quant else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1

if invert_scale:
scale = fns.power(scale, -1)
compressed_weights = weight * scale
else:
compressed_weights = weight / scale
compressed_weights = fns.divide(weight, scale)
if zero_point is not None:
compressed_weights += zero_point.astype(weight.dtype)
compressed_weights = fns.round(compressed_weights)
Expand All @@ -328,7 +322,6 @@ def do_int_quantization(
config: WeightCompressionConfig,
precomputed_scale: Tensor = None,
precomputed_zero_point: Tensor = None,
invert_scale=False,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
The method quantizes the given weights to integer data type uniformly in accordance with the compression config.
Expand All @@ -351,8 +344,6 @@ def do_int_quantization(
:param config: Information on how to compress (quantize) a specific weight.
:param precomputed_scale: Precomputed scale.
:param precomputed_zero_point: Precomputed zero point.
:param invert_scale: applies inversion for scale and then multiply by weights instead of division.
Need as reference implementation for OV.
:return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type,
scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization.
"""
Expand All @@ -373,7 +364,7 @@ def do_int_quantization(
if precomputed_zero_point is not None:
zero_point = precomputed_zero_point

compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point, invert_scale)
compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point)
return compressed_weights, scale, zero_point


Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,11 @@ def calculate_scale_zero_point(
:return: Scale and Zero point values.
"""
levels = level_high - level_low if narrow_range else level_high - level_low + 1
scale = ((input_high - input_low) / (levels - 1)).astype(TensorDataType.float32)
scale = fns.divide((input_high - input_low), (levels - 1)).astype(TensorDataType.float32)
eps = fns.finfo(scale).eps
# NOTE: adding machine epsilon to avoid division by zero
scale = fns.where(fns.abs(scale) < eps, eps, scale)
expected_level_low = level_low + 1 if narrow_range else level_low
zero_point = expected_level_low - fns.round(input_low / scale)
zero_point = expected_level_low - fns.round(fns.divide(input_low, scale))
zero_point = fns.clip(zero_point.astype(TensorDataType.int32), level_low, level_high)
return scale, zero_point
2 changes: 2 additions & 0 deletions nncf/tensor/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
from nncf.tensor.functions.numeric import count_nonzero as count_nonzero
from nncf.tensor.functions.numeric import device as device
from nncf.tensor.functions.numeric import diag as diag
from nncf.tensor.functions.numeric import divide as divide
from nncf.tensor.functions.numeric import dtype as dtype
from nncf.tensor.functions.numeric import expand_dims as expand_dims
from nncf.tensor.functions.numeric import eye as eye
from nncf.tensor.functions.numeric import finfo as finfo
from nncf.tensor.functions.numeric import flatten as flatten
from nncf.tensor.functions.numeric import from_numpy as from_numpy
from nncf.tensor.functions.numeric import inplace_divide as inplace_divide
from nncf.tensor.functions.numeric import isclose as isclose
from nncf.tensor.functions.numeric import isempty as isempty
from nncf.tensor.functions.numeric import item as item
Expand Down
39 changes: 39 additions & 0 deletions nncf/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,42 @@ def ceil(a: Tensor) -> Tensor:
:return: An array of the same type as a, containing the ceiling values.
"""
return Tensor(ceil(a.data))


@functools.singledispatch
@tensor_guard
def divide(a: Union[Tensor, float], b: Union[Tensor, float], invert: Optional[bool] = True) -> Tensor:
"""
Divide two tensors or a tensor and a float.
This function divides `a` by `b`. If `invert` is True, it performs the division as `a * (1.0 / b)`.
Otherwise, it performs the division as `a / b`.
:param a: The first input tensor or float.
:param b: The second input tensor or float.
:param invert: If True, the division is performed as `a * (1.0 / b)`. If False, it is performed as `a / b`.
Defaults to True.
:return: A new tensor resulting from the division.
"""
return Tensor(a * (1.0 / b) if invert else a / b)


@functools.singledispatch
@tensor_guard
def inplace_divide(a: Union[Tensor, float], b: Union[Tensor, float], invert: Optional[bool] = True) -> None:
"""
In-place division of two tensors or a tensor and a float.
This function divides `a` by `b` in place. If `invert` is True, it performs the division as `a *= (1.0 / b)`.
Otherwise, it performs the division as `a /= b`.
:param a: The first input tensor or float.
:param b: The second input tensor or float.
:param invert: If True, the division is performed as `a *= (1.0 / b)`. If False, the division it is as `a /= b`.
Defaults to True.
:return: None. The operation is performed in place.
"""
if invert:
a *= 1.0 / b
else:
a /= b
4 changes: 2 additions & 2 deletions nncf/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def __ipow__(self, other: Union[Tensor, float]) -> Tensor:
return self

def __truediv__(self, other: Union[Tensor, float]) -> Tensor:
return self * _call_function("_binary_op_nowarn", 1.0, other, operator.truediv)
return _call_function("_binary_op_nowarn", self, other, operator.truediv)

def __rtruediv__(self, other: Union[Tensor, float]) -> Tensor:
return other * _call_function("_binary_reverse_op_nowarn", self, 1.0, operator.truediv)
return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv)

def __itruediv__(self, other: Union[Tensor, float]) -> Tensor:
self._data /= unwrap_tensor_data(other)
Expand Down
45 changes: 30 additions & 15 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,32 +1025,47 @@ def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids):
assert ref_e8m0_nodes == names_e8m0


@pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM))
def test_np_ov_compression_decompression(mode):
sz = 60
w = np.arange(-sz, sz).reshape(2, sz).astype(np.float32) / 9.0
@pytest.mark.parametrize("mode", [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM])
@pytest.mark.parametrize(
"w,s,zp",
[
(
np.array([[1.4372410774230957]], np.float32),
np.array([[-0.9581607580184937]], np.float32),
np.array([[1]], np.int32),
),
(np.arange(-60, 60).reshape(2, 60).astype(np.float32) / 9.0, None, None),
],
)
def test_np_ov_compression_decompression(mode, w, s, zp):
w = Tensor(w)
if s is not None:
s = Tensor(s)
if mode == CompressWeightsMode.INT4_SYM:
zp = None
if zp is not None:
zp = Tensor(zp)

config = WeightCompressionConfig(mode)

compressed_weighs, scale, zp = do_int_quantization(w, -1, config, invert_scale=True)
decompressed_weighs = do_int_dequantization(compressed_weighs, scale, zp)
compressed_weights, s, zp = do_int_quantization(w, -1, config, precomputed_scale=s, precomputed_zero_point=zp)
decompressed_weights = do_int_dequantization(compressed_weights, s, zp)

compressed_weighs = compressed_weighs.data
decompressed_weighs = decompressed_weighs.data
compressed_weights = compressed_weights.data
decompressed_weights = decompressed_weights.data
zp_shape = zp.shape if zp is not None else None

compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, scale.shape, zp_shape)
compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, s.shape, zp_shape)
compress_decompress = OVWeightCompressionAlgoBackend.get_compress_decompress_pipeline(
config, w.shape, scale.shape, zp_shape
config, w.shape, s.shape, zp_shape
)

params = [w.data, scale.data, zp.data] if zp is not None else [w.data, scale.data]
compressed_weighs_ov = compress(params)
decompressed_weighs_ov = compress_decompress(params)
params = [w.data, s.data, zp.data] if zp is not None else [w.data, s.data]
compressed_weights_ov = compress(params)
decompressed_weights_ov = compress_decompress(params)

assert np.allclose(compressed_weighs, compressed_weighs_ov)
assert np.allclose(decompressed_weighs, decompressed_weighs_ov)
assert np.allclose(compressed_weights, compressed_weights_ov, atol=0)
assert np.allclose(decompressed_weights, decompressed_weights_ov, atol=0)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 74c4168

Please sign in to comment.