Skip to content

Commit

Permalink
Fix asym kernel issue by following autogptq's pr (#137)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
wenhuach21 and pre-commit-ci[bot] authored Jun 3, 2024
1 parent 4d2d259 commit 794cd90
Show file tree
Hide file tree
Showing 42 changed files with 4,844 additions and 123 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ image presents an overview of AutoRound.
<div align="left">

## What's New

* [2024/06] AutoRound format supports mixed bit-widths and group sizes for inference, resolving the significant performance drop issue with the asymmetric kernel
* [2024/05] Check out our updated paper on [arxiv](https://arxiv.org/pdf/2309.05516v4)
* [2024/05] AutoRound supports lm-head quantization, saving 0.7G for LLaMA3-8B at W4G128.
* [2024/05] AutoRound performs well
Expand All @@ -42,6 +42,8 @@ image presents an overview of AutoRound.
```bash
pip install -r requirements.txt
python setup.py install
or
pip install -vvv --no-build-isolation -e .
```

### Install from pypi
Expand All @@ -55,7 +57,7 @@ pip install auto-round
### Gaudi2/ CPU/ GPU

We found a significant accuracy discrepancy with the qdq model using the AutoGPTQ GPU backend with asymmetric
quantization in some scenarios. Please switch to symmetric quantization to alleviate this issue.
quantization in some scenarios, especially at lower bits,like 2. Please save quantized model to AuoRound format to fix this issue.

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -71,7 +73,7 @@ bits, group_size, sym = 4, 128, False
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, device=None)
autoround.quantize()
output_dir = "./tmp_autoround"
autoround.save_quantized(output_dir)
autoround.save_quantized(output_dir) ##save_quantized(output_dir,format=="auto_round")
```

<details>
Expand Down Expand Up @@ -149,7 +151,7 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
##from auto_round.auto_quantizer import AutoHfQuantizer ## uncomment it for models with quantized lm-head
##from auto_round.auto_quantizer import AutoHfQuantizer ## uncomment it for models with auto_round format

quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True)
Expand Down
135 changes: 57 additions & 78 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from packaging import version
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import Conv1D
import transformers
from transformers.quantizers import AutoQuantizationConfig, HfQuantizer
from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING
from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod
Expand All @@ -52,11 +51,17 @@
else:
import importlib.metadata as importlib_metadata

AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
AUTOROUND_MINIMUM_VERSION = version.parse("0.2")


def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
try: ##TODO remove it later
import auto_round
return True, auto_round.__version__
except:
pass

package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
Expand All @@ -71,26 +76,32 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
return package_exists


_auto_gptq_available = _is_package_available("auto_gptq")
_auto_round_available = _is_package_available("auto_round")


def is_auto_gptq_available():
if _auto_gptq_available:
version_autogptq = version.parse(importlib_metadata.version("auto_gptq"))
if AUTOGPTQ_MINIMUM_VERSION < version_autogptq:
def is_auto_round_available():
if _auto_round_available:
version_autoround = version.parse(importlib_metadata.version("auto_round"))
if AUTOROUND_MINIMUM_VERSION < version_autoround:
return True
else:
raise ImportError(
f"Found an incompatible version of auto-gptq. Found version {version_autogptq},"
f" but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported"
f"Found an incompatible version of auto-round. Found version {version_autoround},"
f" but only version above {AUTOROUND_MINIMUM_VERSION} are supported"
)


if is_auto_gptq_available():
from auto_gptq import exllama_set_max_input_length
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.quantization import GPTQ
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
def is_autoround_exllamav2_available():
res = True
try:
from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix
except ImportError as e:
res = False
return res


if is_auto_round_available():
from auto_round_extension.cuda.post_init import autoround_post_init


#
Expand Down Expand Up @@ -201,15 +212,8 @@ def __init__(
dataset: str = None,
group_size: int = 128,
sym: bool = False,
backend="gptq:exllamav2",
iters: int = 200,
backend="autoround:exllamav2",
weight_config: dict = None,
enable_quanted_input=True,
enable_minmax_tuning=True,
lr=None,
minmax_lr=None,
n_samples=512,
seqlen=2048,
**kwargs,
):
self.bits = bits
Expand All @@ -218,31 +222,20 @@ def __init__(
self.group_size = group_size
self.sym = sym
self.backend = backend
self.inters = iters
self.weight_config = weight_config
self.enable_quanted_input = enable_quanted_input
self.enable_minmax_tuning = enable_minmax_tuning
self.lr = lr
self.minmax_lr = minmax_lr
self.n_samples = n_samples
self.seqlen = seqlen
if kwargs is not None:
for key in kwargs.keys():
setattr(self, key, kwargs[key])

self.post_init()

def get_loading_attributes(self):
pass
# attibutes_dict = copy.deepcopy(self.__dict__)
# loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"]
# loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
# return loading_attibutes_dict
return {}

def post_init(self):
r"""Safety checker that arguments are correct."""
if self.bits not in [2, 3, 4, 8]:
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
if self.bits not in [2, 4, 8]:
raise ValueError(f"Only support quantization to [2,4,8] bits but found {self.bits}")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("group_size must be greater than 0 or equal to -1")
##TODO add more check
Expand All @@ -254,23 +247,23 @@ def to_dict(self):


class AutoRoundQuantizer(HfQuantizer):
"""Quantizer of the Autoround method, currently only gptq backend has been supported."""
"""Quantizer of the AutoRound method, currently only triton and exllamav2 backend has been supported."""

requires_calibration = False
required_packages = ["auto_gptq"]
required_packages = ["auto_round"]
optimum_quantizer = None

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.exllama2_available = is_autoround_exllamav2_available

def validate_environment(self, *args, **kwargs):
gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError("GPU is required to quantize or run quantize model.")
elif not is_auto_gptq_available():
raise ImportError("Loading a GPTQ quantized model requires auto-gptq library (`pip install auto-gptq`)")
elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
raise ImportError("You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`")
if not is_auto_round_available():
raise ImportError("Loading a AutoRound quantized model requires auto-round library (`pip install "
"auto-round`)")
elif version.parse(importlib.metadata.version("auto_round")) < version.parse("0.2.0"):
raise ImportError("You need a version of auto_round > 0.2.0 to use AutoRound: `pip install --upgrade "
"auto-round`")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
Expand All @@ -280,7 +273,7 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
return torch_dtype

