Skip to content

Commit

Permalink
enable llava & Qwen-VL multimodal model quantization
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <[email protected]>
  • Loading branch information
WeiweiZhang1 committed Jun 28, 2024
1 parent f9e7d79 commit e273472
Show file tree
Hide file tree
Showing 17 changed files with 2,578 additions and 59 deletions.
91 changes: 60 additions & 31 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .calib_dataset import get_dataloader
from .quantizer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer
from .special_model_handler import check_hidden_state_dim, check_share_attention_mask
from .special_model_handler import check_hidden_state_dim, check_share_attention_mask, check_not_share_position_ids
from .utils import (
CpuInfo,
block_forward,
Expand Down Expand Up @@ -89,6 +89,7 @@ class AutoRound(object):
data_type (str): The data type to be used (default is "int").
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
multimodal(bool): Enable multimodal model quantization, (default is "False").
Returns:
The quantized model.
Expand Down Expand Up @@ -124,6 +125,7 @@ def __init__(
dynamic_max_gap: int = -1,
data_type: str = "int", ##only support int for now
scale_dtype: str = "fp16",
multimodal:bool = False,
**kwargs,
):
self.quantized = False
Expand Down Expand Up @@ -153,6 +155,7 @@ def __init__(
logger.info(f"using {self.model.dtype} for quantization tuning")
self.dataset = dataset
self.iters = iters
self.multimodal = multimodal
if self.iters <= 0:
logger.warning("iters must be positive, reset it to 200")
self.iters = 200
Expand Down Expand Up @@ -203,8 +206,8 @@ def quantize(self):
The quantized model and weight configurations.
"""
# logger.info("cache block input")
block_names = get_block_names(self.model)
if len(block_names) == 0:
all_blocks = get_block_names(self.model)
if len(all_blocks) == 0:
logger.warning("could not find blocks, exit with original model")
return self.model, self.weight_config

Expand All @@ -213,29 +216,28 @@ def quantize(self):

layer_names = self.get_quantized_layer_names_outside_blocks()
self.start_time = time.time()
all_inputs = self.try_cache_inter_data_gpucpu([block_names[0]], self.nsamples, layer_names=layer_names)
del self.inputs
inputs = all_inputs[block_names[0]]

all_inputs.pop(block_names[0])
self.inputs = None
del self.inputs
if "input_ids" in inputs.keys():
total_samples = len(inputs["input_ids"])
self.nsamples = total_samples
if total_samples < self.train_bs:
self.train_bs = total_samples
logger.warning(f"force the train batch size to {total_samples} ")

self.model = self.model.to("cpu")
torch.cuda.empty_cache()
self.quant_blocks(
self.model,
inputs,
block_names,
nblocks=self.nblocks,
device=self.device,
)
all_first_block_names = [block[0] for block in all_blocks]
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
for block_names in all_blocks:
inputs = all_inputs[block_names[0]]
all_inputs.pop(block_names[0])
self.inputs = None
del self.inputs
if "input_ids" in inputs.keys():
total_samples = len(inputs["input_ids"])
self.n_samples = total_samples
if total_samples < self.train_bs:
self.train_bs = total_samples
logger.warning(f"force the train batch size to {total_samples} ")
self.model = self.model.to("cpu")
torch.cuda.empty_cache()
self.quant_blocks(
self.model,
inputs,
block_names,
nblocks=self.nblocks,
device=self.device,
)

self.quant_layers(layer_names, all_inputs)

Expand Down Expand Up @@ -338,11 +340,11 @@ def set_layerwise_config(self, weight_config):
Returns:
None
"""
layers_inblocks = get_layer_names_in_block(self.model, self.supported_types)
layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types)
for n, m in self.model.named_modules():
if not isinstance(m, tuple(self.supported_types)):
continue
if n not in weight_config.keys() and n in layers_inblocks:
if n not in weight_config.keys() and n in layers_in_blocks:
weight_config[n] = {}
weight_config[n]["data_type"] = self.data_type
weight_config[n]["bits"] = self.bits
Expand Down Expand Up @@ -396,7 +398,13 @@ def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_de
end_index = min(self.nsamples, i + bs)
indices = torch.arange(i, end_index).to(torch.long)
tmp_input_ids, tmp_input_others = sampling_inputs(
input_ids, input_others, indices, self.seqlen, self.share_attention_mask_flag, self.input_dim
input_ids,
input_others,
indices,
self.seqlen,
self.share_attention_mask_flag,
self.not_share_position_ids_flag,
self.input_dim
)
tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to(
cache_device
Expand Down Expand Up @@ -449,6 +457,14 @@ def calib(self, nsamples, bs):
for key in data.keys():
data_new[key] = data[key].to(self.model.device)
input_ids = data_new["input_ids"]
elif isinstance(data, tuple) or isinstance(data, list):
if self.multimodal:
data_new = {"input_ids": data[0].to(self.model.device), \
"images": data[1].to(self.model.device, dtype=self.model.dtype), "image_sizes": data[2]}
input_ids = data_new["input_ids"]
else:
data_new = data
input_ids = data_new[0]
else:
data_new = {}
for key in data.keys():
Expand All @@ -460,13 +476,15 @@ def calib(self, nsamples, bs):
try:
if isinstance(data_new, torch.Tensor):
self.model(data_new)
elif isinstance(data_new, tuple) or isinstance(data_new, list):
self.model(*data_new)
else:
self.model(**data_new)
except NotImplementedError:
pass
except Exception as error:
logger.error(error)
total_cnt += input_ids.shape[0]
total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1
if total_cnt >= nsamples:
break
if total_cnt == 0:
Expand All @@ -483,7 +501,7 @@ def calib(self, nsamples, bs):

@torch.no_grad()
def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=[], last_cache_name=None):
"""Attempts to cache intermediate data on GPUif failed, then using CPU.
"""Attempts to cache intermediate data on GPU, if failed, then using CPU.
Args:
block_names (list): List of block names to cache data for.
Expand Down Expand Up @@ -542,6 +560,7 @@ def cache_inter_data(self, block_names, nsamples, layer_names=[], last_cache_nam
self.last_cache_name = last_cache_name
if last_cache_name is None and len(block_names) + len(layer_names) == 1:
self.last_cache_name = block_names[0] if len(block_names) == 1 else layer_names[0]
# do not set last_cache_name for multimodal models
calib_bs = self.train_bs
self.hook_handles = []
self._replace_forward()
Expand Down Expand Up @@ -579,6 +598,7 @@ def forward(m, hidden_states, *positional_args, **kwargs):
if self.share_attention_mask_flag is None:
self.input_dim = check_hidden_state_dim(self.model, positional_args)
self.share_attention_mask_flag = check_share_attention_mask(self.model, hidden_states, **kwargs)
self.not_share_position_ids_flag = check_not_share_position_ids(self.model, **kwargs)
if name in self.inputs:
self.inputs[name]["input_ids"].extend(list(torch.split(hidden_states.to("cpu"), 1, dim=self.input_dim)))
else:
Expand Down Expand Up @@ -612,6 +632,13 @@ def forward(m, hidden_states, *positional_args, **kwargs):
self.inputs[name][key].extend(list(torch.split(alibi.to("cpu"), 1, dim=0)))
else:
self.inputs[name][key] = list(torch.split(alibi.to("cpu"), 1, dim=0))
elif "position_ids" in key:
if key not in self.inputs[name].keys():
self.inputs[name][key] = list(torch.split(kwargs[key].to("cpu"), 1, dim=0)) \
if self.not_share_position_ids_flag \
else to_device(kwargs[key], device=torch.device("cpu"))
elif kwargs[key] is not None and self.not_share_position_ids_flag:
self.inputs[name][key].extend(list(torch.split(kwargs[key].to("cpu"), 1, dim=0)))
elif key not in self.inputs[name].keys():
self.inputs[name][key] = to_device(kwargs[key], device=torch.device("cpu"))
if name == self.last_cache_name:
Expand Down Expand Up @@ -847,6 +874,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
indices,
seqlen=self.seqlen,
share_attention_mask_flag=self.share_attention_mask_flag,
not_share_position_ids_flag=self.not_share_position_ids_flag,
input_dim=self.input_dim,
)

Expand Down Expand Up @@ -1414,3 +1442,4 @@ def __init__(
**kwargs,
)


29 changes: 15 additions & 14 deletions auto_round/export/export_to_autogptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,23 @@ def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True,
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)
##check module quantized in block, this may have bug for mixed precision quantization
block_name = get_block_names(model)[0]
first_block = get_module(model, block_name)
all_blocks = get_block_names(model)
all_to_quantized = True
modules_in_block_to_quantize = []
for n, m in first_block.named_modules():
is_supported_type = False
for supported_type in supported_types:
if isinstance(m, supported_type):
is_supported_type = True
break
if not is_supported_type:
continue
if not check_to_quantized(m):
all_to_quantized = False
else:
modules_in_block_to_quantize.append(n)
for block_names in all_blocks:
first_block = get_module(model, block_names[0])
for n, m in first_block.named_modules():
is_supported_type = False
for supported_type in supported_types:
if isinstance(m, supported_type):
is_supported_type = True
break
if not is_supported_type:
continue
if not check_to_quantized(m):
all_to_quantized = False
else:
modules_in_block_to_quantize.append(n)
modules_in_block_to_quantize = [modules_in_block_to_quantize]
if all_to_quantized:
modules_in_block_to_quantize = None
Expand Down
11 changes: 11 additions & 0 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

