Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto round quantizer supports gptq kernel #155

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING
from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod

from auto_round.utils import get_module, set_module
from auto_round.utils import get_module, set_module, dynamic_import_inference_linear
import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits

logger = getLogger(__name__)
Expand Down Expand Up @@ -316,20 +316,6 @@ def convert_model(self, model: nn.Module):
self._replace_by_quant_layers(model, layer_configs, backend)
return model

def _dynamic_import_inference_linear(self, bits, backend):
if (not torch.cuda.is_available()) or "qbits" in backend or "cpu" in backend:
try:
from intel_extension_for_transformers import qbits # pylint: disable=E0401
except Exception as e:
raise ImportError("Please install Intel Extension for Transformers via 'pip install "
"intel-extension-for-transformers' to inference on X86 CPU")
return qlinear_qbits.QuantLinear
if bits == 4 and self.exllama2_available and "exllamav2" in backend:
from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear
else:
from auto_round_extension.cuda.qliner_triton import QuantLinear
return QuantLinear

def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
"""Replaces linear layers in `module` by `QuantLinear`

Expand All @@ -351,7 +337,7 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):

layer = get_module(module, layer_name)
device = get_device(layer)
QuantLinear = self._dynamic_import_inference_linear(bits, backend)
QuantLinear = dynamic_import_inference_linear(bits, group_size, backend)
if isinstance(layer, nn.Linear):
in_features = layer.in_features
out_features = layer.out_features
Expand Down
2 changes: 1 addition & 1 deletion auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="autoround:exl
save(model, output_dir)


def save(model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = True):
def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_serialization: bool = True):
"""Save model state dict and configs.

Args:
Expand Down
105 changes: 104 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import importlib
import transformers

from functools import lru_cache

