diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index a0e9922f24..fd0548c99a 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -511,7 +511,6 @@ def compress_weights( ) options = { - "sensitivity_metric": sensitivity_metric, "awq": awq, "scale_estimation": scale_estimation, "gptq": gptq, @@ -523,6 +522,12 @@ def compress_weights( f"Torch backend does not support {', '.join(unsupported_options)} option(s). Set them to None." ) + if sensitivity_metric not in [None, SensitivityMetric.WEIGHT_QUANTIZATION_ERROR]: + raise nncf.ParameterNotSupportedError( + "Torch backend only supports data-free sensitivity metric. " + "Set None or SensitivityMetric.WEIGHT_QUANTIZATION_ERROR." + ) + if is_wrapped_model(model): if not model.nncf.trace_parameters: raise nncf.ValidationError( @@ -550,7 +555,6 @@ def compress_weights( ) options = { - "sensitivity_metric": sensitivity_metric, "awq": awq, "scale_estimation": scale_estimation, "gptq": gptq, @@ -562,6 +566,12 @@ def compress_weights( f"TorchFX backend does not support {', '.join(unsupported_options)} option(s). Set them to None." ) + if sensitivity_metric not in [None, SensitivityMetric.WEIGHT_QUANTIZATION_ERROR]: + raise nncf.ParameterNotSupportedError( + "TorchFX backend only supports data-free sensitivity metric. " + "Set None or SensitivityMetric.WEIGHT_QUANTIZATION_ERROR." + ) + if dataset: raise nncf.ParameterNotSupportedError( "TorchFX only supports data-free weights compression," "Set the 'dataset' option to None" diff --git a/tests/post_training/data/wc_reference_data_2024.5.yaml b/tests/post_training/data/wc_reference_data_2024.5.yaml index ee5b1ffdad..bd263305a7 100644 --- a/tests/post_training/data/wc_reference_data_2024.5.yaml +++ b/tests/post_training/data/wc_reference_data_2024.5.yaml @@ -8,6 +8,6 @@ tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV: num_int4: 11 num_int8: 290 tinyllama_int4_data_free_backend_TORCH: - metric_value: 0.73541 - num_int4: 308 - num_int8: 4 + metric_value: 0.73873 + num_int4: 114 + num_int8: 84 diff --git a/tests/post_training/test_quantize_conformance.py b/tests/post_training/test_quantize_conformance.py index 2ea880fde3..20504a8b08 100644 --- a/tests/post_training/test_quantize_conformance.py +++ b/tests/post_training/test_quantize_conformance.py @@ -106,7 +106,10 @@ def ref_data_correction(data: Dict, file_name: str): with file_path.open() as f: correction_data = yaml.safe_load(f) for m_name, c_data in correction_data.items(): - data[m_name].update(c_data) + if m_name in data: + data[m_name].update(c_data) + else: + data[m_name] = c_data print(f"Applied correction file {file_path}") return data @@ -125,17 +128,18 @@ def fixture_wc_reference_data(): path_reference = DATA_ROOT / "wc_reference_data.yaml" with path_reference.open() as f: data = yaml.safe_load(f) - fp32_test_cases = defaultdict(dict) - for test_case_name in data: - if "atol" not in data[test_case_name]: - data[test_case_name]["atol"] = 1e-5 - reported_name = test_case_name.split("_backend_")[0] - fp32_case_name = f"{reported_name}_backend_FP32" - fp32_test_cases[fp32_case_name]["metric_value"] = 1 - if "atol" not in fp32_test_cases[fp32_case_name]: - fp32_test_cases[fp32_case_name]["atol"] = 1e-10 - data.update(fp32_test_cases) - return ref_data_correction(data, "wc_reference_data") + data = ref_data_correction(data, "wc_reference_data") + fp32_test_cases = defaultdict(dict) + for test_case_name in data: + if "atol" not in data[test_case_name]: + data[test_case_name]["atol"] = 1e-5 + reported_name = test_case_name.split("_backend_")[0] + fp32_case_name = f"{reported_name}_backend_FP32" + fp32_test_cases[fp32_case_name]["metric_value"] = 1 + if "atol" not in fp32_test_cases[fp32_case_name]: + fp32_test_cases[fp32_case_name]["atol"] = 1e-10 + data.update(fp32_test_cases) + return data @pytest.fixture(scope="session", name="ptq_result_data") diff --git a/tests/torch/fx/test_compress_weights.py b/tests/torch/fx/test_compress_weights.py index 2cc1768afa..835398bd57 100644 --- a/tests/torch/fx/test_compress_weights.py +++ b/tests/torch/fx/test_compress_weights.py @@ -24,6 +24,7 @@ from nncf.quantization import compress_weights from nncf.torch.dynamic_graph.patch_pytorch import disable_patching from tests.torch.ptq.test_weights_compression import ALL_SENSITIVITY_METRICS +from tests.torch.ptq.test_weights_compression import DATA_BASED_SENSITIVITY_METRICS from tests.torch.ptq.test_weights_compression import INT4_MODES from tests.torch.ptq.test_weights_compression import INT8_MODES from tests.torch.ptq.test_weights_compression import SUPPORTED_MODES @@ -240,7 +241,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): @pytest.mark.parametrize( "params", ( - *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), + *({"sensitivity_metric": metric} for metric in DATA_BASED_SENSITIVITY_METRICS), {"gptq": True}, {"awq": True}, {"scale_estimation": True}, diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 3735baa120..2e902e1af5 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -250,7 +250,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): @pytest.mark.parametrize( "params", ( - *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), + *({"sensitivity_metric": metric} for metric in DATA_BASED_SENSITIVITY_METRICS), {"gptq": True}, {"awq": True}, {"scale_estimation": True},