Skip to content

Commit

Permalink
Block gptq
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 8, 2024
1 parent 686beb7 commit f29d7ad
Showing 1 changed file with 71 additions and 10 deletions.
81 changes: 71 additions & 10 deletions src/brevitas_examples/llm/llm_quant/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,82 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

from copy import deepcopy
import torch
from tqdm import tqdm

from brevitas.graph.gptq import gptq_mode
from accelerate.utils.operations import send_to_device


@torch.no_grad()
def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None):
with gptq_mode(model,
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
gptq_model(**inps)
gptq.update()
if True:
blocks = model.model.layers #getattr(model, block_name)
first_block = blocks[0]
cached_args, cached_kwargs = [], []

def intercept_input(module, args, kwargs):
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
cached_args.append(args)
cached_kwargs.append(kwargs)
raise RuntimeError
def intercept_output(module, args, kwargs, output):
if isinstance(output, tuple):
output = output[0]
cached_args.append((output,))
raise RuntimeError


hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True)
for inps in dataloader:
try:
model(**inps)
except:
pass
hook.remove()


for index, block in enumerate(tqdm(blocks)):
with gptq_mode(block,
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
for _ in tqdm(range(gptq.num_layers)):
for args, kwargs in zip(cached_args, cached_kwargs):
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
gptq.update()
past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs)
cached_args = []

if index < len(blocks)-1:
hook = blocks[index].register_forward_hook(intercept_output, with_kwargs=True)
for args, kwargs in zip(past_cached_args, past_cached_kwargs):
try:
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
except Exception as e:
pass
hook.remove()


else:
with gptq_mode(model,
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=False) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
gptq_model(**inps)
gptq.update()

0 comments on commit f29d7ad

Please sign in to comment.