class LazyImport(object):
"""Lazy import python module till use."""
Expand Down Expand Up @@ -560,3 +560,106 @@ def get_layer_names_in_block(model, supported_types=[torch.nn.Linear, transforme
if hasattr(m, "tmp_name"):
delattr(m, "tmp_name")
return layers_in_block


def is_autoround_exllamav2_available():
"""Checks if the AutoRound ExLlamaV2 kernels are available.

Returns:
bool:
True if the AutoRound ExLlamaV2 kernels are available, False otherwise.
"""
res = True
try:
from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix
except ImportError as e:
res = False
return res


def get_autogptq_backend_config(backend, bits=4):
use_triton = False
disable_exllamav2 = False
disable_exllamav1 = False
disable_marlin = True
use_qigen = False
if backend == "gptq:qigen":
use_qigen = True
if backend == "gptq:triton": ##TODO refine the code
use_triton = True
if backend == "gptq:marlin":
use_triton = False
disable_marlin = True
if backend == "gptq:exllamav2": ##need v1 code to export
use_triton = False
disable_marlin = True
if backend == "gptq:exllamav1":
use_triton = False
disable_marlin = True
if backend == "gptq:cuda":
use_triton = False
disable_marlin = True
disable_exllamav2 = True
disable_exllamav1 = True
if bits not in [2, 4, 8]:
use_qigen = False
if bits not in [2, 4]:
use_triton = False
return use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin

@lru_cache(None)
def warning_once(logger, msg: str):
logger.warning(msg)

logger.warning_once = warning_once
def dynamic_import_inference_linear(bits, group_size, backend):
"""Dynamically imports and returns the appropriate QuantLinear class based on the given bits and backend.

Args:
bits (int):
The number of bits for quantization.
backend (str):
The backend to be used for quantization, such as "qbits", "cpu", or "exllamav2".

Returns:
class:
The appropriate QuantLinear class for the given configuration.
"""
exllama2_available = is_autoround_exllamav2_available()

if (not torch.cuda.is_available()) or "qbits" in backend or "cpu" in backend:
try:
from intel_extension_for_transformers import qbits # pylint: disable=E0401
except Exception as e:
raise ImportError("Please install Intel Extension for Transformers via 'pip install "
"intel-extension-for-transformers' to inference on X86 CPU")
import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits
return qlinear_qbits.QuantLinear
if "gptq" in backend:
try:
import auto_gptq # pylint: disable=E0401
except Exception as e:
raise ImportError("Please install auto-gptq via 'pip install auto-gptq' to support GPTQ backend ")
use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
backend, bits
)
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear # pylint: disable=E0401
QuantLinear = dynamically_import_QuantLinear(
use_triton=use_triton,
desc_act=False,
group_size=group_size,
bits=bits,
disable_exllama=disable_exllamav1,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen,
disable_marlin=disable_marlin,
)
return QuantLinear
if bits == 4 and exllama2_available and "exllamav2" in backend:
from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear
elif bits == 4 and "exllamav2" in backend:
logger.warning_once("Please install auto-round from source to enable exllamav2 kernels, switch to triton "
"kernels for now")
else:
from auto_round_extension.cuda.qliner_triton import QuantLinear
return QuantLinear
2 changes: 1 addition & 1 deletion auto_round_extension/qbits/qlinear_qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def post_init(self):

scales = self.scales

logger.info(
logger.debug(
f"QBits repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}")
self.qweight = self.qbits.repack_quantized_weight(intweight.contiguous(), scales.float().contiguous(), zeros.contiguous(), torch.empty(0),
# weight_dtype
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def detect_local_sm_architectures():
import torch
except Exception as e:
print(
f"Building PyTorch CUDA extension requires PyTorch being installed, please install PyTorch first: {e}.\n NOTE: This issue may be raised due to pip build isolation system (ignoring local packages). Please use `--no-build-isolation` when installing with pip, and refer to https://github.com/AutoGPTQ/AutoGPTQ/pull/620 for more details.")
f"Building PyTorch CUDA extension requires PyTorch being installed, please install PyTorch first: {e}.\n NOTE: This issue may be raised due to pip build isolation system (ignoring local packages). Please use `--no-build-isolation` when installing with pip, and refer to https://github.com/AutoRound/AutoRound/pull/620 for more details.")
sys.exit(1)
if not torch.cuda.is_available():
print(
Expand All @@ -75,7 +75,7 @@ def detect_local_sm_architectures():
ROCM_VERSION = os.environ.get('ROCM_VERSION', None)
if ROCM_VERSION and not torch.version.hip:
print(
f"Trying to compile auto-gptq for ROCm, but PyTorch {torch.__version__} "
f"Trying to compile auto-round for ROCm, but PyTorch {torch.__version__} "
"is installed without ROCm support."
)
sys.exit(1)
Expand All @@ -89,7 +89,7 @@ def detect_local_sm_architectures():
else:
if not CUDA_VERSION:
print(
f"Trying to compile auto-gptq for CUDA, but Pytorch {torch.__version__} "
f"Trying to compile auto-round for CUDA, but Pytorch {torch.__version__} "
"is installed without CUDA support."
)
sys.exit(1)
Expand All @@ -102,13 +102,13 @@ def detect_local_sm_architectures():
requested_but_unsupported_archs = {arch for arch in archs if arch in UNSUPPORTED_COMPUTE_CAPABILITIES}
if len(requested_but_unsupported_archs) > 0:
raise ValueError(
f"Trying to compile AutoGPTQ for CUDA compute capabilities {torch_cuda_arch_list}, but AutoGPTQ does not support the compute capabilities {requested_but_unsupported_archs} (AutoGPTQ requires Pascal or higher). Please fix your environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139).")
f"Trying to compile AutoRound for CUDA compute capabilities {torch_cuda_arch_list}, but AutoRound does not support the compute capabilities {requested_but_unsupported_archs} (AutoRound requires Pascal or higher). Please fix your environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139).")
else:
local_arch_list = detect_local_sm_architectures()
local_but_unsupported_archs = {arch for arch in local_arch_list if arch in UNSUPPORTED_COMPUTE_CAPABILITIES}
if len(local_but_unsupported_archs) > 0:
raise ValueError(
f"PyTorch detected the compute capabilities {local_arch_list} for the NVIDIA GPUs on the current machine, but AutoGPTQ can not be built for compute capabilities {local_but_unsupported_archs} (AutoGPTQ requires Pascal or higher). Please set the environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139) with your necessary architectures.")
f"PyTorch detected the compute capabilities {local_arch_list} for the NVIDIA GPUs on the current machine, but AutoRound can not be built for compute capabilities {local_but_unsupported_archs} (AutoRound requires Pascal or higher). Please set the environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139) with your necessary architectures.")

# For the PyPI release, the version is simply x.x.x to comply with PEP 440.
if not PYPI_RELEASE:
Expand Down