diff --git a/src/brevitas_examples/llm/llm_quant/gptq.py b/src/brevitas_examples/llm/llm_quant/gptq.py index 1eafa2851..85b160c3e 100644 --- a/src/brevitas_examples/llm/llm_quant/gptq.py +++ b/src/brevitas_examples/llm/llm_quant/gptq.py @@ -3,21 +3,82 @@ # SPDX-License-Identifier: BSD-3-Clause """ +from copy import deepcopy import torch from tqdm import tqdm from brevitas.graph.gptq import gptq_mode +from accelerate.utils.operations import send_to_device @torch.no_grad() def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None): - with gptq_mode(model, - use_quant_activations=False, - group_of_parallel_layers=group_of_parallel_layers, - act_order=act_order, - create_weight_orig=False) as gptq: - gptq_model = gptq.model - for _ in tqdm(range(gptq.num_layers)): - for inps in dataloader: - gptq_model(**inps) - gptq.update() + if True: + blocks = model.model.layers #getattr(model, block_name) + first_block = blocks[0] + cached_args, cached_kwargs = [], [] + + def intercept_input(module, args, kwargs): + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + cached_args.append(args) + cached_kwargs.append(kwargs) + raise RuntimeError + def intercept_output(module, args, kwargs, output): + if isinstance(output, tuple): + output = output[0] + cached_args.append((output,)) + raise RuntimeError + + + hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True) + for inps in dataloader: + try: + model(**inps) + except: + pass + hook.remove() + + + for index, block in enumerate(tqdm(blocks)): + with gptq_mode(block, + use_quant_activations=False, + group_of_parallel_layers=group_of_parallel_layers, + act_order=act_order, + create_weight_orig=False) as gptq: + for _ in tqdm(range(gptq.num_layers)): + for args, kwargs in zip(cached_args, cached_kwargs): + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + gptq.update() + past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs) + cached_args = [] + + if index < len(blocks)-1: + hook = blocks[index].register_forward_hook(intercept_output, with_kwargs=True) + for args, kwargs in zip(past_cached_args, past_cached_kwargs): + try: + args = send_to_device(args, 'cuda') + kwargs = send_to_device(kwargs, 'cuda') + block(*args, **kwargs) + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + except Exception as e: + pass + hook.remove() + + + else: + with gptq_mode(model, + use_quant_activations=False, + group_of_parallel_layers=group_of_parallel_layers, + act_order=act_order, + create_weight_orig=False) as gptq: + gptq_model = gptq.model + for _ in tqdm(range(gptq.num_layers)): + for inps in dataloader: + gptq_model(**inps) + gptq.update()