diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 13406c0b288..531ef7c28f9 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -286,7 +286,6 @@ def calculate_quantized_weight( config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None, - invert_scale=False, ) -> Tensor: """ Quantizes the weight tensor using the provided scale and zero point. @@ -295,7 +294,6 @@ def calculate_quantized_weight( :param config: Weight compression configuration. :param scale: Scale tensor used for quantization. :param zero_point: Zero point tensor used for quantization. - :param invert_scale: applies inversion for scale and then multiply by weights instead of division. :return: Quantized weight tensor of uint8 or int8 type. """ if weight.dtype != TensorDataType.float32: @@ -309,11 +307,7 @@ def calculate_quantized_weight( level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 - if invert_scale: - scale = fns.power(scale, -1) - compressed_weights = weight * scale - else: - compressed_weights = weight / scale + compressed_weights = weight * fns.reciprocal(scale) if zero_point is not None: compressed_weights += zero_point.astype(weight.dtype) compressed_weights = fns.round(compressed_weights) @@ -328,7 +322,6 @@ def do_int_quantization( config: WeightCompressionConfig, precomputed_scale: Tensor = None, precomputed_zero_point: Tensor = None, - invert_scale=False, ) -> Tuple[Tensor, Tensor, Tensor]: """ The method quantizes the given weights to integer data type uniformly in accordance with the compression config. @@ -351,8 +344,6 @@ def do_int_quantization( :param config: Information on how to compress (quantize) a specific weight. :param precomputed_scale: Precomputed scale. :param precomputed_zero_point: Precomputed zero point. - :param invert_scale: applies inversion for scale and then multiply by weights instead of division. - Need as reference implementation for OV. :return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type, scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization. """ @@ -380,7 +371,7 @@ def do_int_quantization( if precomputed_zero_point is not None: zero_point = precomputed_zero_point - compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point, invert_scale) + compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point) return compressed_weights, scale, zero_point diff --git a/nncf/tensor/functions/__init__.py b/nncf/tensor/functions/__init__.py index 5a286a6fc13..3c6f836ae46 100644 --- a/nncf/tensor/functions/__init__.py +++ b/nncf/tensor/functions/__init__.py @@ -46,10 +46,12 @@ from nncf.tensor.functions.numeric import minimum as minimum from nncf.tensor.functions.numeric import moveaxis as moveaxis from nncf.tensor.functions.numeric import multiply as multiply +from nncf.tensor.functions.numeric import ones as ones from nncf.tensor.functions.numeric import ones_like as ones_like from nncf.tensor.functions.numeric import percentile as percentile from nncf.tensor.functions.numeric import power as power from nncf.tensor.functions.numeric import quantile as quantile +from nncf.tensor.functions.numeric import reciprocal as reciprocal from nncf.tensor.functions.numeric import reshape as reshape from nncf.tensor.functions.numeric import round as round from nncf.tensor.functions.numeric import searchsorted as searchsorted diff --git a/nncf/tensor/functions/numeric.py b/nncf/tensor/functions/numeric.py index 061d1ee6e66..f0be76add57 100644 --- a/nncf/tensor/functions/numeric.py +++ b/nncf/tensor/functions/numeric.py @@ -818,6 +818,27 @@ def zeros( return Tensor(get_numeric_backend_fn("zeros", backend)(shape, dtype=dtype, device=device)) +def ones( + shape: Tuple[int, ...], + *, + backend: TensorBackend, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> Tensor: + """ + Return a new array of given shape and type, filled with ones. + + :param shape: Shape of the new array + :param backend: The backend type for which the ones tensor is required. + :param dtype: The data type of the returned tensor, If dtype is not given, + then the default data type is determined by backend. + :param device: The device on which the tensor will be allocated, If device is not given, + then the default device is determined by backend. + :return: A tensor filled with ones of the specified shape and data type. + """ + return Tensor(get_numeric_backend_fn("ones", backend)(shape, dtype=dtype, device=device)) + + def eye( n: int, m: Optional[int] = None, @@ -905,3 +926,17 @@ def ceil(a: Tensor) -> Tensor: :return: An array of the same type as a, containing the ceiling values. """ return Tensor(ceil(a.data)) + + +@functools.singledispatch +@tensor_guard +def reciprocal(a: Tensor) -> Tensor: + """ + Compute the reciprocal of a tensor or a float. + + This function returns a new tensor where each element is the reciprocal of the corresponding element in `a`. + + :param a: The input tensor or float. + :return: A tensor containing the reciprocal of each element in `a`. + """ + return Tensor(reciprocal(a.data)) diff --git a/nncf/tensor/functions/numpy_numeric.py b/nncf/tensor/functions/numpy_numeric.py index 2496882ddba..71025b9e34f 100644 --- a/nncf/tensor/functions/numpy_numeric.py +++ b/nncf/tensor/functions/numpy_numeric.py @@ -394,6 +394,19 @@ def zeros( return np.zeros(shape, dtype=dtype) +def ones( + shape: Tuple[int, ...], + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> np.ndarray: + if device is not None and device != TensorDeviceType.CPU: + raise ValueError("numpy_numeric.ones only supports CPU device.") + if dtype is not None: + dtype = DTYPE_MAP[dtype] + return np.ones(shape, dtype=dtype) + + def eye( n: int, m: Optional[int] = None, @@ -431,3 +444,8 @@ def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]: @register_numpy_types(numeric.ceil) def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.ceil(a) + + +@register_numpy_types(numeric.reciprocal) +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: + return np.reciprocal(a) diff --git a/nncf/tensor/functions/torch_numeric.py b/nncf/tensor/functions/torch_numeric.py index e3163d28aab..da2beb9f0d4 100644 --- a/nncf/tensor/functions/torch_numeric.py +++ b/nncf/tensor/functions/torch_numeric.py @@ -423,6 +423,19 @@ def zeros( return torch.zeros(*shape, dtype=dtype, device=device) +def ones( + shape: Tuple[int, ...], + *, + dtype: Optional[TensorDataType] = None, + device: Optional[TensorDeviceType] = None, +) -> torch.Tensor: + if dtype is not None: + dtype = DTYPE_MAP[dtype] + if device is not None: + device = DEVICE_MAP[device] + return torch.ones(*shape, dtype=dtype, device=device) + + def eye( n: int, m: Optional[int] = None, @@ -465,3 +478,8 @@ def _(a: torch.Tensor) -> torch.Tensor: @numeric.ceil.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return torch.ceil(a) + + +@numeric.reciprocal.register(torch.Tensor) +def _(a: torch.Tensor) -> torch.Tensor: + return torch.reciprocal(a) diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index 13f2d6bc976..ff6477513b3 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -1514,6 +1514,19 @@ def test_fn_zeros(self): assert tensor_a.shape == shape assert fns.all(tensor_a == 0) + def test_fn_ones(self): + shape = (2, 2) + for dtype in TensorDataType: + if dtype == TensorDataType.bfloat16 and self.backend() == TensorBackend.numpy: + continue + tensor_a = fns.ones(shape, backend=self.backend(), dtype=dtype, device=self.device()) + assert isinstance(tensor_a, Tensor) + assert tensor_a.device == self.device() + assert tensor_a.backend == self.backend() + assert tensor_a.dtype == dtype + assert tensor_a.shape == shape + assert fns.all(tensor_a == 1) + @pytest.mark.parametrize( "n, m, ref", ( @@ -1695,3 +1708,9 @@ def test_svd(self, a, full_matrices, abs_res_ref): for act, abs_ref in zip(res, abs_res_ref): assert isinstance(act, Tensor) assert fns.allclose(fns.abs(act), abs_ref, atol=1e-7) + + @pytest.mark.parametrize("a,ref", [([1], [1.0]), ([2, 4], [0.5, 0.25])]) + def test_reciprocal(self, a, ref): + t_a = Tensor(self.to_tensor(a)).astype(TensorDataType.float32) + res = fns.reciprocal(t_a) + assert fns.allclose(res, Tensor(self.to_tensor(ref)).astype(TensorDataType.float32)) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 7f8ccea21ec..1ad6b5ef01e 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -1078,32 +1078,47 @@ def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids): assert ref_e8m0_nodes == names_e8m0 -@pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM)) -def test_np_ov_compression_decompression(mode): - sz = 60 - w = np.arange(-sz, sz).reshape(2, sz).astype(np.float32) / 9.0 +@pytest.mark.parametrize("mode", [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]) +@pytest.mark.parametrize( + "w,s,zp", + [ + ( + np.array([[1.4372410774230957]], np.float32), + np.array([[-0.9581607580184937]], np.float32), + np.array([[1]], np.int32), + ), + (np.arange(-60, 60).reshape(2, 60).astype(np.float32) / 9.0, None, None), + ], +) +def test_np_ov_compression_decompression(mode, w, s, zp): w = Tensor(w) + if s is not None: + s = Tensor(s) + if mode == CompressWeightsMode.INT4_SYM: + zp = None + if zp is not None: + zp = Tensor(zp) config = WeightCompressionConfig(mode) - compressed_weighs, scale, zp = do_int_quantization(w, -1, config, invert_scale=True) - decompressed_weighs = do_int_dequantization(compressed_weighs, scale, zp) + compressed_weights, s, zp = do_int_quantization(w, -1, config, precomputed_scale=s, precomputed_zero_point=zp) + decompressed_weights = do_int_dequantization(compressed_weights, s, zp) - compressed_weighs = compressed_weighs.data - decompressed_weighs = decompressed_weighs.data + compressed_weights = compressed_weights.data + decompressed_weights = decompressed_weights.data zp_shape = zp.shape if zp is not None else None - compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, scale.shape, zp_shape) + compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, s.shape, zp_shape) compress_decompress = OVWeightCompressionAlgoBackend.get_compress_decompress_pipeline( - config, w.shape, scale.shape, zp_shape + config, w.shape, s.shape, zp_shape ) - params = [w.data, scale.data, zp.data] if zp is not None else [w.data, scale.data] - compressed_weighs_ov = compress(params) - decompressed_weighs_ov = compress_decompress(params) + params = [w.data, s.data, zp.data] if zp is not None else [w.data, s.data] + compressed_weights_ov = compress(params) + decompressed_weights_ov = compress_decompress(params) - assert np.allclose(compressed_weighs, compressed_weighs_ov) - assert np.allclose(decompressed_weighs, decompressed_weighs_ov) + assert np.allclose(compressed_weights, compressed_weights_ov, atol=0) + assert np.allclose(decompressed_weights, decompressed_weights_ov, atol=0) @pytest.mark.parametrize( diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index 6c48904c91a..b5517a43972 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -23,7 +23,7 @@ tinyllama_int8_data_free_backend_TORCH: num_int4: 0 num_int8: 312 tinyllama_data_aware_gptq_scale_estimation_stateful_backend_OV: - metric_value: 0.86503 + metric_value: 0.81880 num_int4: 94 num_int8: 124 metrics_xfail_reason: "Issue-148819"