Skip to content

Commit

Permalink
Merge branch 'main' into xuehao/remove_autogptq
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Dec 12, 2024
2 parents 92533e5 + e88882e commit 1b78adc
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 24 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ steps,
which competes impressively against recent methods without introducing any additional inference overhead and keeping low
tuning cost. The below
image presents an overview of AutoRound. Check out our paper on [arxiv](https://arxiv.org/pdf/2309.05516) for more
details and visit [low_bit_open_llm_leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard) for
more accuracy data and recipes across various models.
details and quantized huggingface space models in [OPEA](https://huggingface.co/OPEA), [Kaitchup](https://huggingface.co/kaitchup) and [fbaldassarri](https://huggingface.co/fbaldassarri).

<div align="center">

Expand Down Expand Up @@ -398,3 +397,4 @@ If you find AutoRound useful for your research, please cite our paper:
```



11 changes: 10 additions & 1 deletion auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,14 @@ def detect_device(self, target_backend, orig_backend):
if backend is None:
raise ValueError("Backend not found, please set it to 'auto' to have a try ")

return BackendInfos[backend].device[0]
device = BackendInfos[backend].device[0]
if "cuda" in device and torch.cuda.is_available():
return device
elif "hpu" in device and is_hpu_supported():
return device
else:
return "cpu"


def convert_model(self, model: nn.Module):
"""Converts the given model to an AutoRound model by replacing its layers with quantized layers.
Expand Down Expand Up @@ -392,6 +399,7 @@ def convert_model(self, model: nn.Module):
quantization_config.target_backend = quantization_config.backend

target_device = self.detect_device(quantization_config.target_backend, quantization_config.backend)

self.target_device = target_device

if hasattr(quantization_config, "backend"): # pragma: no cover
Expand Down Expand Up @@ -744,3 +752,4 @@ def is_serializable(self):
transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer


24 changes: 10 additions & 14 deletions auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll
supported_types = kwargs["supported_types"]
safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"]
to_quant_block_names = kwargs["to_quant_block_names"]
quant_block_list = kwargs.get("quant_block_list", None)
quant_block_list = kwargs.get("quant_block_list", get_block_names(model))
logger.info("Saving quantized model to autogptq format, this may take a while...")
tokenizer = kwargs.get("tokenizer", None)
processor = kwargs.get("processor", None)
Expand All @@ -131,19 +131,14 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll
processor.save_pretrained(output_dir)
##check module quantized in block, this may have bug for mixed precision quantization
quantization_config = kwargs["serialization_dict"]
if bool(quant_block_list):
all_blocks = quant_block_list
flattened_list = [item for sublist in all_blocks for item in sublist]
common_prefix = os.path.commonprefix(flattened_list).rstrip('.')
if common_prefix not in BLOCK_PATTERNS:
logger.error(f"auto-gptq format may not support loading this quantized model")
quantization_config['block_name_to_quantize'] = common_prefix
else:
all_blocks = get_block_names(model)
flattened_list = [item for sublist in all_blocks for item in sublist]
common_prefix = os.path.commonprefix(flattened_list).rstrip('.')
if common_prefix not in BLOCK_PATTERNS:
quantization_config['block_name_to_quantize'] = common_prefix
all_blocks = quant_block_list
flattened_list = [item for sublist in all_blocks for item in sublist]
common_prefix = os.path.commonprefix(flattened_list).rstrip('.')
if common_prefix not in BLOCK_PATTERNS:
logger.error(f"auto-gptq format may not support loading this quantized model")
quantization_config['block_name_to_quantize'] = common_prefix
quantization_config.pop("to_quant_block_names", None)


all_to_quantized = True
modules_in_block_to_quantize = []
Expand Down Expand Up @@ -222,3 +217,4 @@ def save(model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", saf
json.dump(model.config.quantization_config, f, indent=2)



2 changes: 2 additions & 0 deletions auto_round/mllm/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
"llava_next": dict(cls="LLaVA_Next"),
"phi3_v": dict(cls="Phi3Vision"),
"mllama": dict(cls="llama_vision"),
"glm-4v-9b": dict(cls="GLM4v"),
}


Expand Down Expand Up @@ -409,3 +410,4 @@ class CliArgs:
json.dump(results, open(output_file, 'w'), indent=4, default=_handle_non_serializable)

return results

3 changes: 1 addition & 2 deletions auto_round/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def _init_tuning_params_and_quant_func(self):
self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0)
self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0)
self._init_params("value", p_dtype, weight_reshape.shape, 0, True)

# Min-max scale initialization
shape = get_scale_shape(orig_weight, orig_layer.group_size)
self._init_params("min_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning)
Expand Down Expand Up @@ -304,7 +303,6 @@ def forward(self, x):
bias = self.orig_layer.bias
if bias is not None and bias.device.type == 'meta':
bias = self.orig_layer.get_bias().to(self.device)

if self.enable_norm_bias_tuning:
bias, _, _ = self._qdq_bias(bias, self.bias_v)

Expand Down Expand Up @@ -520,3 +518,4 @@ def unwrapper_block(block, best_params):
best_param = None
orig_layer = m.unwrapper(best_param)
set_module(block, n, orig_layer)

14 changes: 9 additions & 5 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,21 @@ def tune(args):

# load_model
processor, image_processor = None, None
if "llava" in model_name:
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration":
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
tokenizer, model, image_processor, _ = load_pretrained_model(
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
else:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
model_type = config.model_type
if "qwen2_vl" in model_type:
if "llava" in model_type:
from transformers import LlavaForConditionalGeneration
cls = LlavaForConditionalGeneration
elif "qwen2_vl" in model_type:
from transformers import Qwen2VLForConditionalGeneration
cls = Qwen2VLForConditionalGeneration
elif "mllama" in model_type:
Expand Down Expand Up @@ -511,3 +514,4 @@ def lmms_eval(args):
apply_chat_template=False,
)
return results

0 comments on commit 1b78adc

Please sign in to comment.