Skip to content

Commit

Permalink
[Torch] INT4 weight compression (#3014)
Browse files Browse the repository at this point in the history
### Changes
- Support INT4 weight compression in Torch and Torch.FX backends
- Added `INT4SymmetricWeightsDecompressor` and
`INT4ASymmetricWeightsDecompressor`

### Reason for changes

Support INT4 weight model compression of PyTorch models in NNCF. 

### Related tickets

#3005

### Tests

updated tests
  • Loading branch information
alexsu52 authored Oct 28, 2024
1 parent 94f1006 commit ef9cdd2
Show file tree
Hide file tree
Showing 15 changed files with 586 additions and 239 deletions.
54 changes: 32 additions & 22 deletions nncf/quantization/algorithms/weight_compression/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
from nncf.torch.model_graph_manager import split_const_name
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor


class PTWeightCompressionAlgoBackend(WeightCompressionAlgoBackend):
Expand Down Expand Up @@ -212,10 +214,9 @@ def transform_model(

for wc_params in weight_compression_parameters:
compression_config = wc_params.compression_config
if compression_config.mode not in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT8,
if compression_config.mode in [
CompressWeightsMode.NF4,
CompressWeightsMode.E2M1,
]:
raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.")

Expand All @@ -235,17 +236,35 @@ def transform_model(
None if precomputed_scales is None else precomputed_scales.get(wc_params.weight_name),
None if precomputed_zero_points is None else precomputed_zero_points.get(wc_params.weight_name),
)
compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16)

# pack compressed tensor
# creates weight decompressor
if compression_config.mode == CompressWeightsMode.INT8_SYM:
dtype = TensorDataType.int8
else:
dtype = TensorDataType.uint8
packed_tensor = compressed_weight.tensor.astype(dtype)
decompressor = INT8SymmetricWeightsDecompressor(compressed_weight.scale.data, result_dtype=weight.dtype)
elif compression_config.mode == CompressWeightsMode.INT8_ASYM:
decompressor = INT8AsymmetricWeightsDecompressor(
compressed_weight.scale.data, compressed_weight.zero_point.data, result_dtype=weight.dtype
)
elif compression_config.mode == CompressWeightsMode.INT4_SYM:
decompressor = INT4SymmetricWeightsDecompressor(
scale=compressed_weight.scale.data,
compressed_weight_shape=compressed_weight.tensor.shape,
result_shape=weight.shape,
result_dtype=weight.dtype,
)
elif compression_config.mode == CompressWeightsMode.INT4_ASYM:
decompressor = INT4AsymmetricWeightsDecompressor(
scale=compressed_weight.scale.data,
zero_point=compressed_weight.zero_point.data,
compressed_weight_shape=compressed_weight.tensor.shape,
result_shape=weight.shape,
result_dtype=weight.dtype,
)

# pack tensor
packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data)

# sets compressed tensor
compressed_parameter = torch.nn.Parameter(packed_tensor.data, requires_grad=False)
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
setattr(module, weight_attr_name, compressed_parameter)

consumer_nodes = graph.get_next_nodes(weight_node)
Expand All @@ -256,15 +275,6 @@ def transform_model(
if id(param) == id(weight):
setattr(c_module, name, compressed_parameter)

# creates weight decompressor
if compression_config.mode == CompressWeightsMode.INT8_SYM:
decompressor = SymmetricWeightsDecompressor(compressed_weight.scale.data, result_dtype=weight.dtype)
else:
packed_zero_point = compressed_weight.zero_point.astype(dtype)
decompressor = AsymmetricWeightsDecompressor(
compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.dtype
)

# registry weight decompression module in the model
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"

Expand Down
56 changes: 33 additions & 23 deletions nncf/quantization/algorithms/weight_compression/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor


class FXWeightCompressionAlgoBackend(WeightCompressionAlgoBackend):
Expand Down Expand Up @@ -176,10 +178,9 @@ def transform_model(

for wc_params in weight_compression_parameters:
compression_config = wc_params.compression_config
if compression_config.mode not in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT8,
if compression_config.mode in [
CompressWeightsMode.NF4,
CompressWeightsMode.E2M1,
]:
raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.")
weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
Expand All @@ -196,35 +197,44 @@ def transform_model(
None if precomputed_scales is None else precomputed_scales.get(wc_params.weight_name),
None if precomputed_zero_points is None else precomputed_zero_points.get(wc_params.weight_name),
)
compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16)

# pack compressed tensor
if compression_config.mode == CompressWeightsMode.INT8_SYM:
dtype = TensorDataType.int8
else:
dtype = TensorDataType.uint8
packed_tensor = compressed_weight.tensor.astype(dtype)

self.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, packed_tensor)
# creates weight decompressor
if compression_config.mode == CompressWeightsMode.INT8_SYM:
decompressor = SymmetricWeightsDecompressor(
decompressor = INT8SymmetricWeightsDecompressor(
compressed_weight.scale.data, result_dtype=weight.data.dtype
)
decompressor_type = "symmetric"
else:
packed_zero_point = compressed_weight.zero_point.astype(dtype)
decompressor = AsymmetricWeightsDecompressor(
compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.data.dtype
elif compression_config.mode == CompressWeightsMode.INT8_ASYM:
decompressor = INT8AsymmetricWeightsDecompressor(
compressed_weight.scale.data, compressed_weight.zero_point.data, result_dtype=weight.data.dtype
)
elif compression_config.mode == CompressWeightsMode.INT4_SYM:
decompressor = INT4SymmetricWeightsDecompressor(
scale=compressed_weight.scale.data,
compressed_weight_shape=compressed_weight.tensor.shape,
result_shape=weight.shape,
result_dtype=weight.data.dtype,
)
decompressor_type = "asymmetric"
elif compression_config.mode == CompressWeightsMode.INT4_ASYM:
decompressor = INT4AsymmetricWeightsDecompressor(
scale=compressed_weight.scale.data,
zero_point=compressed_weight.zero_point.data,
compressed_weight_shape=compressed_weight.tensor.shape,
result_shape=weight.shape,
result_dtype=weight.data.dtype,
)

# pack tensor
packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data)

# sets compressed tensor
self.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, packed_tensor)

