From 68138e8204ca49d18981f07c51b822b87978ef88 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Mon, 21 Oct 2024 12:01:20 +0800 Subject: [PATCH] refine AuoRound format and support marlin repacking (#280) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 8 +- auto_round/__main__.py | 6 +- auto_round/auto_quantizer.py | 406 ++++++++++++---- auto_round/backend.py | 458 ++++++++++++++++++ .../export/export_to_autoround/export.py | 58 +-- auto_round/utils.py | 134 +---- .../qbits/qlinear_qbits_gptq.py | 227 +++++++++ examples/language-modeling/main.py | 4 +- examples/language-modeling/run_autoround.sh | 5 +- .../run_autoround_on_gaudi.sh | 3 +- examples/language-modeling/run_xpu.sh | 9 +- test/test_generation.py | 177 +++++++ 12 files changed, 1222 insertions(+), 273 deletions(-) create mode 100644 auto_round/backend.py create mode 100644 auto_round_extension/qbits/qlinear_qbits_gptq.py create mode 100644 test/test_generation.py diff --git a/README.md b/README.md index 31ace144..ab8a8b86 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ asymmetric kernel has issues** that can cause considerable accuracy drops, parti models. Additionally, symmetric quantization tends to perform poorly at 2-bit precision. -**AutoAWQ Format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted +**AutoAWQ Format**(>0.3.0): This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted within the community, only 4-bits quantization is supported. It features specialized layer fusion tailored for Llama models. @@ -230,13 +230,13 @@ in [Gaudi Guide](https://docs.habana.ai/en/latest/). from transformers import AutoModelForCausalLM, AutoTokenizer from auto_round import AutoRoundConfig -device = "auto" ##cpu, hpu, cuda +backend = "auto" ##cpu, hpu, cuda, cuda:marlin('pip install -v gptqmodel --no-build-isolation') quantization_config = AutoRoundConfig( - backend=device + backend=backend ) quantized_model_path = "./tmp_autoround" model = AutoModelForCausalLM.from_pretrained(quantized_model_path, - device_map=device, quantization_config=quantization_config) + device_map=backend.split(':')[0], quantization_config=quantization_config) tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) text = "There is a girl who likes adventure," inputs = tokenizer(text, return_tensors="pt").to(model.device) diff --git a/auto_round/__main__.py b/auto_round/__main__.py index f76e885c..29c73231 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -92,7 +92,7 @@ def setup_parser(): parser.add_argument("--format", default=None, type=str, help="The format in which to save the model. " - "The options are 'auto_round', 'auto_round:gptq','auto_round:marlin'," + "The options are 'auto_round', 'auto_round:gptq','auto_round:awq'," " 'auto_gptq', 'auto_awq', 'itrex', 'itrex_xpu' and 'fake'." "default to 'auto_round." ) @@ -316,7 +316,9 @@ def tune(args): format_list = args.format.replace(' ', '').split(',') inplace = False if len(format_list) > 1 else True for format_ in format_list: - eval_folder = f'{export_dir}-{format_}' + save_format_ = format_.replace(":", "-") + save_format_ = save_format_.replace("_", "-") + eval_folder = f'{export_dir}-{save_format_}' autoround.save_quantized(eval_folder, format=format_, inplace=inplace) lm_eval_version = get_library_version("lm-eval") diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index 96540253..c9d3d3d7 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -26,6 +26,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import importlib.util import warnings from dataclasses import dataclass @@ -41,10 +42,18 @@ from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod -from auto_round.utils import get_module, set_module, dynamic_import_inference_linear +from auto_round.utils import get_module, set_module, is_hpu_supported + +from auto_round.backend import get_layer_backend, dynamic_import_inference_linear + import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits +import auto_round_extension.qbits.qlinear_qbits_gptq as qlinear_qbits_gptq +from auto_round.backend import BackendInfos +from transformers.utils.versions import require_version from enum import Enum +from tqdm import tqdm import copy +import re logger = getLogger(__name__) import sys @@ -99,27 +108,6 @@ def is_auto_round_available(): ) -def is_autoround_exllamav2_available(): - res = True - try: - from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix - except ImportError as e: - res = False - return res - - -def is_hpu_supported(): # pragma: no cover - try: - import subprocess - import habana_frameworks.torch.core as htcore # pylint: disable=E0401 - hqt_version = subprocess.check_output(['pip', 'show', \ - 'habana_quantization_toolkit']).decode().split('\n')[1].split(': ')[1] - assert (hqt_version >= "1.17") - except ImportError as e: - return False - return True - - if is_auto_round_available(): from auto_round_extension.cuda.post_init import autoround_post_init @@ -247,15 +235,6 @@ def __init__( self.dataset = dataset self.group_size = group_size self.sym = sym - if "auto" == backend: - if torch.cuda.is_available(): - backend = "auto_round:exllamav2" - elif is_hpu_supported(): - backend = "hpu" - else: - backend = "cpu" - elif "cuda" == backend: - backend = "auto_round:exllamav2" self.backend = backend self.layer_config = layer_config if kwargs is not None: @@ -270,12 +249,12 @@ def post_init(self): raise ValueError(f"Only support quantization to [2,4,8] bits but found {self.bits}") if self.group_size != -1 and self.group_size <= 0: raise ValueError("group_size must be greater than 0 or equal to -1") - ##TODO add more check def get_loading_attributes(self): - attributes_dict = copy.deepcopy(self.__dict__) - loading_attributes = ["backend"] - loading_attibutes_dict = {i: j for i, j in attributes_dict.items() if i in loading_attributes} + # attributes_dict = copy.deepcopy(self.__dict__) + loading_attibutes_dict = {"target_backend": self.backend} + # loading_attributes = ["backend"] + # loading_attibutes_dict = {i: j for i, j in attributes_dict.items() if i in loading_attributes} return loading_attibutes_dict def to_dict(self): @@ -293,7 +272,6 @@ class AutoRoundQuantizer(HfQuantizer): def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): super().__init__(quantization_config, **kwargs) - self.exllama2_available = is_autoround_exllamav2_available() def validate_environment(self, *args, **kwargs): if not is_auto_round_available(): @@ -316,34 +294,131 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AutoRound.") return torch_dtype + def find_backend(self, target_backend: str): + """Finds the matching backend key based on the target backend or its alias. + + This function checks if the provided `target_backend` is directly present in `BackendInfos`. + If not, it iterates through the backends to see if the `target_backend` matches any backend's alias. + + Args: + target_backend (str): + The name of the backend or alias to find. + + Returns: + str or None: + The backend key if a match is found, otherwise `None`. + """ + # Directly return if target_backend exists in BackendInfos + if target_backend in BackendInfos: + return target_backend + + # Search through BackendInfos to check if target_backend matches any backend alias + for key in BackendInfos.keys(): + backendInfo = BackendInfos[key] + if backendInfo.alias is not None and target_backend in backendInfo.alias: + return key + + # Return None if no matching backend or alias is found + return None + + def detect_device(self, target_backend, orig_backend): + """Detects the appropriate device for the specified backend. + + This function determines the device type based on the target backend. If the target backend is + not specified, it defaults to the original backend. The function checks for the availability + of CUDA, HPU, or CPU, and returns the appropriate device type. + + Args: + target_backend (str or None): + The name of the target backend. If None, defaults to `orig_backend`. + orig_backend (str): + The original backend name to fall back on if `target_backend` is None. + + Returns: + str: + The type of device detected ('cuda', 'hpu', or 'cpu'). + + Raises: + ValueError: + If the specified backend cannot be found. + """ + # Default to the original backend if target_backend is not provided + if target_backend is None: + target_backend = orig_backend + + # Check for specific device types based on the target backend + if "cuda" in target_backend: + return "cuda" + elif "hpu" in target_backend: + return "hpu" + elif "cpu" in target_backend: + return "cpu" + + # Determine the device automatically based on availability + if target_backend.split(":")[0] == "auto": + if torch.cuda.is_available(): + return "cuda" + elif is_hpu_supported(): + return "hpu" + else: + return "cpu" + + # Find the backend and determine the device type from BackendInfos + backend = self.find_backend(target_backend) + if backend is None: + raise ValueError("Backend not found, please set it to 'auto' to have a try ") + + return BackendInfos[backend].device[0] + def convert_model(self, model: nn.Module): - """Convert the model to an AutoRound model by getting and replacing the layers. + """Converts the given model to an AutoRound model by replacing its layers with quantized layers. + + This method extracts the quantization configuration from the model and adjusts its layers + according to the specified quantization parameters. It supports different backends and + ensures that the model's data type is compatible with the selected hardware. Args: - model (`nn.Module`): - Model to be converted + model (nn.Module): + The model to be converted into an AutoRound model. + + Returns: + nn.Module: + The converted AutoRound model with quantized layers. + + Raises: + ValueError: + If the quantization backend is not specified in the configuration. """ + from auto_round.utils import get_layer_names_in_block quantization_config = model.config.quantization_config + if not hasattr(quantization_config, "target_backend"): + quantization_config.target_backend = quantization_config.backend + + target_device = self.detect_device(quantization_config.target_backend, quantization_config.backend) + if hasattr(quantization_config, "backend"): # pragma: no cover - backend = quantization_config.backend - if "hpu" in backend and model.dtype != torch.bfloat16: - logger.info("change the dtype to `bfloat16` as HPU does not support float16") + if "hpu" == target_device and model.dtype != torch.bfloat16: + logger.info("Change the dtype to `bfloat16` as HPU does not support float16") model = model.to(torch.bfloat16) + bits = quantization_config.bits group_size = quantization_config.group_size - data_type = quantization_config.data_type if hasattr(quantization_config, "data_type") \ - else "int" # pragma: no cover + data_type = quantization_config.data_type if hasattr(quantization_config, + "data_type") else "int" # pragma: no cover sym = quantization_config.sym - quant_block_list = quantization_config.quant_block_list \ - if hasattr(quantization_config, "quant_block_list") else None + quant_block_list = quantization_config.quant_block_list if hasattr(quantization_config, + "quant_block_list") else None layer_names = get_layer_names_in_block(model, quant_block_list=quant_block_list) + extra_config = {} if hasattr(quantization_config, "extra_config"): extra_config = quantization_config.extra_config + layer_names += extra_config.keys() layer_names = list(set(layer_names)) + layer_configs = {} for layer_name in layer_names: layer_configs[layer_name] = {} @@ -359,27 +434,74 @@ def convert_model(self, model: nn.Module): layer_configs[layer_name]["data_type"] = extra_config[layer_name].get("data_type", data_type) layer_configs[layer_name]["sym"] = extra_config[layer_name].get("sym", sym) layer_configs[layer_name]["clip"] = extra_config[layer_name].get("clip", False) + if hasattr(quantization_config, "backend"): # pragma: no cover backend = quantization_config.backend elif 'gptq' in quantization_config.quant_method: # pragma: no cover backend = 'gptq' else: # pragma: no cover logger.error("Please specify quantization backend") + raise ValueError("Quantization backend must be specified.") - self._replace_by_quant_layers(model, layer_configs, backend) + self._replace_by_quant_layers(model, layer_configs, quantization_config.target_backend, target_device, backend) return model - def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend): - """Replaces linear layers in `module` by `QuantLinear` + def _replace_by_quant_layers(self, module: nn.Module, layer_configs, target_backend, target_device, orig_backend): + """Replaces linear layers in the given module with quantized layers. + + This method iterates over the specified layer configurations and replaces + the original layers in the module with instances of `QuantLinear`. It handles + various layer types and ensures that the correct quantization parameters are applied. Args: - module (`nn.Module`): - Module to quantize - names (`List[str]`): - List of names of the module to quantize - name (`str`, defaults to `""`): - To keep track of the name of the current module + module (nn.Module): + The module containing layers to be quantized. + layer_configs (dict): + A dictionary containing configuration for each layer's quantization. + target_backend (str): + The backend to use for quantization, which includes device and format information. + target_device (str): + The device on which the model will run (e.g., 'cuda', 'cpu', 'hpu'). + orig_backend (str): + The original backend of the packing. + + Raises: + AssertionError: + If any condition related to backend or quantization configuration is not met. """ + + def remove_str(input_string: str, sub_str) -> str: + """Removes the specified substring from the input string, if present. + + Args: + input_string (str): + The original string from which to remove the substring. + sub_str (str): + The substring to be removed. + + Returns: + str: + The modified string with the substring removed. + """ + pattern = re.escape(sub_str) + r':?' + return re.sub(pattern, '', input_string) + + if "auto" == target_backend.split(':')[0]: + target_backend = target_backend[4:] # Remove 'auto' + if len(target_backend) >= 1 and target_backend[0] == ":": + target_backend = target_backend[1:] + + # Remove device info from target_backend + target_backend = remove_str(target_backend, "cpu") + target_backend = remove_str(target_backend, "hpu") + target_backend = remove_str(target_backend, "cuda") + orig_backend = self.find_backend(orig_backend) + + if target_backend == "": + target_backend = orig_backend + + self.need_marlin_repacking = False + for layer_name in layer_configs.keys(): config = layer_configs[layer_name] bits = config["bits"] @@ -387,55 +509,170 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend): data_type = config["data_type"] sym = config["sym"] clip = config["clip"] + if not (bits <= 8): continue layer = get_module(module, layer_name) - device = get_device(layer) - QuantLinear = dynamic_import_inference_linear(backend, bits, group_size, sym) if isinstance(layer, nn.Linear): in_features = layer.in_features out_features = layer.out_features - elif isinstance(layer, nn.Conv2d): ##not supported now + elif isinstance(layer, nn.Conv2d): # Not supported currently in_features = layer.in_channels out_features = layer.out_channels - elif isinstance(layer, Conv1D): ##TODO need to have a check + elif isinstance(layer, Conv1D): # TODO: Needs verification in_features = layer.weight.shape[0] out_features = layer.weight.shape[1] - bias = layer.bias is not None - try: - new_layer = QuantLinear( # pylint: disable=E1123 - bits, - group_size, - in_features, - out_features, - bias, - weight_dtype=layer.weight.dtype, - clip = clip + else: + continue + + if "marlin" in target_backend and "marlin" not in orig_backend: + # Need to repack + assert sym == True, "Marlin only supports symmetric quantization" + assert target_device == "cuda", "Marlin only supports CUDA device" + assert not "awq" in orig_backend, "Marlin does not support repacking from AWQ format" + self.need_marlin_repacking = True + # Using original backend to load the layer then replace + layer_backend = orig_backend + else: + target_backend = self.find_backend(target_backend) # TODO: Move out if have supported marlin + layer_backend = get_layer_backend( + target_device, target_backend, orig_backend, bits, group_size, sym, in_features, out_features ) - except: - new_layer = QuantLinear( # pylint: disable=E1123 + + QuantLinear = dynamic_import_inference_linear(layer_backend, bits, group_size, sym) + + layer_device = get_device(layer) + + bias = layer.bias is not None + if "awq" in layer_backend: + new_layer = QuantLinear.from_linear( # pylint: disable=E1123 + layer, bits, group_size, - in_features, - out_features, - bias, - weight_dtype=layer.weight.dtype, + init_only=True ) - - new_layer.device = device + else: + try: + new_layer = QuantLinear( # pylint: disable=E1123 + bits, + group_size, + in_features, + out_features, + bias, + weight_dtype=layer.weight.dtype, + clip=clip + ) + except: + new_layer = QuantLinear( # pylint: disable=E1123 + bits, + group_size, + in_features, + out_features, + bias, + weight_dtype=layer.weight.dtype, + ) + + new_layer.device = layer_device set_module(module, layer_name, new_layer) def qbits_post_init(self, model): dep_check = True for layer in model.modules(): - if isinstance(layer, qlinear_qbits.QuantLinear): + if isinstance(layer, (qlinear_qbits.QuantLinear, qlinear_qbits_gptq.QuantLinear)): if dep_check: layer.req_check() layer.post_init() dep_check = False return model + def repack_marlin(self, model): + """Repack the model to use Marlin format for quantized layers. + + This method iterates through the model's modules, identifies instances of + `QuantLinear`, and replaces them with `MarlinInferenceQuantLinear`. It + handles the initialization of various parameters and the repacking of + quantized weights and scales for optimized performance on Marlin. + + Args: + model (nn.Module): + The model to be repacked into Marlin format. + + Raises: + ImportError: + If the required modules for Marlin inference cannot be imported. + """ + message = "Repacking to Marlin format" + + for n, m in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))): + if m.__class__.__name__ == "QuantLinear": + try: + from gptqmodel.nn_modules.qlinear.qlinear_marlin_inference import ( # pylint: disable=E0401 + MarlinInferenceQuantLinear, + marlin_permute_scales, + marlin_make_workspace + ) + except ImportError: + raise ImportError("Failed to import Marlin inference modules.") + + with torch.device("meta"): + # Create a new MarlinInferenceQuantLinear module with the appropriate parameters. + new_module = MarlinInferenceQuantLinear( + bits=4, + group_size=m.group_size, + sym=True, + desc_act=False, + infeatures=m.infeatures, + outfeatures=m.outfeatures, + bias=m.bias is not None, + ) + + device = m.qweight.device + import gptqmodel_marlin_cuda_inference # pylint: disable=E0401 + + # Initialize the necessary parameters for the new module. + new_module.g_idx = torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + new_module.g_idx_sort_indices = torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + new_module.zp = torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + new_module.bias = m.bias + + # Repack the quantized weight for the Marlin format. + marlin_qweight = gptqmodel_marlin_cuda_inference.gptq_marlin_repack( # pylint: disable=E0401 + m.qweight, + new_module.g_idx_sort_indices, + m.infeatures, + m.outfeatures, + m.bits + ) + new_module.qweight.resize_(marlin_qweight.shape) + new_module.qweight = nn.Parameter(marlin_qweight, requires_grad=False) + + # Permute scales for the new module's configuration. + marlin_scales = marlin_permute_scales( + m.scales, + size_k=m.infeatures, + size_n=m.outfeatures, + group_size=m.group_size + ) + + new_module.scales.resize_(marlin_scales.shape) + new_module.scales = nn.Parameter(marlin_scales, requires_grad=False) + + # Create a workspace for the new module. + new_module.workspace = marlin_make_workspace( # TODO: Consider moving this to post-init. + new_module.outfeatures, device + ) + + # Replace the original module in the model with the new Marlin module. + set_module(model, n, new_module) + + # Clear cache and perform garbage collection to free memory. + torch.cuda.empty_cache() + gc.collect() + def post_init_model(self, model): """Post-initialization that require device information, for example buffers initialization on device. @@ -448,6 +685,12 @@ class StoreAttr(object): pass model.quantize_config = StoreAttr() + if self.need_marlin_repacking: + require_version("gptqmodel", + "marlin format requires gptqmodel to be installed, " + "`pip install -v gptqmodel --no-build-isolation `") + self.repack_marlin(model) + model = autoround_post_init(model) # there are no side-effects after call qbits_post_init when model quant-type not equal to qbits. model = self.qbits_post_init(model) @@ -457,7 +700,7 @@ class StoreAttr(object): def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): if model.__class__.main_input_name != "input_ids": logger.warning("We can only quantize pure text models and " \ - "certain types(Llava/Qwen-VL/Phi-3-vision) of multimodal models.") + "certain types(Llava/Qwen-VL/Phi-3-vision) of multimodal models.") if self.pre_quantized: model = self.convert_model(model) @@ -479,11 +722,8 @@ def is_serializable(self): import transformers -transformers_version = [int(item) for item in transformers.__version__.split('.')[:2]] -if transformers_version[0] == 4 and transformers_version[1] < 38: +if version.parse(transformers.__version__) < version.parse("4.38.0"): logger.error("Please upgrade transformers>=4.38.0 to support lm-head quantization") transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer - - diff --git a/auto_round/backend.py b/auto_round/backend.py new file mode 100644 index 00000000..16e898ab --- /dev/null +++ b/auto_round/backend.py @@ -0,0 +1,458 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from dataclasses import dataclass, field +from typing import List, Any, Optional + +from auto_round.utils import get_library_version + +BackendInfos = {} + + +@dataclass +class BackendInfo: + """Stores configuration details for various backend formats. + + Attributes: + device: A list of strings representing the devices the backend supports + (e.g., 'cuda', 'cpu'). + sym: A list of booleans indicating whether the backend supports symmetric + quantization (True if symmetric, False if not). + packing_format: A string representing the packing format used by the backend + (e.g., 'triton', 'qbits'). + bits: A list of integers specifying the bit-widths supported by the backend + (e.g., [2, 4, 8]). + group_size: An optional list of integers specifying the group size for + quantization. Defaults to None. + priority: An integer representing the backend's priority, where higher values + indicate higher priority. Defaults to 0. + convertable_format: A list of strings specifying the formats that the backend + can convert from. Defaults to an empty list. + feature_checks: A list of feature check functions (e.g., validation methods) + used to verify whether the backend supports certain features. Defaults to + an empty list. + alias: An optional list of strings representing alternative names for the + backend. Defaults to None. + """ + device: List[str] + sym: List[bool] + packing_format: str + bits: List[int] + group_size: Optional[List[int]] = None + priority: int = 0 ##higher is better + convertable_format: List[str] = field(default_factory=list) + feature_checks: List[Any] = field(default_factory=list) + alias: Optional[List[str]] = None + + +def feature_multiply_checker(in_feature, out_feature, in_feature_multiplier, out_feature_multiplier=None): + if out_feature_multiplier is None: + out_feature_multiplier = in_feature_multiplier + return in_feature % in_feature_multiplier == 0 and out_feature % out_feature_multiplier == 0 + + +def feature_num_greater_checker(in_feature, out_feature, num): + return in_feature * out_feature > num + + +feature_multiply_checker_32 = functools.partial(feature_multiply_checker, in_feature_multiplier=32) + +feature_multiply_checker_marlin = functools.partial(feature_multiply_checker, in_feature_multiplier=128, + out_feature_multiplier=256) + +feature_num_greater_checker_1024 = functools.partial(feature_num_greater_checker, num=1024) + +BackendInfos['auto_round:exllamav2'] = BackendInfo(device=["cuda"], sym=[True, False], + packing_format="triton", + bits=[4], group_size=None, + priority=5, + feature_checks=[feature_multiply_checker_32], + alias=["auto_round"] + ) + +BackendInfos['auto_round:tritonv2'] = BackendInfo(device=["cuda"], sym=[True, False], + packing_format="triton", + bits=[2, 4, 8], group_size=None, + priority=0, feature_checks=[feature_multiply_checker_32], + ) + +BackendInfos['gptq:exllamav2'] = BackendInfo(device=["cuda"], sym=[True, False], + packing_format="triton_zp+-1", + bits=[4], group_size=None, + priority=5, + feature_checks=[feature_multiply_checker_32], + alias=["auto_round:gptq:exllamav2", "auto_round:auto_gptq:exllamav2", + 'gptq', 'auto_gptq', "auto_round:gptq", "auto_round:auto_gptq"] + ) + +BackendInfos['gptq:tritonv2'] = BackendInfo(device=["cuda"], sym=[True, False], + packing_format="triton_zp+-1", + bits=[2, 4, 8], group_size=None, + priority=0, feature_checks=[feature_multiply_checker_32], + alias=["auto_round:gptq:tritonv2", "auto_round:auto_gptq:tritonv2", + "auto_gptq:tritonv2"]) + +BackendInfos['awq:gemm'] = BackendInfo(device=["cuda"], sym=[True, False], ##actrally is gemm + packing_format="awq", + bits=[4], group_size=None, + priority=4, feature_checks=[feature_num_greater_checker_1024], + alias=["auto_awq:gemm", "auto_round:awq:gemm", "auto_round:auto_awq:gemm", "awq", + "auto_awq", "auto_round:awq", "aut_round:auto_awq"]) + +BackendInfos['auto_round:qbits'] = BackendInfo(device=["cpu"], sym=[True, False], + packing_format="qbits", + bits=[2, 4, 8], group_size=None, + priority=0, + feature_checks=[], + convertable_format=["triton"]) + +BackendInfos['auto_round:qbits_zp'] = BackendInfo(device=["cpu"], sym=[True, False], + packing_format="qbits", + bits=[2, 4, 8], group_size=None, + priority=0, + feature_checks=[], + convertable_format=["triton_zp+-1"]) + +# BackendInfos['auto_round:marlin'] = BackendInfo(device=["gpu"], sym=[True], +# packing_format="marlin", +# bits=[4], group_size=[-1, 128], +# priority=6, +# feature_checks=[feature_multiply_checker_marlin], +# alias=["marlin", "auto_gptq:marlin", "auto_round:gptq:marlin", +# "auto_round:auto_gptq:marlin"]) + +BackendInfos['auto_round:hpu'] = BackendInfo(device=["hpu"], sym=[True, False], + packing_format="hpu", + bits=[4], + priority=0, + convertable_format=["triton"] + ) + +BackendInfos['auto_round:hpu_zp'] = BackendInfo(device=["hpu"], sym=[True, False], + packing_format="hpu_zp+-1", + bits=[4], + priority=0, + convertable_format=["triton_zp+-1"]) + + +def check_compatible(backend_name, device, bits, group_size, sym, packing_format, in_features, out_features): + """Checks if the given configuration is compatible with the specified backend. + + Args: + backend_name (str): The name of the backend to check compatibility for. + device (str): The device on which the backend operates (e.g., 'cuda', 'cpu'). + bits (int): The bit-width of the quantization (e.g., 2, 4, 8). + group_size (Optional[int]): The size of the quantization group. Can be None if + not required by the backend. + sym (bool): Whether symmetric quantization is required (True for symmetric). + packing_format (str): The packing format used by the backend (e.g., 'triton'). + in_features (int): The number of input features for the model layer. + out_features (int): The number of output features for the model layer. + + Returns: + bool: True if the configuration is compatible with the backend, False otherwise. + + Raises: + KeyError: If the backend_name is not found in BackendInfos. + + Compatibility checks: + - Device must match one of the backend's supported devices. + - Bit-width must be supported by the backend. + - If group_size is required by the backend, it must match. + - Symmetric or asymmetric quantization must be supported. + - If the packing format matches exactly, all feature checks must pass. + - If the packing format does not match, it must be convertible. + """ + backend = BackendInfos[backend_name] + + # Check if device is supported by the backend + if not device in backend.device: + return False + + # Check if bit-width is supported + if bits not in backend.bits: + return False + + # Check if group_size is valid (if required by backend) + if backend.group_size is not None and group_size not in backend.group_size: + return False + + # Check if symmetric/asymmetric quantization is supported + if sym not in backend.sym: + return False + + # Check packing format and apply feature checks + if packing_format == backend.packing_format: + for check in backend.feature_checks: + if not check(in_features, out_features): + return False + + # Check if the format is convertible when packing formats differ + if packing_format != backend.packing_format and packing_format not in backend.convertable_format: + return False + + return True + + +def dynamic_import_inference_linear(backend, bits, group_size, sym): + """Dynamically imports and returns the appropriate QuantLinear class based on the given backend. + + This function dynamically loads the correct `QuantLinear` class based on the backend and quantization + configuration (e.g., qbits, marlin, hpu, gptq, awq, auto_round). It imports specific modules or raises + errors if the required packages are not installed or the environment is not set up. + + Args: + backend (str): + The backend to be used for quantization (e.g., 'qbits', 'marlin', 'hpu', 'gptq', 'awq', 'auto_round'). + bits (int): + The number of bits to be used for quantization. + group_size (Optional[int]): + The size of the quantization group (if applicable). + sym (bool): + Whether symmetric quantization is required. + + Returns: + class: + The dynamically imported QuantLinear class that corresponds to the given backend configuration. + + Raises: + ImportError: + If required modules are missing for a backend (e.g., Intel Extension, GPTQ, auto_awq). + """ + if "qbits" in backend: + try: + from intel_extension_for_transformers import qbits # pylint: disable=E0401 + except Exception as e: + raise ImportError( + "Please install Intel Extension for Transformers via 'pip install " + "intel-extension-for-transformers' to inference on X86 CPU" + ) + if "zp" in backend: + import auto_round_extension.qbits.qlinear_qbits_gptq as qlinear_qbits_gptq + return qlinear_qbits_gptq.QuantLinear + else: # auto_round must be at the end + import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits_autoround + return qlinear_qbits_autoround.QuantLinear + + if "marlin" in backend: + from transformers.utils.versions import require_version + require_version( + "gptqmodel", + "marlin format requires gptqmodel to be installed, `pip install -v gptqmodel --no-build-isolation`" + ) + from gptqmodel.nn_modules.qlinear.qlinear_marlin_inference import \ + MarlinInferenceQuantLinear # pylint: disable=E0401 + return MarlinInferenceQuantLinear + + if "hpu" in backend: + try: + import habana_frameworks.torch.hpu # pylint: disable=E0401 + except ImportError: + raise ImportError("Please setup hpu environment before using hpu backend") + + if "zp" in backend: + from auto_round_extension.hpu.qlinear_hpu_gptq import QuantLinear as QuantLinear_gptq + return QuantLinear_gptq + else: # auto_round must be at the end + from auto_round_extension.hpu.qlinear_hpu import QuantLinear + return QuantLinear + + if "gptq" in backend: + return get_autogptq_infer_linear(backend, bits, group_size, sym) + + if "awq" in backend: + try: + from awq.modules.linear import WQLinear_GEMM # pylint: disable=E0401 + except ImportError: + raise ImportError( + "autoawq is required. Please install it by 'pip install autoawq' to support auto_awq format.") + return WQLinear_GEMM + + if "auto_round" in backend: + if "exllamav2" in backend: + import auto_round_extension.cuda.qlinear_exllamav2 + return auto_round_extension.cuda.qlinear_exllamav2.QuantLinear + else: + import auto_round_extension.cuda.qlinear_tritonv2 + return auto_round_extension.cuda.qlinear_tritonv2.QuantLinear + + +def get_autogptq_infer_linear(backend, bits=4, group_size=128, sym=False): + """Returns the appropriate QuantLinear class based on backend configuration. + + This function selects and dynamically imports the `QuantLinear` class according to the specified backend + and its features, such as using Triton, ExLlama, Marlin, or Qigen for quantization. + + Args: + backend (str): + The backend to be used for quantization (e.g., 'triton', 'qigen', 'marlin', 'exllamav2'). + bits (int, optional): + The number of bits used for quantization. Default is 4. + group_size (int, optional): + The group size for quantization. Default is 128. + sym (bool, optional): + Whether symmetric quantization is enabled. Default is False. + + Returns: + class: + The dynamically imported QuantLinear class for the given configuration. + + Raises: + ImportError: + If required packages or backends are not installed. + """ + use_triton = False + disable_exllamav2 = False + disable_exllamav1 = False + disable_marlin = True + use_qigen = False + use_tritonv2 = False + + # Determine backend configurations based on input string + if "qigen" in backend: + use_qigen = True + elif "triton" in backend: + use_triton = True + elif "tritonv2" in backend: + use_triton = False + use_tritonv2 = True + elif "marlin" in backend: + use_triton = False + disable_marlin = False + elif "exllamav2" in backend: + use_triton = False + disable_exllamav2 = False + disable_marlin = True + elif "exllamav1" in backend: + use_triton = False + disable_marlin = True + elif "cuda" in backend: + use_triton = False + disable_marlin = True + disable_exllamav2 = True + disable_exllamav1 = True + + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear # pylint: disable=E0401 + version = get_library_version("auto_gptq") + from packaging.version import Version + + # Import the appropriate QuantLinear based on the version of auto_gptq + if Version(version) <= Version("0.7.1"): + QuantLinear = dynamically_import_QuantLinear( + use_triton=use_triton, + desc_act=False, + group_size=group_size, + bits=bits, + disable_exllama=disable_exllamav1, + disable_exllamav2=disable_exllamav2, + use_qigen=use_qigen, + disable_marlin=disable_marlin + ) + else: + QuantLinear = dynamically_import_QuantLinear( # pylint: disable=E1123 + use_triton=use_triton, + desc_act=False, + group_size=group_size, + bits=bits, + disable_exllama=disable_exllamav1, + disable_exllamav2=disable_exllamav2, + use_qigen=use_qigen, + use_marlin=not disable_marlin, + use_tritonv2=use_tritonv2 + ) + + return QuantLinear + + +def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_features, out_features): + """Selects the most suitable backend for the layer based on compatibility and priority. + + This function first checks if the specified backend supports the layer with the provided configuration. + If not, it iterates through other available backends, + checking compatibility and returning the one with the highest priority. + + Args: + device (str): + The device on which the layer will run, e.g., 'cpu', 'cuda'. + backend (str): + The target backend to be used for this layer. + orig_backend (str): + The original backend from which packing format information is retrieved. + bits (int): + The number of bits used for quantization. + group_size (int): + The group size for quantization. + sym (bool): + Whether symmetric quantization is enabled. + in_features (int): + The number of input features for the layer. + out_features (int): + The number of output features for the layer. + + Returns: + str: + The selected backend that is compatible with the layer configuration. + + Raises: + AssertionError: + If the specified backend is not supported. + ValueError: + If no compatible backend is found for the given layer configuration. + """ + # Check if the provided backend is in BackendInfos + assert backend in BackendInfos.keys(), \ + f"Unsupported backend {backend}, please set it to `auto` to try automatic selection" + + packing_format = BackendInfos[orig_backend].packing_format + + # Check if the provided backend supports the layer configuration + if check_compatible(backend, device, bits, group_size, sym, packing_format, in_features, out_features): + return backend + + # Find and store other compatible backends + supported_backends = [] + for key in BackendInfos.keys(): + if key == backend: + continue + if check_compatible(key, device, bits, group_size, sym, packing_format, in_features, out_features): + supported_backends.append(key) + + # Raise an error if no compatible backends are found + if len(supported_backends) == 0: + raise ValueError(f"None of the backends support this layer") + + # Sort the compatible backends by priority and return the one with the highest priority + supported_backends = sorted(supported_backends, key=lambda support_backend: BackendInfos[support_backend].priority, + reverse=True) + + return supported_backends[0] + + +if __name__ == "__main__": + res = get_layer_backend("cuda", "gptq:exllamav2", "gptq:exllamav2", 4, 128, sym=False, in_features=128, + out_features=128) + assert res == "gptq:exllamav2" + + res = get_layer_backend("cuda", "gptq:exllamav2", "gptq:exllamav2", 2, 128, sym=False, in_features=128, + out_features=128) + assert res == "gptq:tritonv2" + + res = get_layer_backend("cpu", "auto_round:exllamav2", "auto_round:exllamav2", 4, 128, sym=False, in_features=128, + out_features=128) + assert res == "auto_round:qbits" + + res = get_layer_backend("cpu", "gptq:exllamav2", "gptq:exllamav2", 4, 128, sym=False, in_features=128, + out_features=128) + assert res == "auto_round:qbits_zp" diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index e9cace5b..0ca355c3 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -181,8 +181,8 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex backend = backend.replace("autoround", "auto_round") backend = backend.replace("auto-round", "auto_round") ##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source - if (kwargs.get("sym") is None or kwargs.get("sym") == True) and ("gptq" not in backend and "awq" not in backend): - backend = backend.replace('auto_round','auto_round:gptq') + if (kwargs.get("sym") is None or kwargs.get("sym") == True) and ("gptq" not in backend and "awq" not in backend): + backend = backend.replace('auto_round', 'auto_round:gptq') if not ("triton" in backend or "exllamav2" in backend or "awq" in backend or "gptq" in backend): logger.info(f"AutoRound format does not support {backend}, try to pack each layer with AutoGPTQ") @@ -198,10 +198,10 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex layer_config = kwargs["layer_config"] quantization_config = kwargs["serialization_dict"] quantization_config["quant_method"] = "intel/auto-round" + + quantization_config["backend"] = backend tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) - if "awq" not in backend: - quantization_config["backend"] = backend extra_config = {} for layer_name in layer_config: if layer_name not in layer_names_in_block and layer_config[layer_name]["bits"] <= 8: ##lm head @@ -243,17 +243,11 @@ def wrapper(name): return model if tokenizer is not None: tokenizer.save_pretrained(output_dir) + if processor is not None: processor.save_pretrained(output_dir) - modules_to_not_convert = [] - if "awq" not in backend: - save(model, output_dir, safe_serialization=safe_serialization) - else: - for name in layer_config.keys(): - config = kwargs["layer_config"][name] - if config["bits"] > 8: - modules_to_not_convert.append(name) - save_awq(model, output_dir, modules_to_not_convert=modules_to_not_convert) + save(model, output_dir, safe_serialization=safe_serialization) + return model @@ -284,41 +278,3 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f: json.dump(model.config.quantization_config, f, indent=2) - -def save_awq( - model: nn.Module, - save_dir: str, - max_shard_size: str = "5GB", - safe_serialization: bool = True, - modules_to_not_convert: list = [], -): - """Save model state dict and configs. - - Args: - model (`nn.Module`): - Model to be saved. The model can be wrapped or unwrapped. - save_dir (`str`): - Directory to which to save. Will be created if it doesn't exist. - max_shard_size (`str`, defaults to `"10GB"`): - The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size - lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). - - - If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard - which will be bigger than `max_shard_size`. - - - safe_serialization (`bool`, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - """ - os.makedirs(save_dir, exist_ok=True) - quantization_config = model.config.quantization_config - model.config.quantization_config["quant_method"] = "awq" - model.config.quantization_config["modules_to_not_convert"] = None if not modules_to_not_convert \ - else modules_to_not_convert - model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) - config_file = "quantization_config.json" - if hasattr(model, "config") and hasattr(model.config, "quantization_config"): - with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f: - json.dump(quantization_config, f, indent=2) - diff --git a/auto_round/utils.py b/auto_round/utils.py index a3913ce1..4d0da4a4 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -27,6 +27,7 @@ from torch.amp import autocast from functools import lru_cache +from packaging import version @lru_cache(None) @@ -760,126 +761,18 @@ def is_autoround_exllamav2_available(): return res -def get_autogptq_infer_linear(backend, bits=4, group_size=128, sym=False): - use_triton = False - disable_exllamav2 = False - disable_exllamav1 = False - disable_marlin = True - use_qigen = False - use_tritonv2 = False - if "qigen" in backend: - use_qigen = True - elif "triton" in backend: - use_triton = True - elif "tritonv2" in backend: - use_triton = False - use_tritonv2 = True - elif "marlin" in backend: - use_triton = False - disable_marlin = False - elif "exllamav2" in backend: - use_triton = False - disable_exllamav2 = False - disable_marlin = True - elif "exllamav1" in backend: - use_triton = False - disable_marlin = True - elif "cuda" in backend: - use_triton = False - disable_marlin = True - disable_exllamav2 = True - disable_exllamav1 = True - from auto_gptq.utils.import_utils import dynamically_import_QuantLinear # pylint: disable=E0401 - version = get_library_version("auto_gptq") - from packaging.version import Version - if Version(version) <= Version("0.7.1"): - QuantLinear = dynamically_import_QuantLinear( - use_triton=use_triton, - desc_act=False, - group_size=group_size, - bits=bits, - disable_exllama=disable_exllamav1, - disable_exllamav2=disable_exllamav2, - use_qigen=use_qigen, - disable_marlin=disable_marlin - ) - else: - QuantLinear = dynamically_import_QuantLinear( # pylint: disable=E1123 - use_triton=use_triton, - desc_act=False, - group_size=group_size, - bits=bits, - disable_exllama=disable_exllamav1, - disable_exllamav2=disable_exllamav2, - use_qigen=use_qigen, - use_marlin=not disable_marlin, - use_tritonv2=use_tritonv2 - ) - return QuantLinear - -def dynamic_import_inference_linear(backend, bits, group_size, sym): - """Dynamically imports and returns the appropriate QuantLinear class based on the given bits and backend. - Args: - bits (int): - The number of bits for quantization. - backend (str): - The backend to be used for quantization, such as "qbits", "cpu", or "exllamav2". - - Returns: - class: - The appropriate QuantLinear class for the given configuration. - """ - exllama2_available = is_autoround_exllamav2_available() - ##TODO may have bug for marlin backend - if (not torch.cuda.is_available() and not is_optimum_habana_available()) or "qbits" in backend or "cpu" in backend: - try: - from intel_extension_for_transformers import qbits # pylint: disable=E0401 - except Exception as e: - raise ImportError("Please install Intel Extension for Transformers via 'pip install " - "intel-extension-for-transformers' to inference on X86 CPU") - import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits - return qlinear_qbits.QuantLinear - if "gptq" in backend: - if not is_optimum_habana_available(): - try: - import auto_gptq # pylint: disable=E0401 - except Exception as e: - raise ImportError("Please install auto-gptq via 'pip install auto-gptq' to support GPTQ backend ") - return get_autogptq_infer_linear(backend, bits, group_size, sym) - else: # pragma: no cover - try: - import habana_frameworks.torch.hpu # noqa: F401 # pylint: disable=E0401 - except Exception as e: - pass - else: - from auto_round_extension.hpu.qlinear_hpu_gptq import QuantLinear - return QuantLinear - if (bits == 4 and is_optimum_habana_available()) or "hpu" in backend: # pragma: no cover - try: - import habana_frameworks.torch.hpu # noqa: F401 # pylint: disable=E0401 - except Exception as e: - pass - else: - from auto_round_extension.hpu.qlinear_hpu import QuantLinear - return QuantLinear - if "awq" in backend: # pragma: no cover - try: - from awq.modules.linear import WQLinear_GEMM # pylint: disable=E0401 - except: - raise ImportError("autoawq is required. Please install it by 'pip install autoawq' to \ - support auto_awq format.") - return WQLinear_GEMM - if bits == 4 and exllama2_available and "exllamav2" in backend: - from auto_round_extension.cuda.qlinear_exllamav2 import QuantLinear - elif bits == 4 and "exllamav2" in backend: - logger.warning_once("Please install auto-round from source to enable exllamav2 kernels, switch to triton " - "kernels for now") - from auto_round_extension.cuda.qlinear_tritonv2 import QuantLinear - else: - from auto_round_extension.cuda.qlinear_tritonv2 import QuantLinear - return QuantLinear +def is_hpu_supported(): # pragma: no cover + try: + import subprocess + import habana_frameworks.torch.core as htcore # pylint: disable=E0401 + hqt_version = subprocess.check_output(['pip', 'show', \ + 'habana_quantization_toolkit']).decode().split('\n')[1].split(': ')[1] + assert (hqt_version >= "1.17") + except ImportError as e: + return False + return True def get_library_version(library_name): @@ -964,7 +857,7 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False): disable_marlin=disable_marlin, ) else: - QuantLinear = dynamically_import_QuantLinear(# pylint: disable=E1123 + QuantLinear = dynamically_import_QuantLinear( # pylint: disable=E1123 use_triton=use_triton, desc_act=False, group_size=group_size, @@ -975,6 +868,3 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False): use_marlin=not disable_marlin, ) return QuantLinear - - - diff --git a/auto_round_extension/qbits/qlinear_qbits_gptq.py b/auto_round_extension/qbits/qlinear_qbits_gptq.py new file mode 100644 index 00000000..658adead --- /dev/null +++ b/auto_round_extension/qbits/qlinear_qbits_gptq.py @@ -0,0 +1,227 @@ + +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +from auto_round.utils import convert_dtype_torch2str, logger +QBITS_AVAILABLE = True + +BITS_DTYPE_MAPPING = { + 2: "int2_clip", + 4: "int4_clip", + 8: "int8", +} + + +class QuantLinear(nn.Module): + QUANT_TYPE = "qbits" + + def __init__( + self, + bits, + group_size, + infeatures, + outfeatures, + bias, + kernel_switch_threshold=128, + trainable=False, + weight_dtype=torch.bfloat16, + **kwargs, + ): + super().__init__() + + if bits not in [2, 4, 8]: + raise NotImplementedError( + "Only 2, 4,8 bits are supported for QBits.") + + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.group_size = group_size if group_size != -1 else infeatures + self.maxq = 2**self.bits - 1 + self.weight_dtype = weight_dtype + self.asym = True + self.qbits = None + + self.register_buffer( + "qweight", + torch.zeros((infeatures // 32 * self.bits, + outfeatures), dtype=torch.int32), + ) + self.register_buffer( + "qzeros", + torch.zeros( + ( + math.ceil(infeatures / self.group_size), + outfeatures // 32 * self.bits, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (math.ceil(infeatures / self.group_size), outfeatures), + dtype=weight_dtype, + ), + ) + if bias: + self.register_buffer("bias", torch.zeros( + (outfeatures), dtype=torch.float)) + else: + self.bias = None + + self.kernel_switch_threshold = kernel_switch_threshold + + self.trainable = trainable + + def req_check(self): + torch_version = str(torch.__version__) + if QBITS_AVAILABLE: + import intel_extension_for_transformers + itrex_version = str(intel_extension_for_transformers.__version__) + version_match_map = {"1.4": "2.2.0+cpu", + "1.4.1": "2.2.0+cpu", "1.4.2": "2.3.0+cpu"} + if itrex_version in version_match_map: + if torch_version != version_match_map[itrex_version]: + logger.warning( + f"Please install torch {version_match_map[itrex_version]} by command 'pip install torch=={version_match_map[itrex_version]} --extra-index-url https://download.pytorch.org/whl/cpu' as Intel Extension for Transformers {itrex_version} is not compatible with current torch.") + else: + logger.error( + "Please install Intel Extension for Transformers by running 'pip install intel-extension-for-transformers' as qbits linear requirements checking fail. ") + exit(1) + + def post_init(self): + import intel_extension_for_transformers + self.qbits = intel_extension_for_transformers.qbits + assert self.qweight.device.type == "cpu" + if self.bias is not None: + self.bias = self.bias.to(dtype=torch.float32) + + # intweight: k x n, zeros: k / group_size x n + intweight, zeros = unpack_to_8bit_signed( + self.qweight, self.qzeros, self.bits) + if zeros is None: + zeros = torch.empty(0, dtype=torch.int8) + self.asym = False + else: + # change it to int8 with offset 128 + if self.bits == 8: + zeros = (zeros.to(torch.int32) - + (2 ** (self.bits - 1))).to(torch.int8) + else: + zeros -= (2**(self.bits - 1)) + + if not self.asym: + intweight -= (2**(self.bits - 1)) + intweight = intweight.to(torch.uint8 if self.asym else torch.int8) + # due to asym return torch.uint8 but backend request int8, + # change it to int8 with offset 128 + if self.asym: + intweight = (intweight.to(torch.int32) - + (2 ** (self.bits - 1))).to(torch.int8) + + scales = self.scales + + logger.debug( + f"QBits repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}") + self.qweight = self.qbits.repack_quantized_weight(intweight.contiguous(), scales.float().contiguous(), zeros.contiguous(), torch.empty(0), + # weight_dtype + BITS_DTYPE_MAPPING[self.bits], + # scale_dtype + "fp32", + # TODO(zhe): consider dynamic-set cmpt for better perf? + "fp32", + self.asym, + self.group_size) + # free mem + self.qzeros = torch.empty(0) + self.scales = torch.empty(0) + + def forward(self, x: torch.Tensor): + raw_input_dtype = x.dtype + if raw_input_dtype != torch.float32: + x = x.to(torch.float32) + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.view(-1, x.shape[-1]) # convert xd to 2d + out_2d_shape = x.shape[:-1] + (self.outfeatures,) + + outputs = torch.zeros(out_2d_shape, device=x.device, dtype=torch.float) + bias = self.bias if self.bias is not None else torch.empty( + 0, dtype=torch.float) + + self.qbits.woq_linear(x, self.qweight, bias, outputs, + convert_dtype_torch2str(torch.float), # compute_dtype + BITS_DTYPE_MAPPING[self.bits], # weight_dtype + "fp32", # scale_dtype + self.asym) + return outputs.to(raw_input_dtype).view(out_shape) + + +@torch.no_grad() +def unpack_to_8bit_signed(qweight, qzeros, bits): + wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0) + zeros = None + if not torch.all(torch.eq(qzeros, 2004318071 if bits == 4 else 0b01111111011111110111111101111111)): + zp_shape = list(qzeros.shape) + zp_shape[1] = zp_shape[1] * (32 // bits) + + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, + 32 // bits), wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) + if bits == 8: + zeros = zeros.to(torch.uint8) + zeros+=1 + try: + zeros = zeros.reshape(zp_shape) + except: + # zeros and scales have different iteam numbers. + # remove 1 (due to 0 + 1 in line 252) + zeros = zeros[zeros != 1] + zeros = zeros.reshape(zp_shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, 32 // + bits, -1), wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8) + weight.bitwise_and_((2**bits) - 1) + weight = weight.view(-1, weight.shape[-1]) + + return weight, zeros + + +# Copied from qlinear_marlin.py +@torch.no_grad() +def dequantize_weight(qweight, qzeros, scales, bits): + unpacked_qweight, unpacked_qzeros = unpack_to_8bit_signed( + qweight, qzeros, bits) + group_size = unpacked_qweight.shape[0] // scales.shape[0] + scales = scales.repeat_interleave(group_size, dim=0) + if unpacked_qzeros is not None: + unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) + else: + unpacked_qzeros = torch.full_like( + scales, 8 if bits == 4 else 128, dtype=torch.int32) + unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales + + return unpacked_qweight, unpacked_qzeros + + +__all__ = ["QuantLinear"] diff --git a/examples/language-modeling/main.py b/examples/language-modeling/main.py index a8608157..a72a48bc 100644 --- a/examples/language-modeling/main.py +++ b/examples/language-modeling/main.py @@ -352,7 +352,9 @@ format_list = args.format.replace(' ', '').split(',') inplace = False if len(format_list) > 1 else True for format_ in format_list: - eval_folder = f'{export_dir}-{format_}' + save_format_ = format_.replace(":", "-") + save_format_ = save_format_.replace("_", "-") + eval_folder = f'{export_dir}-{save_format_}' autoround.save_quantized(eval_folder, format=format_, inplace=inplace) else: deployment_device = args.deployment_device.split(',') diff --git a/examples/language-modeling/run_autoround.sh b/examples/language-modeling/run_autoround.sh index 38a907f5..2f875ff5 100644 --- a/examples/language-modeling/run_autoround.sh +++ b/examples/language-modeling/run_autoround.sh @@ -9,9 +9,8 @@ python3 main.py \ --bits 4 \ --group_size 128 \ --iters 200 \ ---deployment_device 'fake,cpu' \ ---scale_dtype 'fp32' \ ---eval_bs 32 \ +--format 'auto_round,auto_gptq' \ +--eval_bs 16 \ --output_dir "./tmp_autoround" diff --git a/examples/language-modeling/run_autoround_on_gaudi.sh b/examples/language-modeling/run_autoround_on_gaudi.sh index 2dd0f4b6..6f0cf5cf 100644 --- a/examples/language-modeling/run_autoround_on_gaudi.sh +++ b/examples/language-modeling/run_autoround_on_gaudi.sh @@ -6,7 +6,8 @@ python3 main.py \ --model_name $model_name \ --group_size 128 \ --bits 4 \ - --deployment_device "fake" \ + --device 0 \ + --format "auto_round" \ --output_dir "./tmp_autoround" diff --git a/examples/language-modeling/run_xpu.sh b/examples/language-modeling/run_xpu.sh index 26286f9c..f5215725 100644 --- a/examples/language-modeling/run_xpu.sh +++ b/examples/language-modeling/run_xpu.sh @@ -1,6 +1,6 @@ #!/bin/bash set -x -device=1 +device=0 model_name="facebook/opt-125m" CUDA_VISIBLE_DEVICES=$device \ @@ -8,9 +8,6 @@ python3 main.py \ --model_name $model_name \ --bits 4 \ --group_size 32 \ ---deployment_device "xpu,fake" \ ---disable_low_gpu_mem_usage \ ---disable_eval \ ---scale_dtype "fp16" \ ---output_dir "./tmp_autoround" \ +--format "auto_gptq" \ +--output_dir "./tmp_autoround" diff --git a/test/test_generation.py b/test/test_generation.py new file mode 100644 index 00000000..462d2af4 --- /dev/null +++ b/test/test_generation.py @@ -0,0 +1,177 @@ +import copy +import shutil +import sys +import unittest + +sys.path.insert(0, "..") +import torch +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round import AutoRound + + +class LLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.ones([1, 10], dtype=torch.long) + + +class TestAutoRoundFormatGeneration(unittest.TestCase): + @classmethod + def setUpClass(self): + model_name = "facebook/opt-125m" + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.llm_dataloader = LLMDataLoader() + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_llm_generation_sym_gpu_gptq(self): + if not torch.cuda.is_available(): + return + bits = 4 + group_size = 32 + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=True, + iters=1, + seqlen=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + quantized_model_path = "./saved" + + autoround.save_quantized(output_dir=quantized_model_path, format="auto_round:gptq",inplace=False) + device = "auto" ##cpu, hpu, cuda + from auto_round import AutoRoundConfig + quantization_config = AutoRoundConfig( + backend=device + ) + + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, + device_map=device, quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + res = tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]) + assert ( + res == """There is a girl who likes adventure, and I'm not sure if she's into it, but I'm sure she's into it.\nI'm not sure if she's into adventure, but I'm sure she's into it.\nI'm not sure if she's into adventure""") + + # # + # + # def test_llm_generation_sym_gpu_gptq_marlin(self): ##need auto_gptq >0.7.1 + # if not torch.cuda.is_available(): + # return + # bits = 4 + # group_size = 128 + # autoround = AutoRound( + # self.model, + # self.tokenizer, + # bits=bits, + # group_size=group_size, + # sym=True, + # iters=1, + # seqlen=2, + # dataset=self.llm_dataloader, + # ) + # autoround.quantize() + # quantized_model_path = "./saved" + # + # autoround.save_quantized(output_dir=quantized_model_path, format="auto_round:marlin") + # device = "auto" ##cpu, hpu, cuda + # from auto_round import AutoRoundConfig + # quantization_config = AutoRoundConfig( + # backend=device + # ) + # + # model = AutoModelForCausalLM.from_pretrained(quantized_model_path, + # device_map=device, quantization_config=quantization_config) + # tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + # text = "There is a girl who likes adventure," + # inputs = tokenizer(text, return_tensors="pt").to(model.device) + # res = tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]) + # assert ( + # res == """There is a girl who likes adventure, and I'm not sure if she's into it, but I'm sure she's into it.\nI'm not sure if she's into adventure, but I'm sure she's into it.\nI'm not sure if she's into adventure""") + + + def test_llm_generation_asym_gpu_awq(self): + if not torch.cuda.is_available(): + return + bits = 4 + group_size = 32 + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=True, + iters=1, + seqlen=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + quantized_model_path = "./saved" + + autoround.save_quantized(output_dir=quantized_model_path, format="auto_round:awq",inplace=False) + device = "auto" ##cpu, hpu, cuda + from auto_round import AutoRoundConfig + quantization_config = AutoRoundConfig( + backend=device + ) + + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, + device_map=device, quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + res = tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]) + assert ( + res == """There is a girl who likes adventure, and I'm not sure if she's into it, but I'm sure she's into it.\nI'm not sure if she's into adventure, but I'm sure she's into it.\nI'm not sure if she's into adventure""") + + def test_llm_generation_asym_qbits(self): + try: + import intel_extension_for_transformers + except: + return + bits = 4 + group_size = 32 + autoround = AutoRound( + self.model, + self.tokenizer, + bits=bits, + group_size=group_size, + sym=True, + iters=1, + seqlen=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + quantized_model_path = "./saved" + + autoround.save_quantized(output_dir=quantized_model_path, format="auto_round",inplace=False) + device = "cpu" ##cpu, hpu, cuda + from auto_round import AutoRoundConfig + quantization_config = AutoRoundConfig( + backend=device + ) + + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, + device_map="cpu", quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + res = tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]) + assert ( + res == """There is a girl who likes adventure, and I'm not sure if she's into it, but I'm sure she's into it.\nI'm not sure if she's into adventure, but I'm sure she's into it.\nI'm not sure if she's into adventure""") + +