Skip to content

Commit

Permalink
Added cache_block_outputs parameter to handle models with non-regular…
Browse files Browse the repository at this point in the history
… structure such as ChatGLM (#1479)

* Added cache_block_outputs parameter to handle models with non-regular structure in GPTQ

* Code style

* Added variable description

* Applied comments

* Changed default. Added more docstring

* Added a test for cache_block_outputs feature

* Style
  • Loading branch information
AlexKoff88 authored Oct 31, 2023
1 parent 8e7588b commit a49a116
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 23 deletions.
69 changes: 46 additions & 23 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
max_input_length: Optional[int] = None,
cache_block_outputs: Optional[bool] = True,
*args,
**kwargs,
):
Expand Down Expand Up @@ -115,6 +116,9 @@ def __init__(
max_input_length (`Optional[int]`, defaults to `None`):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
It is specific to the exllama backend with act-order.
cache_block_outputs (`bool`, defaults to `True`):
Whether to cache block outputs to reuse as inputs for the succeeding block. It allows optimization of non-standard models
(e.g. ChatGLM) but can require more time.
"""

self.bits = bits
Expand All @@ -134,6 +138,7 @@ def __init__(
self.disable_exllamav2 = disable_exllamav2
self.max_input_length = max_input_length
self.quant_method = QuantizationMethod.GPTQ
self.cache_block_outputs = cache_block_outputs

if self.bits not in [2, 3, 4, 8]:
raise ValueError("only support quantize to [2,3,4,8] bits.")
Expand Down Expand Up @@ -355,10 +360,9 @@ def quantize_model(self, model: nn.Module, tokenizer: Any):

def store_input_hook(_, input, *args):
kwargs = args[0]
input = input[0]
if input is None:
if "hidden_states" in kwargs:
input = kwargs["hidden_states"]
input = (kwargs["hidden_states"],)
else:
raise ValueError("No input value found in the foward pass")
layer_inputs.append(input)
Expand All @@ -369,17 +373,18 @@ def store_input_hook(_, input, *args):
layer_input_kwargs.append(other_kwargs)
raise ValueError

handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
data[k] = v.to(0)
try:
model(**data)
except ValueError:
pass
if self.cache_block_outputs:
handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
data[k] = v.to(0)
try:
model(**data)
except ValueError:
pass
handle.remove()

handle.remove()
if not has_device_map:
blocks[0].to(device)
for module_name in self.module_name_preceding_first_block:
Expand All @@ -393,6 +398,19 @@ def store_input_hook(_, input, *args):
quantizers = {}
for i, block in enumerate(tqdm(blocks, desc=f"Quantizing {self.block_name_to_quantize} blocks ")):
logger.info(f"Start quantizing block {self.block_name_to_quantize} {i + 1}/{len(blocks)}")

if not self.cache_block_outputs:
handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
data[k] = v.to(0)
try:
model(**data)
except ValueError:
pass
handle.remove()

# move block to cuda if needed
# in case we have offload modules, we need to put them on cuda because of GPTQ object
if not has_device_map or get_device(block) == torch.device("cpu"):
Expand Down Expand Up @@ -425,7 +443,7 @@ def tmp(_, input, output):
for j in range(len(dataset)):
# the args are already on the gpu
# don't need to store the output
block(layer_inputs[j], **layer_input_kwargs[j])
block(*layer_inputs[j], **layer_input_kwargs[j])
# remove hook
for h in handles:
h.remove()
Expand All @@ -443,16 +461,21 @@ def tmp(_, input, output):
gptq[name].free()
del subset_layers
# we get the new output from the partial quantized block
for j in range(len(dataset)):
layer_output = block(layer_inputs[j], **layer_input_kwargs[j])[0]
layer_outputs.append(layer_output)

# put back to device
if not has_device_map:
blocks[i] = block.to(device)
del layers
del layer_inputs
layer_inputs, layer_outputs = layer_outputs, []
if self.cache_block_outputs:
for j in range(len(dataset)):
layer_output = block(*layer_inputs[j], **layer_input_kwargs[j])
layer_outputs.append(layer_output)

# put back to device
if not has_device_map:
blocks[i] = block.to(device)
del layers
del layer_inputs
layer_inputs, layer_outputs = layer_outputs, []
else:
del layers
del layer_inputs
layer_inputs = []
torch.cuda.empty_cache()

if self.bits == 4:
Expand Down
10 changes: 10 additions & 0 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class GPTQTest(unittest.TestCase):
desc_act = False
disable_exllama = True
disable_exllamav2 = True
cache_block_outputs = True

dataset = [
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
Expand All @@ -71,6 +72,7 @@ def setUpClass(cls):
desc_act=cls.desc_act,
disable_exllama=cls.disable_exllama,
disable_exllamav2=cls.disable_exllamav2,
cache_block_outputs=cls.cache_block_outputs,
)

cls.quantized_model = cls.quantizer.quantize_model(cls.model_fp16, cls.tokenizer)
Expand Down Expand Up @@ -261,6 +263,14 @@ def test_exllama_serialization(self):
self.check_inference_correctness(quantized_model_from_saved)


class GPTQTestNoBlockCaching(GPTQTest):
cache_block_outputs = False
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I")
EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.")
EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of")


class GPTQUtilsTest(unittest.TestCase):
"""
Test utilities
Expand Down

0 comments on commit a49a116

Please sign in to comment.