Skip to content

Commit

Permalink
Feat (examples/llm): initial support for loading AWQ results (#673)
Browse files Browse the repository at this point in the history
* Feat (examples/llm): initial support for loading AWQ results

Signed-off-by: Alessandro Pappalardo <[email protected]>

* Disable yapf on mlir ops def

* Fix pre commit on awq folder

* yapf

* yapf

---------

Signed-off-by: Alessandro Pappalardo <[email protected]>
  • Loading branch information
volcacius authored Jul 17, 2023
1 parent 26cc63c commit 64c2028
Show file tree
Hide file tree
Showing 9 changed files with 792 additions and 1 deletion.
118 changes: 118 additions & 0 deletions src/brevitas_examples/llm/llm_quant/awq/auto_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Adapted from https://github.com/mit-han-lab/llm-awq, released under the following LICENSE:
MIT License
Copyright (c) 2023 MIT HAN Lab
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import gc

import torch
import torch.nn as nn

from .quantizer import pseudo_quantize_tensor

__all__ = ["auto_clip_block"]


# weight quantization
@torch.no_grad()
def auto_clip_layer(w, input_feat, n_bit, q_config, n_grid=20, max_shrink=0.5, n_sample_token=512):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = q_config["q_group_size"] if q_config["q_group_size"] > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)

oc_batch_size = 256 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []

for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size:(i_b + 1) * oc_batch_size]

org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1

best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group

for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(cur_w, n_bit=n_bit, **q_config)
cur_out = (input_feat * q_w).sum(dim=-1)

# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)

best_max_val = torch.cat(best_max_val_all, dim=0)

del input_feat
del org_out
gc.collect()
torch.cuda.empty_cache()
return best_max_val.squeeze(1)


@torch.no_grad()
def auto_clip_block(module, w_bit, q_config, input_feat):

named_linears = {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}

clip_list = []
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in ["q_", "k_"]]):
continue
max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config)
clip_list.append((name, max_val))
return clip_list


@torch.no_grad()
def apply_clip(module, clip_list):
from .utils.module import get_op_by_name
for name, max_val in clip_list:
layer = get_op_by_name(module, name)
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
272 changes: 272 additions & 0 deletions src/brevitas_examples/llm/llm_quant/awq/auto_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
"""
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Adapted from https://github.com/mit-han-lab/llm-awq, released under the following LICENSE:
MIT License
Copyright (c) 2023 MIT HAN Lab
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.opt.modeling_opt import OPTDecoderLayer

from .utils.module import get_op_by_name
from .utils.module import get_op_name

__all__ = ["auto_scale_block", "apply_scale"]


@torch.no_grad()
def get_weight_scale(weight, q_group_size=-1):
org_shape = weight.shape
if q_group_size > 0:
weight = weight.view(-1, q_group_size)
scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
scale = scale.view(org_shape)
scale = scale.mean(0)
return scale


@torch.no_grad()
def get_act_scale(x):
return x.abs().view(-1, x.shape[-1]).mean(0)


@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
if not isinstance(fcs, list):
fcs = [fcs]

scales = scales.to(ln.weight.device)

ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None:
ln.bias.div_(scales)

for fc in fcs:
fc.weight.mul_(scales.view(1, -1))

for p in ln.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0


@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear)
assert fc1.out_features == fc2.in_features

scales = scales.to(fc1.weight.device)

fc1.weight.div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))

fc2.weight.mul_(scales.view(1, -1))

for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for p in fc2.parameters():
assert torch.isnan(p).sum() == 0


@torch.no_grad()
def auto_scale_block(module, module_kwargs, w_bit, q_config, input_feat):
from .quantizer import pseudo_quantize_tensor

# firstly, get the weight quantize function
if w_bit is not None:

def w_quantize_func(p):
return pseudo_quantize_tensor(
p,
n_bit=w_bit,
**q_config,
).detach()
else:

def w_quantize_func(p):
return p

if "use_cache" in module_kwargs:
module_kwargs.pop("use_cache")

# find the best scale ratio
def _search_module_scale(block, linears2scale: list, x, kwargs={}):
# w: co, ci
# x: n, ci
x = x.to(next(block.parameters()).device)
weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale(weight, q_group_size=q_config.get("q_group_size", -1))

with torch.no_grad():
org_out = block(x, **kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]

x_max = get_act_scale(x)

best_error = float('inf')
best_ratio = -1
best_scales = None

n_grid = 20
history = []

org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(n_grid):
ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1))
fc.weight.data = w_quantize_func(fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs)
if isinstance(out, tuple):
out = out[0]

loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow
history.append(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_ratio = ratio
best_scales = scales
block.load_state_dict(org_sd)
if best_ratio == -1:
print(history)
raise Exception
# print(best_ratio)
best_scales = best_scales.view(-1)

assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()

def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
# module2inspect: if given, we will check the output diff of this module instead of layers
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]

scales = _search_module_scale(module2inspect, layers, inp, kwargs)
# prev_op_name, [layer_name], scale
return (
get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)

scales_list = [] # return the searched scales

if isinstance(module, OPTDecoderLayer):
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn,
kwargs=module_kwargs,
))
# attn out
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat['self_attn.out_proj'],
))
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat['fc1'],
))
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat['fc2'],
))

elif isinstance(module, LlamaDecoderLayer):
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn,
kwargs=module_kwargs,
))
# attn out
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))

else:
raise NotImplementedError(f"{type(module)} not supported yet!")

return scales_list


def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]

if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales)
else:
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")

# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
Loading

0 comments on commit 64c2028

Please sign in to comment.