share_attention_mask_tuple = ("baichuan",)
special_states_dim_tuple = ("chatglm",)
not_share_position_ids_tuple = ("llava",)


def check_share_attention_mask(model, hidden_states, attention_mask=None, **kwargs):
Expand Down Expand Up @@ -54,3 +55,13 @@ def check_hidden_state_dim(model, positional_args):
is_special = True
break
return int(is_special and positional_args is not None)


def check_not_share_position_ids(model, **kwargs):
is_special = False
for key in not_share_position_ids_tuple:
if hasattr(model, "config") and key in model.config.model_type:
is_special = True
break
return bool(is_special)

34 changes: 20 additions & 14 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_scale_shape(weight, group_size):
return shape


def to_device(input, device=torch.device("cpu")):
def to_device(input, device=torch.device("cpu"), multimodal=False):
"""Moves input data to the specified device.
Args:
Expand Down Expand Up @@ -185,13 +185,15 @@ def get_block_names(model):
block_names: A list of block names.
"""
block_names = []
target_m = None
target_modules = []
for n, m in model.named_modules():
if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
target_m = (n, m)
break ## only find the first modulelist, may be not robust
for n, m in target_m[1].named_children():
block_names.append(target_m[0] + "." + n)
target_modules.append((n, m))
# break ## only find the first modulelist, may be not robust
for i,target_m in enumerate(target_modules):
block_names.append([])
for n, m in target_m[1].named_children():
block_names[i].append(target_m[0] + "." + n)
return block_names


