Skip to content

Commit

Permalink
save/load of compressed model in Torch backend
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsu52 committed Oct 31, 2024
1 parent 51a7fb6 commit 21703a0
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 10 deletions.
3 changes: 2 additions & 1 deletion nncf/torch/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def get_config(self) -> Dict[str, Any]:
Returns the compression module config.
"""

@abstractclassmethod
@classmethod
@abstractmethod
def from_config(cls, state: Dict[str, Any]) -> object:
"""
Creates a compression module instance from the given config.
Expand Down
82 changes: 73 additions & 9 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ def get_scale_shape(input_shape: List[int], is_weights: bool, per_channel: bool,
return get_per_channel_scale_shape(input_shape, is_weights, channel_idx)


class BaseWeightsDecompressor(nn.Module, ABC):
class BaseWeightsDecompressor(nn.Module, StatefullModuleInterface, ABC):
"""
Base class for implementing weights decompression modules within NNCF.
Expand Down Expand Up @@ -1081,6 +1081,7 @@ def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
"""


@COMPRESSION_MODULES.register()
class INT8AsymmetricWeightsDecompressor(BaseWeightsDecompressor):
"""
Applies asymmetric decompression of compressed weights in the forward pass
Expand All @@ -1103,17 +1104,32 @@ def quantization_mode(self) -> QuantizationMode:

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
if torch.any((weight < 0) | (weight > 255)):
raise ValueError("Weight values are not in [0, 255].")
raise nncf.ValidationError("Weight values are not in [0, 255].")
return weight.type(dtype=torch.uint8)

def forward(self, x) -> torch.Tensor:
result = decompress_asymmetric(x, self._scale, self._zero_point)
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

def get_config(self) -> Dict[str, Any]:
return {
"scale_shape": self._scale.shape,
"zero_point_shape": self._zero_point.shape,
"result_dtype": self.result_dtype if self.result_dtype is not None else "",
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> object:
scale = torch.ones(state["scale_shape"], dtype=torch.float16)
zero_point = torch.zeros(state["zero_point_shape"], dtype=torch.uint8)
result_dtype = state["result_dtype"] if state["result_dtype"] else None
return cls(scale, zero_point, result_dtype)


@COMPRESSION_MODULES.register()
class INT8SymmetricWeightsDecompressor(BaseWeightsDecompressor):
"""
Applies symmetric decompression of compressed weights in the forward pass
Expand All @@ -1134,17 +1150,30 @@ def quantization_mode(self) -> QuantizationMode:

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
if torch.any((weight < -128) | (weight > 127)):
raise ValueError("Weight values are not in [-128, 127].")
raise nncf.ValidationError("Weight values are not in [-128, 127].")
return weight.type(dtype=torch.int8)

def forward(self, x) -> torch.Tensor:
result = decompress_symmetric(x, self._scale)
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

def get_config(self) -> Dict[str, Any]:
return {
"scale_shape": self._scale.shape,
"result_dtype": self.result_dtype if self.result_dtype is not None else "",
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> object:
scale = torch.ones(state["scale_shape"], dtype=torch.float16)
result_dtype = state["result_dtype"] if state["result_dtype"] else None
return cls(scale, result_dtype)


@COMPRESSION_MODULES.register()
class INT4AsymmetricWeightsDecompressor(BaseWeightsDecompressor):
def __init__(
self,
Expand Down Expand Up @@ -1177,9 +1206,9 @@ def quantization_mode(self) -> QuantizationMode:

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
if torch.any((weight < 0) | (weight > 15)):
raise ValueError("Weight values are not in [0, 15].")
raise nncf.ValidationError("Weight values are not in [0, 15].")
return pack_uint4(weight.type(dtype=torch.uint8))

def forward(self, x):
Expand All @@ -1194,7 +1223,26 @@ def forward(self, x):
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

def get_config(self) -> Dict[str, Any]:
return {
"scale_shape": self._scale.shape,
"zero_point_shape": self.zero_point_shape,
"compressed_weight_shape": self.compressed_weight_shape,
"result_shape": self.result_shape if self.result_shape is not None else "",
"result_dtype": self.result_dtype if self.result_dtype is not None else "",
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> object:
scale = torch.ones(state["scale_shape"], dtype=torch.float16)
zero_point = torch.zeros(state["zero_point_shape"], dtype=torch.uint8)
compressed_weight_shape = state["compressed_weight_shape"]
result_shape = state["result_shape"] if state["result_shape"] else None
result_dtype = state["result_dtype"] if state["result_dtype"] else None
return cls(scale, zero_point, compressed_weight_shape, result_shape, result_dtype)


@COMPRESSION_MODULES.register()
class INT4SymmetricWeightsDecompressor(BaseWeightsDecompressor):
def __init__(
self,
Expand Down Expand Up @@ -1222,9 +1270,9 @@ def quantization_mode(self) -> QuantizationMode:

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
if torch.any((weight < -8) | (weight > 7)):
raise ValueError("Tensor values are not in [-8, 7].")
raise nncf.ValidationError("Tensor values are not in [-8, 7].")
return pack_int4(weight.type(dtype=torch.int8))

def forward(self, x):
Expand All @@ -1235,3 +1283,19 @@ def forward(self, x):
result = result.reshape(self.result_shape) if self.result_shape is not None else result
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

def get_config(self) -> Dict[str, Any]:
return {
"scale_shape": self._scale.shape,
"compressed_weight_shape": self.compressed_weight_shape,
"result_shape": self.result_shape if self.result_shape is not None else "",
"result_dtype": self.result_dtype if self.result_dtype is not None else "",
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> object:
scale = torch.ones(state["scale_shape"], dtype=torch.float16)
compressed_weight_shape = state["compressed_weight_shape"]
result_shape = state["result_shape"] if state["result_shape"] else None
result_dtype = state["result_dtype"] if state["result_dtype"] else None
return cls(scale, compressed_weight_shape, result_shape, result_dtype)
36 changes: 36 additions & 0 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,39 @@ def test_pack_int4():
assert packed_w.numel() * 2 == w_int8.numel()
unpacked_w = unpack_int4(packed_w).reshape(w_int8.shape)
assert torch.all(unpacked_w == w_int8)


@pytest.mark.parametrize("mode", SUPPORTED_MODES)
def test_save_load(mode, tmp_path):
model = ShortTransformer(8, 16)
input_ids = torch.randint(0, 10, (8,))
wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True)

kwargs = {}
if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]:
kwargs["group_size"] = 4
compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs)

state_dict = compressed_model.state_dict()
compression_config = compressed_model.nncf.get_config()

ckpt_path = tmp_path / f"{mode}_model.pt"
torch.save(
{
"model_state_dict": state_dict,
"compression_config": compression_config,
},
ckpt_path,
)

compressed_result = compressed_model(input_ids)

restored_model = ShortTransformer(8, 16)

ckpt = torch.load(ckpt_path)
restored_model = nncf.torch.load_from_config(restored_model, ckpt["compression_config"], input_ids)
restored_model.load_state_dict(ckpt["model_state_dict"])

restored_compressed_result = restored_model(input_ids)

assert torch.allclose(compressed_result, restored_compressed_result)

0 comments on commit 21703a0

Please sign in to comment.