From 21703a0edcd21a6e5dea02624631b22c82902c04 Mon Sep 17 00:00:00 2001 From: Aleksandr Suslov Date: Thu, 31 Oct 2024 13:05:17 +0400 Subject: [PATCH] save/load of compressed model in Torch backend --- nncf/torch/layer_utils.py | 3 +- nncf/torch/quantization/layers.py | 82 ++++++++++++++++++--- tests/torch/ptq/test_weights_compression.py | 36 +++++++++ 3 files changed, 111 insertions(+), 10 deletions(-) diff --git a/nncf/torch/layer_utils.py b/nncf/torch/layer_utils.py index fb7d7bed79f..5f9ba8a9a32 100644 --- a/nncf/torch/layer_utils.py +++ b/nncf/torch/layer_utils.py @@ -44,7 +44,8 @@ def get_config(self) -> Dict[str, Any]: Returns the compression module config. """ - @abstractclassmethod + @classmethod + @abstractmethod def from_config(cls, state: Dict[str, Any]) -> object: """ Creates a compression module instance from the given config. diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 85e41428846..72c50cd4b0a 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -1049,7 +1049,7 @@ def get_scale_shape(input_shape: List[int], is_weights: bool, per_channel: bool, return get_per_channel_scale_shape(input_shape, is_weights, channel_idx) -class BaseWeightsDecompressor(nn.Module, ABC): +class BaseWeightsDecompressor(nn.Module, StatefullModuleInterface, ABC): """ Base class for implementing weights decompression modules within NNCF. @@ -1081,6 +1081,7 @@ def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: """ +@COMPRESSION_MODULES.register() class INT8AsymmetricWeightsDecompressor(BaseWeightsDecompressor): """ Applies asymmetric decompression of compressed weights in the forward pass @@ -1103,9 +1104,9 @@ def quantization_mode(self) -> QuantizationMode: def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: if torch.is_floating_point(weight): - raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.") + raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.") if torch.any((weight < 0) | (weight > 255)): - raise ValueError("Weight values are not in [0, 255].") + raise nncf.ValidationError("Weight values are not in [0, 255].") return weight.type(dtype=torch.uint8) def forward(self, x) -> torch.Tensor: @@ -1113,7 +1114,22 @@ def forward(self, x) -> torch.Tensor: result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result + def get_config(self) -> Dict[str, Any]: + return { + "scale_shape": self._scale.shape, + "zero_point_shape": self._zero_point.shape, + "result_dtype": self.result_dtype if self.result_dtype is not None else "", + } + + @classmethod + def from_config(cls, state: Dict[str, Any]) -> object: + scale = torch.ones(state["scale_shape"], dtype=torch.float16) + zero_point = torch.zeros(state["zero_point_shape"], dtype=torch.uint8) + result_dtype = state["result_dtype"] if state["result_dtype"] else None + return cls(scale, zero_point, result_dtype) + +@COMPRESSION_MODULES.register() class INT8SymmetricWeightsDecompressor(BaseWeightsDecompressor): """ Applies symmetric decompression of compressed weights in the forward pass @@ -1134,9 +1150,9 @@ def quantization_mode(self) -> QuantizationMode: def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: if torch.is_floating_point(weight): - raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.") + raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.") if torch.any((weight < -128) | (weight > 127)): - raise ValueError("Weight values are not in [-128, 127].") + raise nncf.ValidationError("Weight values are not in [-128, 127].") return weight.type(dtype=torch.int8) def forward(self, x) -> torch.Tensor: @@ -1144,7 +1160,20 @@ def forward(self, x) -> torch.Tensor: result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result + def get_config(self) -> Dict[str, Any]: + return { + "scale_shape": self._scale.shape, + "result_dtype": self.result_dtype if self.result_dtype is not None else "", + } + + @classmethod + def from_config(cls, state: Dict[str, Any]) -> object: + scale = torch.ones(state["scale_shape"], dtype=torch.float16) + result_dtype = state["result_dtype"] if state["result_dtype"] else None + return cls(scale, result_dtype) + +@COMPRESSION_MODULES.register() class INT4AsymmetricWeightsDecompressor(BaseWeightsDecompressor): def __init__( self, @@ -1177,9 +1206,9 @@ def quantization_mode(self) -> QuantizationMode: def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: if torch.is_floating_point(weight): - raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.") + raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.") if torch.any((weight < 0) | (weight > 15)): - raise ValueError("Weight values are not in [0, 15].") + raise nncf.ValidationError("Weight values are not in [0, 15].") return pack_uint4(weight.type(dtype=torch.uint8)) def forward(self, x): @@ -1194,7 +1223,26 @@ def forward(self, x): result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result + def get_config(self) -> Dict[str, Any]: + return { + "scale_shape": self._scale.shape, + "zero_point_shape": self.zero_point_shape, + "compressed_weight_shape": self.compressed_weight_shape, + "result_shape": self.result_shape if self.result_shape is not None else "", + "result_dtype": self.result_dtype if self.result_dtype is not None else "", + } + + @classmethod + def from_config(cls, state: Dict[str, Any]) -> object: + scale = torch.ones(state["scale_shape"], dtype=torch.float16) + zero_point = torch.zeros(state["zero_point_shape"], dtype=torch.uint8) + compressed_weight_shape = state["compressed_weight_shape"] + result_shape = state["result_shape"] if state["result_shape"] else None + result_dtype = state["result_dtype"] if state["result_dtype"] else None + return cls(scale, zero_point, compressed_weight_shape, result_shape, result_dtype) + +@COMPRESSION_MODULES.register() class INT4SymmetricWeightsDecompressor(BaseWeightsDecompressor): def __init__( self, @@ -1222,9 +1270,9 @@ def quantization_mode(self) -> QuantizationMode: def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: if torch.is_floating_point(weight): - raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.") + raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.") if torch.any((weight < -8) | (weight > 7)): - raise ValueError("Tensor values are not in [-8, 7].") + raise nncf.ValidationError("Tensor values are not in [-8, 7].") return pack_int4(weight.type(dtype=torch.int8)) def forward(self, x): @@ -1235,3 +1283,19 @@ def forward(self, x): result = result.reshape(self.result_shape) if self.result_shape is not None else result result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result + + def get_config(self) -> Dict[str, Any]: + return { + "scale_shape": self._scale.shape, + "compressed_weight_shape": self.compressed_weight_shape, + "result_shape": self.result_shape if self.result_shape is not None else "", + "result_dtype": self.result_dtype if self.result_dtype is not None else "", + } + + @classmethod + def from_config(cls, state: Dict[str, Any]) -> object: + scale = torch.ones(state["scale_shape"], dtype=torch.float16) + compressed_weight_shape = state["compressed_weight_shape"] + result_shape = state["result_shape"] if state["result_shape"] else None + result_dtype = state["result_dtype"] if state["result_dtype"] else None + return cls(scale, compressed_weight_shape, result_shape, result_dtype) diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 2e902e1af50..6420d7baba1 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -333,3 +333,39 @@ def test_pack_int4(): assert packed_w.numel() * 2 == w_int8.numel() unpacked_w = unpack_int4(packed_w).reshape(w_int8.shape) assert torch.all(unpacked_w == w_int8) + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_save_load(mode, tmp_path): + model = ShortTransformer(8, 16) + input_ids = torch.randint(0, 10, (8,)) + wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) + + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs) + + state_dict = compressed_model.state_dict() + compression_config = compressed_model.nncf.get_config() + + ckpt_path = tmp_path / f"{mode}_model.pt" + torch.save( + { + "model_state_dict": state_dict, + "compression_config": compression_config, + }, + ckpt_path, + ) + + compressed_result = compressed_model(input_ids) + + restored_model = ShortTransformer(8, 16) + + ckpt = torch.load(ckpt_path) + restored_model = nncf.torch.load_from_config(restored_model, ckpt["compression_config"], input_ids) + restored_model.load_state_dict(ckpt["model_state_dict"]) + + restored_compressed_result = restored_model(input_ids) + + assert torch.allclose(compressed_result, restored_compressed_result)