def convert_model(self, model: nn.Module):
"""Convert the model to a GPTQ model by getting and replacing the layers.
"""Convert the model to an AutoRound model by getting and replacing the layers.
Args:
model (`nn.Module`):
Expand Down Expand Up @@ -308,15 +301,22 @@ def convert_model(self, model: nn.Module):
layer_configs[layer_name]["data_type"] = data_type
layer_configs[layer_name]["sym"] = sym
else:
layer_configs[layer_name]["bits"] = extra_config.get("bits", bits)
layer_configs[layer_name]["group_size"] = extra_config.get("group_size", group_size)
layer_configs[layer_name]["data_type"] = extra_config.get("data_type", data_type)
layer_configs[layer_name]["sym"] = extra_config.get("sym", sym)
layer_configs[layer_name]["bits"] = extra_config[layer_name].get("bits", bits)
layer_configs[layer_name]["group_size"] = extra_config[layer_name].get("group_size", group_size)
layer_configs[layer_name]["data_type"] = extra_config[layer_name].get("data_type", data_type)
layer_configs[layer_name]["sym"] = extra_config[layer_name].get("sym", sym)
backend = quantization_config.backend

self._replace_by_quant_layers(model, layer_configs, backend)
return model

def _dynamic_import_inference_linear(self, bits, backend):
if bits == 4 and self.exllama2_available and "exllama2" in backend:
from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear
else:
from auto_round_extension.cuda.qliner_triton import QuantLinear
return QuantLinear

