Skip to content

Commit

Permalink
updated test references
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsu52 committed Oct 29, 2024
1 parent c3d75c2 commit 101d51a
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 19 deletions.
14 changes: 12 additions & 2 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def compress_weights(
)

options = {
"sensitivity_metric": sensitivity_metric,
"awq": awq,
"scale_estimation": scale_estimation,
"gptq": gptq,
Expand All @@ -523,6 +522,12 @@ def compress_weights(
f"Torch backend does not support {', '.join(unsupported_options)} option(s). Set them to None."
)

if sensitivity_metric not in [None, SensitivityMetric.WEIGHT_QUANTIZATION_ERROR]:
raise nncf.ParameterNotSupportedError(
"Torch backend only supports data-free sensitivity metric. "
"Set None or SensitivityMetric.WEIGHT_QUANTIZATION_ERROR."
)

if is_wrapped_model(model):
if not model.nncf.trace_parameters:
raise nncf.ValidationError(
Expand Down Expand Up @@ -550,7 +555,6 @@ def compress_weights(
)

options = {
"sensitivity_metric": sensitivity_metric,
"awq": awq,
"scale_estimation": scale_estimation,
"gptq": gptq,
Expand All @@ -562,6 +566,12 @@ def compress_weights(
f"TorchFX backend does not support {', '.join(unsupported_options)} option(s). Set them to None."
)

if sensitivity_metric not in [None, SensitivityMetric.WEIGHT_QUANTIZATION_ERROR]:
raise nncf.ParameterNotSupportedError(
"TorchFX backend only supports data-free sensitivity metric. "
"Set None or SensitivityMetric.WEIGHT_QUANTIZATION_ERROR."
)

if dataset:
raise nncf.ParameterNotSupportedError(
"TorchFX only supports data-free weights compression," "Set the 'dataset' option to None"
Expand Down
6 changes: 3 additions & 3 deletions tests/post_training/data/wc_reference_data_2024.5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
num_int4: 11
num_int8: 290
tinyllama_int4_data_free_backend_TORCH:
metric_value: 0.73541
num_int4: 308
num_int8: 4
metric_value: 0.73873
num_int4: 114
num_int8: 84
28 changes: 16 additions & 12 deletions tests/post_training/test_quantize_conformance.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def ref_data_correction(data: Dict, file_name: str):
with file_path.open() as f:
correction_data = yaml.safe_load(f)
for m_name, c_data in correction_data.items():
data[m_name].update(c_data)
if m_name in data:
data[m_name].update(c_data)
else:
data[m_name] = c_data
print(f"Applied correction file {file_path}")

return data
Expand All @@ -125,17 +128,18 @@ def fixture_wc_reference_data():
path_reference = DATA_ROOT / "wc_reference_data.yaml"
with path_reference.open() as f:
data = yaml.safe_load(f)
fp32_test_cases = defaultdict(dict)
for test_case_name in data:
if "atol" not in data[test_case_name]:
data[test_case_name]["atol"] = 1e-5
reported_name = test_case_name.split("_backend_")[0]
fp32_case_name = f"{reported_name}_backend_FP32"
fp32_test_cases[fp32_case_name]["metric_value"] = 1
if "atol" not in fp32_test_cases[fp32_case_name]:
fp32_test_cases[fp32_case_name]["atol"] = 1e-10
data.update(fp32_test_cases)
return ref_data_correction(data, "wc_reference_data")
data = ref_data_correction(data, "wc_reference_data")
fp32_test_cases = defaultdict(dict)
for test_case_name in data:
if "atol" not in data[test_case_name]:
data[test_case_name]["atol"] = 1e-5
reported_name = test_case_name.split("_backend_")[0]
fp32_case_name = f"{reported_name}_backend_FP32"
fp32_test_cases[fp32_case_name]["metric_value"] = 1
if "atol" not in fp32_test_cases[fp32_case_name]:
fp32_test_cases[fp32_case_name]["atol"] = 1e-10
data.update(fp32_test_cases)
return data


@pytest.fixture(scope="session", name="ptq_result_data")
Expand Down
3 changes: 2 additions & 1 deletion tests/torch/fx/test_compress_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nncf.quantization import compress_weights
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
from tests.torch.ptq.test_weights_compression import ALL_SENSITIVITY_METRICS
from tests.torch.ptq.test_weights_compression import DATA_BASED_SENSITIVITY_METRICS
from tests.torch.ptq.test_weights_compression import INT4_MODES
from tests.torch.ptq.test_weights_compression import INT8_MODES
from tests.torch.ptq.test_weights_compression import SUPPORTED_MODES
Expand Down Expand Up @@ -240,7 +241,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params):
@pytest.mark.parametrize(
"params",
(
*({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS),
*({"sensitivity_metric": metric} for metric in DATA_BASED_SENSITIVITY_METRICS),
{"gptq": True},
{"awq": True},
{"scale_estimation": True},
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params):
@pytest.mark.parametrize(
"params",
(
*({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS),
*({"sensitivity_metric": metric} for metric in DATA_BASED_SENSITIVITY_METRICS),
{"gptq": True},
{"awq": True},
{"scale_estimation": True},
Expand Down

0 comments on commit 101d51a

Please sign in to comment.