Expand Down Expand Up @@ -232,7 +234,8 @@ def collect_minmax_scale(block):


@torch.no_grad()
def sampling_inputs(input_ids, input_others, indices, seqlen, share_attention_mask_flag=False, input_dim=0):
def sampling_inputs(input_ids, input_others, indices, seqlen,
share_attention_mask_flag=False, not_share_position_ids_flag=False, input_dim=0):
"""Samples inputs based on the given indices and sequence length.
Args:
Expand All @@ -250,7 +253,8 @@ def sampling_inputs(input_ids, input_others, indices, seqlen, share_attention_ma

current_input_others = {"positional_inputs": input_others["positional_inputs"]}
for key in input_others.keys():
if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key):
if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key) \
or (not_share_position_ids_flag and "position_ids" in key):
current_input_others[key] = None
if input_others[key] is not None:
current_input_others[key] = [input_others[key][i] for i in indices]
Expand Down Expand Up @@ -556,12 +560,13 @@ def get_layer_names_in_block(model, supported_types=[torch.nn.Linear, transforme
if isinstance(m, tuple(supported_types)):
m.tmp_name = n
layers_in_block = []
block_names = get_block_names(model)
for block_name in block_names:
block = get_module(model, block_name)
for n, m in block.named_modules():
if hasattr(m, "tmp_name"):
layers_in_block.append(m.tmp_name)
all_blocks = get_block_names(model)
for block_names in all_blocks:
for block_name in block_names:
block = get_module(model, block_name)
for n, m in block.named_modules():
if hasattr(m, "tmp_name"):
layers_in_block.append(m.tmp_name)
for n, m in model.named_modules():
if hasattr(m, "tmp_name"):
delattr(m, "tmp_name")
Expand Down Expand Up @@ -665,3 +670,4 @@ def dynamic_import_inference_linear(bits, group_size, backend):
else:
from auto_round_extension.cuda.qliner_triton import QuantLinear
return QuantLinear

Loading

0 comments on commit e273472

Please sign in to comment.