diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index 8a4952fc..a0e08db4 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -41,7 +41,7 @@ from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod -from auto_round.utils import get_module, set_module +from auto_round.utils import get_module, set_module, dynamic_import_inference_linear import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits logger = getLogger(__name__) @@ -316,20 +316,6 @@ def convert_model(self, model: nn.Module): self._replace_by_quant_layers(model, layer_configs, backend) return model - def _dynamic_import_inference_linear(self, bits, backend): - if (not torch.cuda.is_available()) or "qbits" in backend or "cpu" in backend: - try: - from intel_extension_for_transformers import qbits # pylint: disable=E0401 - except Exception as e: - raise ImportError("Please install Intel Extension for Transformers via 'pip install " - "intel-extension-for-transformers' to inference on X86 CPU") - return qlinear_qbits.QuantLinear - if bits == 4 and self.exllama2_available and "exllamav2" in backend: - from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear - else: - from auto_round_extension.cuda.qliner_triton import QuantLinear - return QuantLinear - def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend): """Replaces linear layers in `module` by `QuantLinear` @@ -351,7 +337,7 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend): layer = get_module(module, layer_name) device = get_device(layer) - QuantLinear = self._dynamic_import_inference_linear(bits, backend) + QuantLinear = dynamic_import_inference_linear(bits, group_size, backend) if isinstance(layer, nn.Linear): in_features = layer.in_features out_features = layer.out_features diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index f89a03aa..ce7dc07c 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -180,7 +180,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="autoround:exl save(model, output_dir) -def save(model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = True): +def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_serialization: bool = True): """Save model state dict and configs. Args: diff --git a/auto_round/utils.py b/auto_round/utils.py index a4a50a6e..47c4ef2f 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -35,7 +35,7 @@ import importlib import transformers - +from functools import lru_cache class LazyImport(object): """Lazy import python module till use.""" @@ -560,3 +560,106 @@ def get_layer_names_in_block(model, supported_types=[torch.nn.Linear, transforme if hasattr(m, "tmp_name"): delattr(m, "tmp_name") return layers_in_block + + +def is_autoround_exllamav2_available(): + """Checks if the AutoRound ExLlamaV2 kernels are available. + + Returns: + bool: + True if the AutoRound ExLlamaV2 kernels are available, False otherwise. + """ + res = True + try: + from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix + except ImportError as e: + res = False + return res + + +def get_autogptq_backend_config(backend, bits=4): + use_triton = False + disable_exllamav2 = False + disable_exllamav1 = False + disable_marlin = True + use_qigen = False + if backend == "gptq:qigen": + use_qigen = True + if backend == "gptq:triton": ##TODO refine the code + use_triton = True + if backend == "gptq:marlin": + use_triton = False + disable_marlin = True + if backend == "gptq:exllamav2": ##need v1 code to export + use_triton = False + disable_marlin = True + if backend == "gptq:exllamav1": + use_triton = False + disable_marlin = True + if backend == "gptq:cuda": + use_triton = False + disable_marlin = True + disable_exllamav2 = True + disable_exllamav1 = True + if bits not in [2, 4, 8]: + use_qigen = False + if bits not in [2, 4]: + use_triton = False + return use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin + +@lru_cache(None) +def warning_once(logger, msg: str): + logger.warning(msg) + +logger.warning_once = warning_once +def dynamic_import_inference_linear(bits, group_size, backend): + """Dynamically imports and returns the appropriate QuantLinear class based on the given bits and backend. + + Args: + bits (int): + The number of bits for quantization. + backend (str): + The backend to be used for quantization, such as "qbits", "cpu", or "exllamav2". + + Returns: + class: + The appropriate QuantLinear class for the given configuration. + """ + exllama2_available = is_autoround_exllamav2_available() + + if (not torch.cuda.is_available()) or "qbits" in backend or "cpu" in backend: + try: + from intel_extension_for_transformers import qbits # pylint: disable=E0401 + except Exception as e: + raise ImportError("Please install Intel Extension for Transformers via 'pip install " + "intel-extension-for-transformers' to inference on X86 CPU") + import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits + return qlinear_qbits.QuantLinear + if "gptq" in backend: + try: + import auto_gptq # pylint: disable=E0401 + except Exception as e: + raise ImportError("Please install auto-gptq via 'pip install auto-gptq' to support GPTQ backend ") + use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config( + backend, bits + ) + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear # pylint: disable=E0401 + QuantLinear = dynamically_import_QuantLinear( + use_triton=use_triton, + desc_act=False, + group_size=group_size, + bits=bits, + disable_exllama=disable_exllamav1, + disable_exllamav2=disable_exllamav2, + use_qigen=use_qigen, + disable_marlin=disable_marlin, + ) + return QuantLinear + if bits == 4 and exllama2_available and "exllamav2" in backend: + from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear + elif bits == 4 and "exllamav2" in backend: + logger.warning_once("Please install auto-round from source to enable exllamav2 kernels, switch to triton " + "kernels for now") + else: + from auto_round_extension.cuda.qliner_triton import QuantLinear + return QuantLinear diff --git a/auto_round_extension/qbits/qlinear_qbits.py b/auto_round_extension/qbits/qlinear_qbits.py index 1c567946..0ef74668 100644 --- a/auto_round_extension/qbits/qlinear_qbits.py +++ b/auto_round_extension/qbits/qlinear_qbits.py @@ -138,7 +138,7 @@ def post_init(self): scales = self.scales - logger.info( + logger.debug( f"QBits repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}") self.qweight = self.qbits.repack_quantized_weight(intweight.contiguous(), scales.float().contiguous(), zeros.contiguous(), torch.empty(0), # weight_dtype diff --git a/setup.py b/setup.py index fd244df2..55cf381a 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ def detect_local_sm_architectures(): import torch except Exception as e: print( - f"Building PyTorch CUDA extension requires PyTorch being installed, please install PyTorch first: {e}.\n NOTE: This issue may be raised due to pip build isolation system (ignoring local packages). Please use `--no-build-isolation` when installing with pip, and refer to https://github.com/AutoGPTQ/AutoGPTQ/pull/620 for more details.") + f"Building PyTorch CUDA extension requires PyTorch being installed, please install PyTorch first: {e}.\n NOTE: This issue may be raised due to pip build isolation system (ignoring local packages). Please use `--no-build-isolation` when installing with pip, and refer to https://github.com/AutoRound/AutoRound/pull/620 for more details.") sys.exit(1) if not torch.cuda.is_available(): print( @@ -75,7 +75,7 @@ def detect_local_sm_architectures(): ROCM_VERSION = os.environ.get('ROCM_VERSION', None) if ROCM_VERSION and not torch.version.hip: print( - f"Trying to compile auto-gptq for ROCm, but PyTorch {torch.__version__} " + f"Trying to compile auto-round for ROCm, but PyTorch {torch.__version__} " "is installed without ROCm support." ) sys.exit(1) @@ -89,7 +89,7 @@ def detect_local_sm_architectures(): else: if not CUDA_VERSION: print( - f"Trying to compile auto-gptq for CUDA, but Pytorch {torch.__version__} " + f"Trying to compile auto-round for CUDA, but Pytorch {torch.__version__} " "is installed without CUDA support." ) sys.exit(1) @@ -102,13 +102,13 @@ def detect_local_sm_architectures(): requested_but_unsupported_archs = {arch for arch in archs if arch in UNSUPPORTED_COMPUTE_CAPABILITIES} if len(requested_but_unsupported_archs) > 0: raise ValueError( - f"Trying to compile AutoGPTQ for CUDA compute capabilities {torch_cuda_arch_list}, but AutoGPTQ does not support the compute capabilities {requested_but_unsupported_archs} (AutoGPTQ requires Pascal or higher). Please fix your environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139).") + f"Trying to compile AutoRound for CUDA compute capabilities {torch_cuda_arch_list}, but AutoRound does not support the compute capabilities {requested_but_unsupported_archs} (AutoRound requires Pascal or higher). Please fix your environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139).") else: local_arch_list = detect_local_sm_architectures() local_but_unsupported_archs = {arch for arch in local_arch_list if arch in UNSUPPORTED_COMPUTE_CAPABILITIES} if len(local_but_unsupported_archs) > 0: raise ValueError( - f"PyTorch detected the compute capabilities {local_arch_list} for the NVIDIA GPUs on the current machine, but AutoGPTQ can not be built for compute capabilities {local_but_unsupported_archs} (AutoGPTQ requires Pascal or higher). Please set the environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139) with your necessary architectures.") + f"PyTorch detected the compute capabilities {local_arch_list} for the NVIDIA GPUs on the current machine, but AutoRound can not be built for compute capabilities {local_but_unsupported_archs} (AutoRound requires Pascal or higher). Please set the environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139) with your necessary architectures.") # For the PyPI release, the version is simply x.x.x to comply with PEP 440. if not PYPI_RELEASE: