Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 committed Jun 8, 2024
1 parent 65d73fc commit ae6687b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 17 deletions.
16 changes: 5 additions & 11 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def convert_model(self, model: nn.Module):

def _dynamic_import_inference_linear(self, bits, backend, device):
if (str(device) == "cpu" and not torch.cuda.is_available()) or "qbits" in backend:
try:
from intel_extension_for_transformers import qbits # noqa: F401
except Exception as e:
raise ImportError("Please install Intel Extension for Transformers via 'pip install "
"intel-extension-for-transformers' to inference on Intel CPU")
return qlinear_qbits.QuantLinear
if bits == 4 and self.exllama2_available and "exllamav2" in backend:
from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear
Expand Down Expand Up @@ -366,17 +371,6 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
weight_dtype=layer.weight.dtype,
)

# if new_layer.qweight.device.type == "cpu": # fallback to qbits linear when qweight on cpu device
# QuantLinear = qlinear_qbits.QuantLinear
# new_layer = QuantLinear( # pylint: disable=E1123
# bits,
# group_size,
# in_features,
# out_features,
# bias,
# weight_dtype=layer.weight.dtype,
# )

new_layer.device = device
set_module(module, layer_name, new_layer)

Expand Down
6 changes: 0 additions & 6 deletions auto_round_extension/qbits/qlinear_qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
import torch.nn as nn
from auto_round.utils import convert_dtype_torch2str, logger
QBITS_AVAILABLE = True
try:
from intel_extension_for_transformers import qbits # noqa: F401
except Exception as e:
QBITS_AVAILABLE = False
# logger.warning(
# "qlinear_qbits should be used with Intel Extension for Transformers.")

BITS_DTYPE_MAPPING = {
2: "int2_clip",
Expand Down

0 comments on commit ae6687b

Please sign in to comment.