# register weight decompression module in the model
graph_weight_node = get_graph_node_by_name(model.graph, wc_params.node_with_weight.node_name)
compressed_weight_name = graph_weight_node.all_input_nodes[wc_params.weight_port_id].name

decompressor_suffix = "_".join(compressed_weight_name.replace(".", "_").split("_")[:-2])
decompressor_name = f"{decompressor_type}_weights_decompressor_{decompressor_suffix}"
decompressor_name = f"{decompressor.quantization_mode}_weights_decompressor_{decompressor_suffix}"

# inserts the weight decompressor into the model as the post hook on the model weight
transformation_layout.register(
Expand Down
47 changes: 30 additions & 17 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,20 +505,26 @@ def compress_weights(
from nncf.torch.model_creation import wrap_model
from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl

if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]:
if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]:
raise nncf.ParameterNotSupportedError(
"Torch backend supports only INT8_ASYM, INT8_SYM modes for weight compression, "
f"but given {mode.value} mode."
"Torch backend does not support NF4 and E2M1 modes for weight compression."
)

if True in [awq, scale_estimation, gptq, lora_correction]:
options = {
"sensitivity_metric": sensitivity_metric,
"awq": awq,
"scale_estimation": scale_estimation,
"gptq": gptq,
"lora_correction": lora_correction,
}
unsupported_options = [name for name, value in options.items() if value is not None]
if unsupported_options:
raise nncf.ParameterNotSupportedError(
"Torch backend does not support 'awq', 'scale_estimation', 'gptq' and 'lora_correction' options. "
"Set them to None."
f"Torch backend does not support {', '.join(unsupported_options)} option(s). Set them to None."
)

if backup_mode is not None:
raise nncf.ParameterNotSupportedError("Torch backend does not support backup_mode option.")
if ratio is not None and ratio != 1:
raise nncf.ParameterNotSupportedError("Torch backend does not support ratio != 1.")

if is_wrapped_model(model):
if not model.nncf.trace_parameters:
Expand All @@ -541,20 +547,27 @@ def compress_weights(
compress_weights_impl as fx_compression_weights_impl,
)

if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]:
if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]:
raise nncf.ParameterNotSupportedError(
"TorchFX backend supports only INT8_ASYM, INT8_SYM modes for weight compression, "
f"but given {mode.value} mode."
"Torch backend does not support NF4 and E2M1 modes for weight compression."
)

if backup_mode is not None:
raise nncf.ParameterNotSupportedError("TorchFX backend does not support backup_mode option.")

if any((awq, scale_estimation, gptq, lora_correction)):
options = {
"sensitivity_metric": sensitivity_metric,
"awq": awq,
"scale_estimation": scale_estimation,
"gptq": gptq,
"lora_correction": lora_correction,
}
unsupported_options = [name for name, value in options.items() if value is not None]
if unsupported_options:
raise nncf.ParameterNotSupportedError(
"TorchFX backend does not support 'awq', 'scale_estimation', 'gptq',"
"and 'lora_correction' options. Set them to None."
f"TorchFX backend does not support {', '.join(unsupported_options)} option(s). Set them to None."
)

if ratio is not None and ratio != 1:
raise nncf.ParameterNotSupportedError("TorchFX backend does not support ratio != 1.")

if dataset:
raise nncf.ParameterNotSupportedError(
"TorchFX only supports data-free weights compression," "Set the 'dataset' option to None"
Expand Down
Loading

0 comments on commit ef9cdd2

Please sign in to comment.