diff --git a/src/compressed_tensors/base.py b/src/compressed_tensors/base.py index d096bc86..65803e5c 100644 --- a/src/compressed_tensors/base.py +++ b/src/compressed_tensors/base.py @@ -13,4 +13,5 @@ # limitations under the License. SPARSITY_CONFIG_NAME = "sparsity_config" -QUANTIZATION_CONFIG_NAME = "sparseml_quantization_config" +QUANTIZATION_CONFIG_NAME = "quantization_config" +COMPRESSION_CONFIG_NAME = "compression_config" diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index 42724967..17acadb9 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -14,7 +14,9 @@ # flake8: noqa -from .base import ModelCompressor +from .base import Compressor from .dense import DenseCompressor from .helpers import load_compressed, save_compressed, save_compressed_model +from .int_quantized import IntQuantizationCompressor +from .model_compressor import ModelCompressor from .sparse_bitmask import BitmaskCompressor, BitmaskTensor diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index e30492d0..a0ceef74 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -12,56 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator -from typing import Dict, Generator, Optional, Tuple +from typing import Dict, Generator, Tuple, Union -from compressed_tensors.base import SPARSITY_CONFIG_NAME -from compressed_tensors.config import CompressionConfig +from compressed_tensors.config import SparsityCompressionConfig +from compressed_tensors.quantization import QuantizationConfig from compressed_tensors.registry import RegistryMixin -from compressed_tensors.utils import get_safetensors_folder from torch import Tensor -from torch.nn import Module, Parameter -from tqdm import tqdm -from transformers import AutoConfig -__all__ = ["ModelCompressor"] +__all__ = ["Compressor"] -class ModelCompressor(RegistryMixin): +class Compressor(RegistryMixin): """ - Base class representing a model compression algorithm. + Base class representing a model compression algorithm :param config: config specifying compression parameters """ - @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path: str - ) -> Optional["ModelCompressor"]: - """ - Given a path to a model config, extract a sparsity config if it exists and - return the associated ModelCompressor - - :param pretrained_model_name_or_path: path to model config on disk or HF hub - :return: matching compressor if config contains a sparsity config - """ - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) - if sparsity_config is None: - return None - - format = sparsity_config.get("format") - sparsity_config = CompressionConfig.load_from_registry( - format, **sparsity_config - ) - compressor = cls.load_from_registry(format, config=sparsity_config) - return compressor - - def __init__(self, config: Optional[CompressionConfig] = None): + def __init__( + self, config: Union[SparsityCompressionConfig, QuantizationConfig, None] = None + ): self.config = config - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: """ Compresses a dense state dict @@ -83,21 +57,3 @@ def decompress( :return: compressed state dict """ raise NotImplementedError() - - def overwrite_weights(self, model_path: str, model: Module): - """ - Overwrites the weights in model with weights decompressed from model_path - - :param model_path: path to compressed weights - :param model: pytorch model to load decompressed weights into - """ - model_path = get_safetensors_folder(model_path) - dense_gen = self.decompress(model_path) - for name, data in tqdm(dense_gen, desc="Decompressing model"): - # loading the decompressed weights into the model - model_device = operator.attrgetter(name)(model).device - data_new = Parameter(data.to(model_device)) - data_old = operator.attrgetter(name)(model) - data_old.data = data_new.data - - setattr(model, SPARSITY_CONFIG_NAME, self.config) diff --git a/src/compressed_tensors/compressors/dense.py b/src/compressed_tensors/compressors/dense.py index d6319980..8f09c8bf 100644 --- a/src/compressed_tensors/compressors/dense.py +++ b/src/compressed_tensors/compressors/dense.py @@ -14,18 +14,18 @@ from typing import Dict, Generator, Tuple -from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.compressors import Compressor from compressed_tensors.config import CompressionFormat from torch import Tensor -@ModelCompressor.register(name=CompressionFormat.dense_sparsity.value) -class DenseCompressor(ModelCompressor): +@Compressor.register(name=CompressionFormat.dense.value) +class DenseCompressor(Compressor): """ Identity compressor for dense models, returns the original state_dict """ - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: + def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: return model_state def decompress( diff --git a/src/compressed_tensors/compressors/helpers.py b/src/compressed_tensors/compressors/helpers.py index 64ccfb3d..fe4b361c 100644 --- a/src/compressed_tensors/compressors/helpers.py +++ b/src/compressed_tensors/compressors/helpers.py @@ -16,8 +16,8 @@ from typing import Dict, Generator, Optional, Tuple, Union import torch -from compressed_tensors.compressors import ModelCompressor -from compressed_tensors.config import CompressionConfig, CompressionFormat +from compressed_tensors.compressors import Compressor +from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.utils.safetensors_load import get_weight_mappings from safetensors import safe_open from safetensors.torch import save_file @@ -48,20 +48,20 @@ def save_compressed( if tensors is None or len(tensors) == 0: raise ValueError("No tensors or empty tensors provided to compress") - # if no compression_format specified, default to `dense_sparsity` - compression_format = compression_format or CompressionFormat.dense_sparsity.value + # if no compression_format specified, default to `dense` + compression_format = compression_format or CompressionFormat.dense.value if not ( - compression_format in ModelCompressor.registered_names() - or compression_format in ModelCompressor.registered_aliases() + compression_format in Compressor.registered_names() + or compression_format in Compressor.registered_aliases() ): raise ValueError( f"Unknown compression format: {compression_format}. " - f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501 + f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501 ) # compress - compressor = ModelCompressor.load_from_registry(compression_format) + compressor = Compressor.load_from_registry(compression_format) # save compressed tensors compressed_tensors = compressor.compress(tensors) save_file(compressed_tensors, save_path) @@ -69,7 +69,7 @@ def save_compressed( def load_compressed( compressed_tensors: Union[str, Path], - compression_config: CompressionConfig = None, + compression_config: SparsityCompressionConfig = None, device: Optional[str] = "cpu", ) -> Generator[Tuple[str, Tensor], None, None]: """ @@ -90,9 +90,9 @@ def load_compressed( if ( compression_config is None - or compression_config.format == CompressionFormat.dense_sparsity.value + or compression_config.format == CompressionFormat.dense.value ): - # if no compression_config specified, or `dense_sparsity` format specified, + # if no compression_config specified, or `dense` format specified, # assume tensors are not compressed on disk weight_mappings = get_weight_mappings(compressed_tensors) for weight_name, file_with_weight_name in weight_mappings.items(): @@ -102,7 +102,7 @@ def load_compressed( else: # decompress tensors compression_format = compression_config.format - compressor = ModelCompressor.load_from_registry( + compressor = Compressor.load_from_registry( compression_format, config=compression_config ) yield from compressor.decompress(compressed_tensors, device=device) diff --git a/src/compressed_tensors/compressors/int_quantized.py b/src/compressed_tensors/compressors/int_quantized.py new file mode 100644 index 00000000..6fbc0c66 --- /dev/null +++ b/src/compressed_tensors/compressors/int_quantized.py @@ -0,0 +1,95 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Generator, Tuple + +import torch +from compressed_tensors.compressors import Compressor +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize +from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from safetensors import safe_open +from torch import Tensor +from tqdm import tqdm + + +__all__ = ["IntQuantizationCompressor"] + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +@Compressor.register(name=CompressionFormat.int_quantized.value) +class IntQuantizationCompressor(Compressor): + """ + Integer compression for quantized models. Weight of each quantized layer is + converted from its original float type to the format specified by the layer's + quantization scheme. + """ + + COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"] + + def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: + model_quant_args = kwargs["model_quant_args"] + compressed_dict = {} + _LOGGER.debug( + f"Compressing model with {len(model_state)} parameterized layers..." + ) + + for name, value in tqdm(model_state.items(), desc="Compressing model"): + if name.endswith(".weight"): + prefix = name.removesuffix(".weight") + scale = model_state.get(merge_names(prefix, "weight_scale"), None) + zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) + if scale is not None and zp is not None: + # weight is quantized, compress it + quant_args = model_quant_args[prefix] + try: + bit_depth = torch.finfo(value.dtype).bits + except TypeError: + bit_depth = torch.iinfo(value.dtype).bits + if bit_depth > quant_args.num_bits: + # only quantize if not already quantized + value = quantize( + x=value, + scale=scale, + zero_point=zp, + args=quant_args, + dtype=torch.int8, + ) + + compressed_dict[name] = value.to("cpu") + + return compressed_dict + + def decompress( + self, path_to_model_or_tensors: str, device: str = "cpu" + ) -> Generator[Tuple[str, Tensor], None, None]: + weight_mappings = get_nested_weight_mappings( + path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES + ) + for weight_name in weight_mappings.keys(): + weight_data = {} + for param_name, safe_path in weight_mappings[weight_name].items(): + full_name = merge_names(weight_name, param_name) + with safe_open(safe_path, framework="pt", device=device) as f: + weight_data[param_name] = f.get_tensor(full_name) + + if len(weight_data) == len(self.COMPRESSION_PARAM_NAMES): + decompressed = dequantize( + x_q=weight_data["weight"], + scale=weight_data["weight_scale"], + zero_point=weight_data["weight_zero_point"], + ) + yield merge_names(weight_name, "weight"), decompressed diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py new file mode 100644 index 00000000..9d1cf6df --- /dev/null +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -0,0 +1,264 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import operator +import os +from typing import Dict, Optional, Union + +from compressed_tensors.base import ( + COMPRESSION_CONFIG_NAME, + QUANTIZATION_CONFIG_NAME, + SPARSITY_CONFIG_NAME, +) +from compressed_tensors.compressors import Compressor +from compressed_tensors.config import SparsityCompressionConfig +from compressed_tensors.quantization import ( + QuantizationConfig, + QuantizationStatus, + apply_quantization_config, + load_pretrained_quantization, +) +from compressed_tensors.quantization.utils import ( + is_module_quantized, + iter_named_leaf_modules, +) +from compressed_tensors.utils import get_safetensors_folder +from torch import Tensor +from torch.nn import Module, Parameter +from tqdm import tqdm +from transformers import AutoConfig +from transformers.file_utils import CONFIG_NAME + + +__all__ = ["ModelCompressor"] + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class ModelCompressor: + """ + Handles compression and decompression of a model with a sparsity config and/or + quantization config. + + Compression LifeCycle + - compressor = ModelCompressor.from_pretrained_model(model) + - compressed_state_dict = compressor.compress(model, state_dict) + - compressor.quantization_compressor.compress(model, state_dict) + - compressor.sparsity_compressor.compress(model, state_dict) + - model.save_pretrained(output_dir, state_dict=compressed_state_dict) + - compressor.update_config(output_dir) + + Decompression LifeCycle + - compressor = ModelCompressor.from_pretrained(comp_model_path) + - model = AutoModel.from_pretrained(comp_model_path) + - compressor.decompress(comp_model_path, model) + - compressor.sparsity_compressor.decompress(comp_model_path, model) + - compressor.quantization_compressor.decompress(comp_model_path, model) + + :param sparsity_config: config specifying sparsity compression parameters + :param quantization_config: config specifying quantization compression parameters + """ + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + ) -> Optional["ModelCompressor"]: + """ + Given a path to a model config, extract the sparsity and/or quantization + configs and load a ModelCompressor + + :param pretrained_model_name_or_path: path to model config on disk or HF hub + :return: compressor for the extracted configs + """ + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) + if compression_config is None: + return None + + sparsity_config = compression_config.get(SPARSITY_CONFIG_NAME, None) + quantization_config = compression_config.get(QUANTIZATION_CONFIG_NAME, None) + + if sparsity_config is None and quantization_config is None: + return None + + if sparsity_config is not None: + format = sparsity_config.get("format") + sparsity_config = SparsityCompressionConfig.load_from_registry( + format, **sparsity_config + ) + if quantization_config is not None: + quantization_config = QuantizationConfig.parse_obj(quantization_config) + + return cls( + sparsity_config=sparsity_config, quantization_config=quantization_config + ) + + @classmethod + def from_pretrained_model( + cls, + model: Module, + sparsity_config: Union[SparsityCompressionConfig, str, None] = None, + quantization_format: Optional[str] = None, + ) -> Optional["ModelCompressor"]: + """ + Given a pytorch model and optional sparsity and/or quantization configs, + load the appropriate compressors + + :param model: pytorch model to target for compression + :param sparsity_config: a filled in sparsity config or string corresponding + to a sparsity compression algorithm + :param quantization_format: string corresponding to a quantization compression + algorithm + :return: compressor for the extracted configs + """ + quantization_config = QuantizationConfig.from_pretrained( + model, format=quantization_format + ) + + if isinstance(sparsity_config, str): # we passed in a sparsity format + sparsity_config = SparsityCompressionConfig.load_from_registry( + sparsity_config + ) + + if sparsity_config is None and quantization_config is None: + return None + + return cls( + sparsity_config=sparsity_config, quantization_config=quantization_config + ) + + def __init__( + self, + sparsity_config: Optional[SparsityCompressionConfig] = None, + quantization_config: Optional[QuantizationConfig] = None, + ): + self.sparsity_config = sparsity_config + self.quantization_config = quantization_config + self.sparsity_compressor = None + self.quantization_compressor = None + + if sparsity_config is not None: + self.sparsity_compressor = Compressor.load_from_registry( + sparsity_config.format, config=sparsity_config + ) + if quantization_config is not None: + self.quantization_compressor = Compressor.load_from_registry( + quantization_config.format, config=quantization_config + ) + + def compress( + self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None + ) -> Dict[str, Tensor]: + """ + Compresses a dense state dict or model with sparsity and/or quantization + + :param model: uncompressed model to compress + :param model_state: optional uncompressed state_dict to insert into model + :return: compressed state dict + """ + if state_dict is None: + state_dict = model.state_dict() + + compressed_state_dict = state_dict + quantized_modules_to_args = _get_weight_arg_mappings(model) + if self.quantization_compressor is not None: + compressed_state_dict = self.quantization_compressor.compress( + state_dict, model_quant_args=quantized_modules_to_args + ) + + if self.sparsity_compressor is not None: + compressed_state_dict = self.sparsity_compressor.compress( + compressed_state_dict + ) + + return compressed_state_dict + + def decompress(self, model_path: str, model: Module): + """ + Overwrites the weights in model with weights decompressed from model_path + + :param model_path: path to compressed weights + :param model: pytorch model to load decompressed weights into + """ + model_path = get_safetensors_folder(model_path) + if self.sparsity_compressor is not None: + dense_gen = self.sparsity_compressor.decompress(model_path) + self._replace_weights(dense_gen, model) + setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) + + if self.quantization_compressor is not None: + apply_quantization_config(model, self.quantization_config) + load_pretrained_quantization(model, model_path) + dense_gen = self.quantization_compressor.decompress(model_path) + self._replace_weights(dense_gen, model) + + def update_status(module): + module.quantization_status = QuantizationStatus.FROZEN + + model.apply(update_status) + setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config) + + def update_config(self, save_directory: str): + """ + Update the model config located at save_directory with compression configs + for sparsity and/or quantization + + :param save_directory: path to a folder containing a HF model config + """ + config_file_path = os.path.join(save_directory, CONFIG_NAME) + if not os.path.exists(config_file_path): + _LOGGER.warning( + f"Could not find a valid model config file in " + f"{save_directory}. Compression config will not be saved." + ) + return + + with open(config_file_path, "r") as config_file: + config_data = json.load(config_file) + + config_data[COMPRESSION_CONFIG_NAME] = {} + if self.quantization_config is not None: + quant_config_data = self.quantization_config.model_dump() + config_data[COMPRESSION_CONFIG_NAME][ + QUANTIZATION_CONFIG_NAME + ] = quant_config_data + if self.sparsity_config is not None: + sparsity_config_data = self.sparsity_config.model_dump() + config_data[COMPRESSION_CONFIG_NAME][ + SPARSITY_CONFIG_NAME + ] = sparsity_config_data + + with open(config_file_path, "w") as config_file: + json.dump(config_data, config_file, indent=2, sort_keys=True) + + def _replace_weights(self, dense_weight_generator, model): + for name, data in tqdm(dense_weight_generator, desc="Decompressing model"): + # loading the decompressed weights into the model + model_device = operator.attrgetter(name)(model).device + data_new = Parameter(data.to(model_device)) + data_old = operator.attrgetter(name)(model) + data_old.data = data_new.data + + +def _get_weight_arg_mappings(model: Module) -> Dict: + quantized_modules_to_args = {} + for name, submodule in iter_named_leaf_modules(model): + if is_module_quantized(submodule): + if submodule.quantization_scheme.weights is not None: + quantized_modules_to_args[name] = submodule.quantization_scheme.weights + + return quantized_modules_to_args diff --git a/src/compressed_tensors/compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_bitmask.py index abf09fa3..10b398c1 100644 --- a/src/compressed_tensors/compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_bitmask.py @@ -17,7 +17,7 @@ import numpy import torch -from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.compressors import Compressor from compressed_tensors.config import CompressionFormat from compressed_tensors.utils import get_nested_weight_mappings, merge_names from safetensors import safe_open @@ -37,8 +37,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -@ModelCompressor.register(name=CompressionFormat.sparse_bitmask.value) -class BitmaskCompressor(ModelCompressor): +@Compressor.register(name=CompressionFormat.sparse_bitmask.value) +class BitmaskCompressor(Compressor): """ Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d values tensor, with their locations stored in a 2d bitmask @@ -67,7 +67,7 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: f"found an existing entry for {key}. The existing entry will " "be replaced." ) - compressed_dict.update(bitmask_dict) + compressed_dict |= bitmask_dict return compressed_dict diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 96778995..17c8da73 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -19,17 +19,18 @@ from pydantic import BaseModel -__all__ = ["CompressionConfig", "CompressionFormat"] +__all__ = ["SparsityCompressionConfig", "CompressionFormat"] class CompressionFormat(Enum): - dense_sparsity = "dense-sparsity" + dense = "dense" sparse_bitmask = "sparse-bitmask" + int_quantized = "int-quantized" -class CompressionConfig(RegistryMixin, BaseModel): +class SparsityCompressionConfig(RegistryMixin, BaseModel): """ - Base data class for storing compression parameters + Base data class for storing sparsity compression parameters :param format: name of compression format :param global_sparsity: average sparsity of the entire model diff --git a/src/compressed_tensors/config/dense.py b/src/compressed_tensors/config/dense.py index 0a18309e..8e7e3b7a 100644 --- a/src/compressed_tensors/config/dense.py +++ b/src/compressed_tensors/config/dense.py @@ -14,14 +14,14 @@ from typing import Optional -from compressed_tensors.config import CompressionConfig, CompressionFormat +from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig __all__ = ["DenseSparsityConfig"] -@CompressionConfig.register(name=CompressionFormat.dense_sparsity.value) -class DenseSparsityConfig(CompressionConfig): +@SparsityCompressionConfig.register(name=CompressionFormat.dense.value) +class DenseSparsityConfig(SparsityCompressionConfig): """ Identity configuration for storing a sparse model in an uncompressed dense format @@ -31,6 +31,6 @@ class DenseSparsityConfig(CompressionConfig): "unstructured", "2:4", "8:16" etc """ - format: str = CompressionFormat.dense_sparsity.value + format: str = CompressionFormat.dense.value global_sparsity: Optional[float] = 0.0 sparsity_structure: Optional[str] = "unstructured" diff --git a/src/compressed_tensors/config/sparse_bitmask.py b/src/compressed_tensors/config/sparse_bitmask.py index 9d2015f3..c14d9f7c 100644 --- a/src/compressed_tensors/config/sparse_bitmask.py +++ b/src/compressed_tensors/config/sparse_bitmask.py @@ -14,14 +14,14 @@ from typing import Optional -from compressed_tensors.config import CompressionConfig, CompressionFormat +from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig __all__ = ["BitmaskConfig"] -@CompressionConfig.register(name=CompressionFormat.sparse_bitmask.value) -class BitmaskConfig(CompressionConfig): +@SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value) +class BitmaskConfig(SparsityCompressionConfig): """ Configuration for storing a sparse model using bitmask compression diff --git a/src/compressed_tensors/quantization/lifecycle/__init__.py b/src/compressed_tensors/quantization/lifecycle/__init__.py index 9504597b..0fa15b64 100644 --- a/src/compressed_tensors/quantization/lifecycle/__init__.py +++ b/src/compressed_tensors/quantization/lifecycle/__init__.py @@ -19,4 +19,5 @@ from .forward import * from .frozen import * from .initialize import * +from .compressed import * from .apply import * diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 4c601d07..64c5ae06 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -19,6 +19,9 @@ from compressed_tensors.quantization.lifecycle.calibration import ( set_module_for_calibration, ) +from compressed_tensors.quantization.lifecycle.compressed import ( + compress_quantized_weights, +) from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, @@ -118,13 +121,20 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): :param model: model to apply quantization to :param status: status to update the module to """ - if status >= QuantizationStatus.INITIALIZED: + current_status = _infer_status(model) + + if status >= QuantizationStatus.INITIALIZED > current_status: model.apply(initialize_module_for_quantization) - if status >= QuantizationStatus.CALIBRATION: + + if current_status < status >= QuantizationStatus.CALIBRATION > current_status: model.apply(set_module_for_calibration) - if status >= QuantizationStatus.FROZEN: + + if current_status < status >= QuantizationStatus.FROZEN > current_status: model.apply(freeze_module_quantization) + if current_status < status >= QuantizationStatus.COMPRESSED > current_status: + model.apply(compress_quantized_weights) + def find_first_name_or_class_match( name: str, module: Module, targets: Iterable[str], check_contains: bool = False @@ -156,6 +166,14 @@ def _find_first_match( return None +def _infer_status(model: Module) -> Optional[QuantizationStatus]: + for module in model.modules(): + status = getattr(module, "quantization_status", None) + if status is not None: + return status + return None + + def _load_quant_args_from_state_dict( base_name: str, module_name: str, module: Module, state_dict: Dict ): diff --git a/src/compressed_tensors/quantization/lifecycle/compressed.py b/src/compressed_tensors/quantization/lifecycle/compressed.py new file mode 100644 index 00000000..84962df4 --- /dev/null +++ b/src/compressed_tensors/quantization/lifecycle/compressed.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from compressed_tensors.quantization.lifecycle.forward import quantize +from compressed_tensors.quantization.quant_config import QuantizationStatus +from torch.nn import Module + + +__all__ = [ + "compress_quantized_weights", +] + + +_LOGGER = logging.getLogger(__name__) + + +def compress_quantized_weights(module: Module): + """ + Quantizes the module weight representation to use fewer bits in memory + + apply to full model with `model.apply(compress_quantized_weights)` + + :param module: module to compress to quantized representation + """ + scheme = getattr(module, "quantization_scheme", None) + if not scheme or not scheme.weights: + # no quantization scheme or weights not quantized, nothing to do + return + + if scheme is QuantizationStatus.COMPRESSED: + # module is already compressed, nothing to do + return + + weight = getattr(module, "weight", None) + scale = getattr(module, "weight_scale", None) + zero_point = getattr(module, "weight_zero_point", None) + + if weight is None or scale is None or zero_point is None: + # no weight, scale, or ZP, nothing to do + + # mark as compressed here to maintain consistent status throughout the model + module.quantization_status = QuantizationStatus.COMPRESSED + return + + module.weight.requires_grad = False # cannot use auto grad after compression + module.weight.data = quantize( + x=weight, + scale=scale, + zero_point=zero_point, + args=scheme.weights, + dtype=torch.int8, + ) + + module.quantization_status = QuantizationStatus.COMPRESSED diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index c17ed33f..31fb58f1 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -14,6 +14,7 @@ from functools import wraps from math import ceil +from typing import Optional import torch from compressed_tensors.quantization.quant_args import ( @@ -33,16 +34,24 @@ def quantize( x: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, - q_min: torch.Tensor, - q_max: torch.Tensor, + args: QuantizationArgs, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + bit_range = 2**args.num_bits + q_max = torch.tensor(bit_range / 2 - 1, device=x.device) + q_min = torch.tensor(-bit_range / 2, device=x.device) - return torch.clamp( + quantized_value = torch.clamp( torch.round(x / scale + zero_point), q_min, q_max, ) + if dtype is not None: + quantized_value = quantized_value.to(dtype) + + return quantized_value + @torch.no_grad() def dequantize( @@ -75,10 +84,6 @@ def fake_quantize( :return: fake quantized tensor """ - bit_range = 2**args.num_bits - max_q = torch.tensor(bit_range / 2 - 1, device=x.device) - min_q = torch.tensor(-bit_range / 2, device=x.device) - group_size = args.group_size # group @@ -111,7 +116,7 @@ def fake_quantize( zp = zero_point[:, i].unsqueeze(1) idx = i * group_size - Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) + Q = quantize(x[:, idx : (idx + group_size)], sc, zp, args) DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) # channel-wise @@ -121,7 +126,7 @@ def fake_quantize( scale = scale.unsqueeze(0) zero_point = zero_point.unsqueeze(0) - Q = quantize(x, scale, zero_point, min_q, max_q) + Q = quantize(x, scale, zero_point, args) DQ = dequantize(Q, scale, zero_point) # per-token @@ -134,11 +139,11 @@ def fake_quantize( scale = scale.unsqueeze(1) zero_point = zero_point.unsqueeze(1) - Q = quantize(x, scale, zero_point, min_q, max_q) + Q = quantize(x, scale, zero_point, args) DQ = dequantize(Q, scale, zero_point) else: - Q = quantize(x, scale, zero_point, min_q, max_q) + Q = quantize(x, scale, zero_point, args) DQ = dequantize(Q, scale, zero_point) return DQ diff --git a/src/compressed_tensors/quantization/lifecycle/frozen.py b/src/compressed_tensors/quantization/lifecycle/frozen.py index 34d132ec..652f1c3a 100644 --- a/src/compressed_tensors/quantization/lifecycle/frozen.py +++ b/src/compressed_tensors/quantization/lifecycle/frozen.py @@ -35,6 +35,10 @@ def freeze_module_quantization(module: Module): # no quantization scheme nothing to do return + if module.quantization_status == QuantizationStatus.FROZEN: + # nothing to do, already frozen + return + # delete observers from module if not dynamic if scheme.input_activations and not scheme.input_activations.dynamic: delattr(module, "input_observer") diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index f548dba3..ea65e90a 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -35,12 +35,13 @@ def calculate_qparams( """ min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + device = min_vals.device bit_range = 2**quantization_args.num_bits - 1 bit_min = -(bit_range + 1) / 2 bit_max = bit_min + bit_range if quantization_args.symmetric: - zero_points = torch.tensor(0).to(torch.int8) + zero_points = torch.tensor(0, device=device).to(torch.int8) max_val_pos = torch.max(-min_vals, max_vals) scales = max_val_pos / (float(bit_range) / 2) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index a894b4c2..43127b79 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -16,6 +16,7 @@ from typing import Dict, List, Optional from compressed_tensors.base import QUANTIZATION_CONFIG_NAME +from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( calculate_compression_ratio, @@ -62,10 +63,33 @@ def lifecycle_order(cls) -> List["QuantizationStatus"]: return def __ge__(self, other): + if other is None: + return True if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other) + def __gt__(self, other): + if other is None: + return True + if not isinstance(other, self.__class__): + raise NotImplementedError + return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other) + + def __lt__(self, other): + if other is None: + return False + if not isinstance(other, self.__class__): + raise NotImplementedError + return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other) + + def __le__(self, other): + if other is None: + return False + if not isinstance(other, self.__class__): + raise NotImplementedError + return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other) + LIFECYCLE_ORDER = [ QuantizationStatus.INITIALIZED, @@ -116,7 +140,9 @@ def from_model_config(model_name_or_path) -> "QuantizationConfig": return QuantizationConfig.parse_obj(quantization_config) @staticmethod - def from_pretrained(model: Module) -> "QuantizationConfig": + def from_pretrained( + model: Module, format: Optional[str] = None + ) -> Optional["QuantizationConfig"]: """ Converts a model into its associated QuantizationConfig based on the QuantizationScheme attached to each quanitzed module @@ -147,6 +173,9 @@ def from_pretrained(model: Module) -> "QuantizationConfig": if not match_found: quant_scheme_to_layers.append(scheme) + if len(quant_scheme_to_layers) == 0: # No quantized layers + return None + # clean up ignore list, we can leave out layers types if none of the # instances are quantized consolidated_ignore = [] @@ -162,10 +191,20 @@ def from_pretrained(model: Module) -> "QuantizationConfig": group_name = "group_" + str(idx) config_groups[group_name] = scheme + # TODO: this is incorrect in compressed mode, since we are overwriting the + # original weight we lose the uncompressed bit_depth indo compression_ratio = calculate_compression_ratio(model) + + if format is None: + if quantization_status == QuantizationStatus.COMPRESSED: + format = CompressionFormat.int_quantized.value + else: + format = CompressionFormat.dense.value + return QuantizationConfig( config_groups=config_groups, quantization_status=quantization_status, global_compression_ratio=compression_ratio, + format=format, ignore=consolidated_ignore, ) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 8676ef15..66944eb1 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -15,6 +15,7 @@ from typing import Tuple import torch +from compressed_tensors.quantization.observers.base import Observer from torch.nn import Module from tqdm import tqdm @@ -78,11 +79,25 @@ def module_type(module: Module) -> str: def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]: - # yields modules that do not have any submodules - # TODO: potentially expand to add list of allowed submodules such as observers + """ + Yields modules that do not have any submodules except observers. The observers + themselves are not yielded + + :param model: model to get leaf modules of + :returns: generator tuple of (name, leaf_submodule) + """ for name, submodule in model.named_modules(): - if len(list(submodule.children())) == 0: + children = list(submodule.children()) + if len(children) == 0 and not isinstance(submodule, Observer): yield name, submodule + else: + has_non_observer_children = False + for child in children: + if not isinstance(child, Observer): + has_non_observer_children = True + + if not has_non_observer_children: + yield name, submodule def calculate_compression_ratio(model: Module) -> float: diff --git a/tests/test_int_quant.py b/tests/test_int_quant.py new file mode 100644 index 00000000..b5b2cdb3 --- /dev/null +++ b/tests/test_int_quant.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +import torch +from compressed_tensors import IntQuantizationCompressor +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, +) +from compressed_tensors.quantization.lifecycle.forward import fake_quantize +from safetensors.torch import save_file + + +def get_dummy_quant_config(): + config_groups = { + "group_1": QuantizationScheme(targets=["Linear"], weights=QuantizationArgs()), + } + ignore = ["lm_head"] + quant_config = QuantizationConfig( + config_groups=config_groups, + ignore=ignore, + ) + + return quant_config + + +def test_quant_format(): + dense_state_dict = { + "dummy.weight": torch.rand((512, 1024)), + "dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32), + "dummy.weight_zero_point": torch.tensor(0, dtype=torch.int32), + } + quant_config = get_dummy_quant_config() + + compressor = IntQuantizationCompressor(config=quant_config) + quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} + compressed_state_dict = compressor.compress( + dense_state_dict, model_quant_args=quantized_modules_to_args + ) + + # state_dict params should be the same + assert len(dense_state_dict) == len(compressed_state_dict) + + # check compressed to int8 + assert compressed_state_dict["dummy.weight"].dtype == torch.int8 + assert compressed_state_dict["dummy.weight_scale"].dtype == torch.float32 + assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32 + + +def test_reload_match(tmp_path): + dense_state_dict = { + "dummy.weight": torch.rand((511, 350)), + "dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32), + "dummy.weight_zero_point": torch.tensor(0, dtype=torch.int32), + "dummy2.weight": torch.rand((128, 280)), + "dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32), + "dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int32), + } + quant_config = get_dummy_quant_config() + + compressor = IntQuantizationCompressor(config=quant_config) + quantized_modules_to_args = { + "dummy": quant_config.config_groups["group_1"].weights, + "dummy2": quant_config.config_groups["group_1"].weights, + } + compressed_state_dict = compressor.compress( + dense_state_dict, model_quant_args=quantized_modules_to_args + ) + save_file(compressed_state_dict, tmp_path / "model.safetensors") + reconstructed_dense_gen = compressor.decompress(tmp_path) + reconstructed_dense = {} + for name, value in reconstructed_dense_gen: + reconstructed_dense[name] = value + + fake_quant_dummy = fake_quantize( + dense_state_dict["dummy.weight"], + scale=dense_state_dict["dummy.weight_scale"], + zero_point=dense_state_dict["dummy.weight_zero_point"], + args=quantized_modules_to_args["dummy"], + ) + assert torch.equal( + fake_quant_dummy, reconstructed_dense["dummy.weight"].to(torch.float32) + ) + + fake_quant_dummy2 = fake_quantize( + dense_state_dict["dummy2.weight"], + scale=dense_state_dict["dummy2.weight_scale"], + zero_point=dense_state_dict["dummy2.weight_zero_point"], + args=quantized_modules_to_args["dummy2"], + ) + assert torch.equal( + fake_quant_dummy2, reconstructed_dense["dummy2.weight"].to(torch.float32) + ) + + shutil.rmtree(tmp_path) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 6a3d17af..82e21afb 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional -from compressed_tensors.quantization.lifecycle import apply_quantization_config +import torch +from compressed_tensors.quantization.lifecycle import ( + apply_quantization_config, + apply_quantization_status, +) from compressed_tensors.quantization.quant_config import ( QuantizationConfig, QuantizationStatus, @@ -22,7 +27,7 @@ def test_apply_quantization_config_tinyllama(): - quant_config = get_sample_tinyllama_quant_config() + quant_config = get_sample_tinyllama_quant_config(status="calibration") model = get_tinyllama_model() # check that model is not already quantized @@ -55,6 +60,23 @@ def test_apply_quantization_config_tinyllama(): assert num_embeddings == 1 assert num_rotary_embeddings == 22 + # test quantization compression + # sample forward pass to fill scales, zps + model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) + apply_quantization_status(model, QuantizationStatus.COMPRESSED) + for name, module in model.named_modules(): + if name in quant_config.ignore: + continue + module_type = module.__class__.__name__ + if module_type == "Linear": + _test_layer_quantization_status( + module, + inputs=True, + weights=True, + expected_status=QuantizationStatus.COMPRESSED, + expected_dtype=torch.int8, + ) + def test_serialize_config_tinyllama(): quant_config = get_sample_tinyllama_quant_config() @@ -74,18 +96,26 @@ def test_serialize_config_tinyllama(): assert serialized_config.config_groups["group_1"].targets == ["Linear"] assert serialized_config.config_groups["group_1"].input_activations is not None assert serialized_config.quantization_status == QuantizationStatus.FROZEN - assert serialized_config.format == "fakequant" + assert serialized_config.format == "dense" assert serialized_config.quant_method == "sparseml" assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] assert serialized_config.global_compression_ratio > 1.0 assert serialized_config.global_compression_ratio < 8.0 -def _test_layer_quantization_status(module, inputs: bool, weights: bool): +def _test_layer_quantization_status( + module, + inputs: bool, + weights: bool, + expected_status: Optional[QuantizationStatus] = None, + expected_dtype: Optional[torch.dtype] = None, +): # check if quantization is applied at all (true if inputs or weights targeted) quantized = inputs or weights assert hasattr(module, "quantization_scheme") == quantized assert hasattr(module, "quantization_status") == quantized + if expected_status is not None: + assert module.quantization_status is expected_status # check inputs matches expected assert hasattr(module, "input_scale") == inputs @@ -94,6 +124,8 @@ def _test_layer_quantization_status(module, inputs: bool, weights: bool): # check weights matches expected assert hasattr(module, "weight_scale") == weights assert hasattr(module, "weight_zero_point") == weights + if weights and expected_dtype is not None: + assert module.weight.dtype is expected_dtype def get_tinyllama_model(): @@ -102,11 +134,11 @@ def get_tinyllama_model(): ) -def get_sample_tinyllama_quant_config(): +def get_sample_tinyllama_quant_config(status: str = "frozen"): config_dict = { "quant_method": "sparseml", "format": "fakequant", - "quantization_status": "frozen", + "quantization_status": status, "global_compression_ratio": None, "config_groups": { "group_1": { diff --git a/tests/test_quantization/lifecycle/test_lifecycle.py b/tests/test_quantization/lifecycle/test_lifecycle.py index 352fcb4d..7ded8ef9 100644 --- a/tests/test_quantization/lifecycle/test_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_lifecycle.py @@ -97,7 +97,7 @@ def test_lifecyle(create_quantization_scheme): for _ in range(10): layer(torch.randn(4, 4)) - assert initialized_layer_input_zero_point != layer.input_zero_point + assert initialized_layer_input_zero_point != 0 assert initialized_layer_input_scale != layer.input_scale assert initialized_layer_weight_scale == layer.weight_scale diff --git a/tests/test_registry.py b/tests/test_registry.py index ffe66b85..4726fcf2 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -16,11 +16,11 @@ from compressed_tensors import ( BitmaskCompressor, BitmaskConfig, - CompressionConfig, CompressionFormat, + Compressor, DenseCompressor, DenseSparsityConfig, - ModelCompressor, + SparsityCompressionConfig, ) @@ -28,11 +28,11 @@ "name,type", [ [CompressionFormat.sparse_bitmask.value, BitmaskConfig], - [CompressionFormat.dense_sparsity.value, DenseSparsityConfig], + [CompressionFormat.dense.value, DenseSparsityConfig], ], ) def test_configs(name, type): - config = CompressionConfig.load_from_registry(name) + config = SparsityCompressionConfig.load_from_registry(name) assert isinstance(config, type) assert config.format == name @@ -41,13 +41,13 @@ def test_configs(name, type): "name,type", [ [CompressionFormat.sparse_bitmask.value, BitmaskCompressor], - [CompressionFormat.dense_sparsity.value, DenseCompressor], + [CompressionFormat.dense.value, DenseCompressor], ], ) def test_compressors(name, type): - compressor = ModelCompressor.load_from_registry( - name, config=CompressionConfig(format="none") + compressor = Compressor.load_from_registry( + name, config=SparsityCompressionConfig(format="none") ) assert isinstance(compressor, type) - assert isinstance(compressor.config, CompressionConfig) + assert isinstance(compressor.config, SparsityCompressionConfig) assert compressor.config.format == "none" diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index 7ae0799d..d8a430e4 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -44,10 +44,10 @@ def test_save_compressed_sparse_bitmask(tmp_path, tensors): assert (tmp_path / "model.safetensors").exists() -def test_save_compressed_dense_sparsity(tmp_path, tensors): +def test_save_compressed_dense(tmp_path, tensors): save_compressed( tensors, - compression_format="dense-sparsity", + compression_format="dense", save_path=tmp_path / "model.safetensors", ) assert (tmp_path / "model.safetensors").exists() @@ -92,10 +92,10 @@ def test_load_compressed_sparse_bitmask(tmp_path, tensors): assert torch.allclose(tensors[key], loaded_tensors[key]) -def test_load_compressed_dense_sparsity(tmp_path, tensors): +def test_load_compressed_dense(tmp_path, tensors): save_compressed( tensors, - compression_format="dense-sparsity", + compression_format="dense", save_path=tmp_path / "model.safetensors", ) save_compressed(