def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
"""Replaces linear layers in `module` by `QuantLinear`
Expand All @@ -335,21 +335,7 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
data_type = config["data_type"]
if not (bits <= 8 and data_type == "int"):
continue
from auto_round.export.export_to_autoround.export_to_autoround import get_autogptq_backend_config

use_triton, disable_exllama, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
backend, bits
)
QuantLinear = dynamically_import_QuantLinear(
use_triton=False,
desc_act=False,
group_size=group_size,
bits=bits,
disable_exllama=True,
disable_exllamav2=False,
use_qigen=use_qigen,
disable_marlin=disable_marlin,
)
QuantLinear = self._dynamic_import_inference_linear(bits, backend)
layer = get_module(module, layer_name)
device = get_device(layer)
if isinstance(layer, nn.Linear):
Expand Down Expand Up @@ -382,23 +368,17 @@ def post_init_model(self, model):
The input model
"""

# if self.bits == 4 and not self.disable_exllama:
# if get_device(model) == torch.device("cpu") or (
# hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"])
# ):
# raise ValueError(
# "Found modules on cpu/disk. Using Exllama
# or Exllamav2 backend requires all the modules to be on GPU."
# "You can deactivate exllama backend by
# setting `disable_exllama=True` in the quantization config object"
# )
#
# if self.bits == 4: if get_device(model) == torch.device("cpu") or ( hasattr(model, "hf_device_map") and
# any(d in model.hf_device_map for d in ["cpu", "disk"]) ): raise ValueError( "Found modules on cpu/disk.
# Using Exllamav2 backend requires all the modules to be on GPU." "You can deactivate exllama backend by
# setting `disable_exllama=True` in the quantization config object" )

class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = False
model = autogptq_post_init(model, use_act_order=False)
model = autoround_post_init(model)
return model

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
Expand Down Expand Up @@ -436,4 +416,3 @@ def is_serializable(self):

transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer
from transformers import AutoModelForCausalLM as AutoModelForCausalLM
2 changes: 1 addition & 1 deletion auto_round/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
from .register import EXPORT_FORMAT
from .export_to_autogptq import save_quantized_as_autogptq
from .export_to_itrex import save_quantized_as_itrex, QuantConfig
from .export_to_autoround.export_to_autoround import save_quantized_as_autoround
from .export_to_autoround.export import save_quantized_as_autoround


10 changes: 7 additions & 3 deletions auto_round/export/export_to_autogptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
@register_format("auto_gptq")
def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwargs):
"""Export the model to autogptq format to easily leverage cuda kernel."""
try:
import auto_gptq
except ImportError:
raise ImportError("export to autogptq requires autogptq library. Please run 'pip install auto-gptq'")
model = kwargs["model"]
weight_config = kwargs["weight_config"]
sym = kwargs["sym"]
Expand Down Expand Up @@ -95,7 +99,7 @@ def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwar
else:
compressed_model = copy.deepcopy(model.to("cpu"))

from auto_gptq.modeling._utils import pack_model
from auto_gptq.modeling._utils import pack_model # pylint: disable=E0401

if bits == 3 or use_triton is False:
if bits == 3 and use_triton is True:
Expand Down Expand Up @@ -127,7 +131,7 @@ def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwar
info = weight_config[key]
if not check_to_quantized(info):
continue
quantizers[key] = (None, info["scale"].to(torch.float32), info["zp"].to(torch.float32), info["g_idx"])
quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32), info["g_idx"])
pack_model(
compressed_model,
quantizers,
Expand Down Expand Up @@ -236,7 +240,7 @@ def _save_quantized_to_autogptq(
model_save_name = model_base_name + ".bin"
torch.save(model.state_dict(), join(save_dir, model_save_name))

from auto_gptq.modeling._base import BaseQuantizeConfig
from auto_gptq.modeling._base import BaseQuantizeConfig # pylint: disable=E0401

quantization_config = BaseQuantizeConfig(
bits=bits,
Expand Down
2 changes: 1 addition & 1 deletion auto_round/export/export_to_autoround/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .export_to_autoround import save_quantized_as_autoround
from .export import save_quantized_as_autoround

Loading

0 comments on commit 794cd90

Please sign in to comment.