-
Notifications
You must be signed in to change notification settings - Fork 195
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (examples/llm): initial support for loading AWQ results (#673)
* Feat (examples/llm): initial support for loading AWQ results Signed-off-by: Alessandro Pappalardo <[email protected]> * Disable yapf on mlir ops def * Fix pre commit on awq folder * yapf * yapf --------- Signed-off-by: Alessandro Pappalardo <[email protected]>
- Loading branch information
Showing
9 changed files
with
792 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
""" | ||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
Adapted from https://github.com/mit-han-lab/llm-awq, released under the following LICENSE: | ||
MIT License | ||
Copyright (c) 2023 MIT HAN Lab | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
""" | ||
|
||
import gc | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .quantizer import pseudo_quantize_tensor | ||
|
||
__all__ = ["auto_clip_block"] | ||
|
||
|
||
# weight quantization | ||
@torch.no_grad() | ||
def auto_clip_layer(w, input_feat, n_bit, q_config, n_grid=20, max_shrink=0.5, n_sample_token=512): | ||
assert w.dim() == 2 | ||
org_w_shape = w.shape | ||
# w [co, ci] -> [co, 1, n_group, group size] | ||
# input_feat [n_token, ci] -> [1, n_token, n_group, group size] | ||
group_size = q_config["q_group_size"] if q_config["q_group_size"] > 0 else w.shape[1] | ||
input_feat = input_feat.view(-1, input_feat.shape[-1]) | ||
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) | ||
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token] | ||
w = w.reshape(w.shape[0], 1, -1, group_size) | ||
|
||
oc_batch_size = 256 # prevent OOM | ||
assert w.shape[0] % oc_batch_size == 0 | ||
w_all = w | ||
best_max_val_all = [] | ||
|
||
for i_b in range(w.shape[0] // oc_batch_size): | ||
w = w_all[i_b * oc_batch_size:(i_b + 1) * oc_batch_size] | ||
|
||
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 | ||
|
||
best_max_val = org_max_val.clone() | ||
min_errs = torch.ones_like(org_max_val) * 1e9 | ||
input_feat = input_feat.to(w.device) | ||
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group | ||
|
||
for i_s in range(int(max_shrink * n_grid)): | ||
max_val = org_max_val * (1 - i_s / n_grid) | ||
min_val = -max_val | ||
cur_w = torch.clamp(w, min_val, max_val) | ||
q_w = pseudo_quantize_tensor(cur_w, n_bit=n_bit, **q_config) | ||
cur_out = (input_feat * q_w).sum(dim=-1) | ||
|
||
# co, 1, n_group, 1 | ||
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) | ||
del cur_w | ||
del cur_out | ||
cur_best_idx = err < min_errs | ||
min_errs[cur_best_idx] = err[cur_best_idx] | ||
best_max_val[cur_best_idx] = max_val[cur_best_idx] | ||
best_max_val_all.append(best_max_val) | ||
|
||
best_max_val = torch.cat(best_max_val_all, dim=0) | ||
|
||
del input_feat | ||
del org_out | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
return best_max_val.squeeze(1) | ||
|
||
|
||
@torch.no_grad() | ||
def auto_clip_block(module, w_bit, q_config, input_feat): | ||
|
||
named_linears = {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} | ||
|
||
clip_list = [] | ||
for name in named_linears: | ||
# due to qk bmm, it is hard to clip precisely | ||
if any([_ in name for _ in ["q_", "k_"]]): | ||
continue | ||
max_val = auto_clip_layer( | ||
named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config) | ||
clip_list.append((name, max_val)) | ||
return clip_list | ||
|
||
|
||
@torch.no_grad() | ||
def apply_clip(module, clip_list): | ||
from .utils.module import get_op_by_name | ||
for name, max_val in clip_list: | ||
layer = get_op_by_name(module, name) | ||
max_val = max_val.to(layer.weight.device) | ||
org_shape = layer.weight.shape | ||
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) | ||
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) | ||
layer.weight.data = layer.weight.data.reshape(org_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
""" | ||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
Adapted from https://github.com/mit-han-lab/llm-awq, released under the following LICENSE: | ||
MIT License | ||
Copyright (c) 2023 MIT HAN Lab | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | ||
from transformers.models.llama.modeling_llama import LlamaRMSNorm | ||
from transformers.models.opt.modeling_opt import OPTDecoderLayer | ||
|
||
from .utils.module import get_op_by_name | ||
from .utils.module import get_op_name | ||
|
||
__all__ = ["auto_scale_block", "apply_scale"] | ||
|
||
|
||
@torch.no_grad() | ||
def get_weight_scale(weight, q_group_size=-1): | ||
org_shape = weight.shape | ||
if q_group_size > 0: | ||
weight = weight.view(-1, q_group_size) | ||
scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) | ||
scale = scale.view(org_shape) | ||
scale = scale.mean(0) | ||
return scale | ||
|
||
|
||
@torch.no_grad() | ||
def get_act_scale(x): | ||
return x.abs().view(-1, x.shape[-1]).mean(0) | ||
|
||
|
||
@torch.no_grad() | ||
def scale_ln_fcs(ln, fcs, scales): | ||
if not isinstance(fcs, list): | ||
fcs = [fcs] | ||
|
||
scales = scales.to(ln.weight.device) | ||
|
||
ln.weight.div_(scales) | ||
if hasattr(ln, 'bias') and ln.bias is not None: | ||
ln.bias.div_(scales) | ||
|
||
for fc in fcs: | ||
fc.weight.mul_(scales.view(1, -1)) | ||
|
||
for p in ln.parameters(): | ||
assert torch.isnan(p).sum() == 0 | ||
for fc in fcs: | ||
for p in fc.parameters(): | ||
assert torch.isnan(p).sum() == 0 | ||
|
||
|
||
@torch.no_grad() | ||
def scale_fc_fc(fc1, fc2, scales): | ||
assert isinstance(fc1, nn.Linear) | ||
assert isinstance(fc2, nn.Linear) | ||
assert fc1.out_features == fc2.in_features | ||
|
||
scales = scales.to(fc1.weight.device) | ||
|
||
fc1.weight.div_(scales.view(-1, 1)) | ||
if fc1.bias is not None: | ||
fc1.bias.div_(scales.view(-1)) | ||
|
||
fc2.weight.mul_(scales.view(1, -1)) | ||
|
||
for p in fc1.parameters(): | ||
assert torch.isnan(p).sum() == 0 | ||
for p in fc2.parameters(): | ||
assert torch.isnan(p).sum() == 0 | ||
|
||
|
||
@torch.no_grad() | ||
def auto_scale_block(module, module_kwargs, w_bit, q_config, input_feat): | ||
from .quantizer import pseudo_quantize_tensor | ||
|
||
# firstly, get the weight quantize function | ||
if w_bit is not None: | ||
|
||
def w_quantize_func(p): | ||
return pseudo_quantize_tensor( | ||
p, | ||
n_bit=w_bit, | ||
**q_config, | ||
).detach() | ||
else: | ||
|
||
def w_quantize_func(p): | ||
return p | ||
|
||
if "use_cache" in module_kwargs: | ||
module_kwargs.pop("use_cache") | ||
|
||
# find the best scale ratio | ||
def _search_module_scale(block, linears2scale: list, x, kwargs={}): | ||
# w: co, ci | ||
# x: n, ci | ||
x = x.to(next(block.parameters()).device) | ||
weight = torch.cat([_m.weight for _m in linears2scale], dim=0) | ||
w_max = get_weight_scale(weight, q_group_size=q_config.get("q_group_size", -1)) | ||
|
||
with torch.no_grad(): | ||
org_out = block(x, **kwargs) | ||
if isinstance(org_out, tuple): | ||
org_out = org_out[0] | ||
|
||
x_max = get_act_scale(x) | ||
|
||
best_error = float('inf') | ||
best_ratio = -1 | ||
best_scales = None | ||
|
||
n_grid = 20 | ||
history = [] | ||
|
||
org_sd = {k: v.cpu() for k, v in block.state_dict().items()} | ||
for ratio in range(n_grid): | ||
ratio = ratio * 1 / n_grid | ||
scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4).view(-1) | ||
scales = scales / (scales.max() * scales.min()).sqrt() | ||
for fc in linears2scale: | ||
fc.weight.mul_(scales.view(1, -1)) | ||
fc.weight.data = w_quantize_func(fc.weight.data) / (scales.view(1, -1)) | ||
out = block(x, **kwargs) | ||
if isinstance(out, tuple): | ||
out = out[0] | ||
|
||
loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow | ||
history.append(loss) | ||
is_best = loss < best_error | ||
if is_best: | ||
best_error = loss | ||
best_ratio = ratio | ||
best_scales = scales | ||
block.load_state_dict(org_sd) | ||
if best_ratio == -1: | ||
print(history) | ||
raise Exception | ||
# print(best_ratio) | ||
best_scales = best_scales.view(-1) | ||
|
||
assert torch.isnan(best_scales).sum() == 0, best_scales | ||
return best_scales.detach() | ||
|
||
def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}): | ||
# module2inspect: if given, we will check the output diff of this module instead of layers | ||
if module2inspect is None: | ||
assert len(layers) == 1 | ||
module2inspect = layers[0] | ||
|
||
scales = _search_module_scale(module2inspect, layers, inp, kwargs) | ||
# prev_op_name, [layer_name], scale | ||
return ( | ||
get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales) | ||
|
||
scales_list = [] # return the searched scales | ||
|
||
if isinstance(module, OPTDecoderLayer): | ||
# attention input | ||
scales_list.append( | ||
_auto_get_scale( | ||
prev_op=module.self_attn_layer_norm, | ||
layers=[module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj], | ||
inp=input_feat['self_attn.q_proj'], | ||
module2inspect=module.self_attn, | ||
kwargs=module_kwargs, | ||
)) | ||
# attn out | ||
scales_list.append( | ||
_auto_get_scale( | ||
prev_op=module.self_attn.v_proj, | ||
layers=[module.self_attn.out_proj], | ||
inp=input_feat['self_attn.out_proj'], | ||
)) | ||
# fc1 | ||
scales_list.append( | ||
_auto_get_scale( | ||
prev_op=module.final_layer_norm, | ||
layers=[module.fc1], | ||
inp=input_feat['fc1'], | ||
)) | ||
# fc2 | ||
scales_list.append( | ||
_auto_get_scale( | ||
prev_op=module.fc1, | ||
layers=[module.fc2], | ||
inp=input_feat['fc2'], | ||
)) | ||
|
||
elif isinstance(module, LlamaDecoderLayer): | ||
# attention input | ||
scales_list.append( | ||
_auto_get_scale( | ||
prev_op=module.input_layernorm, | ||
layers=[module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj], | ||
inp=input_feat['self_attn.q_proj'], | ||
module2inspect=module.self_attn, | ||
kwargs=module_kwargs, | ||
)) | ||
# attn out | ||
scales_list.append( | ||
_auto_get_scale( | ||
prev_op=module.self_attn.v_proj, | ||
layers=[module.self_attn.o_proj], | ||
inp=input_feat['self_attn.o_proj'], | ||
)) | ||
# fc1 | ||
scales_list.append( | ||
_auto_get_scale( | ||
prev_op=module.post_attention_layernorm, | ||
layers=[module.mlp.gate_proj, module.mlp.up_proj], | ||
inp=input_feat['mlp.gate_proj'], | ||
module2inspect=module.mlp, | ||
)) | ||
# fc2 | ||
scales_list.append( | ||
_auto_get_scale( | ||
prev_op=module.mlp.up_proj, | ||
layers=[module.mlp.down_proj], | ||
inp=input_feat['mlp.down_proj'], | ||
)) | ||
|
||
else: | ||
raise NotImplementedError(f"{type(module)} not supported yet!") | ||
|
||
return scales_list | ||
|
||
|
||
def apply_scale(module, scales_list, input_feat_dict=None): | ||
for prev_op_name, layer_names, scales in scales_list: | ||
prev_op = get_op_by_name(module, prev_op_name) | ||
layers = [get_op_by_name(module, name) for name in layer_names] | ||
|
||
if isinstance(prev_op, nn.Linear): | ||
assert len(layers) == 1 | ||
scale_fc_fc(prev_op, layers[0], scales) | ||
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)): | ||
scale_ln_fcs(prev_op, layers, scales) | ||
else: | ||
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!") | ||
|
||
# apply the scaling to input feat if given; prepare it for clipping | ||
if input_feat_dict is not None: | ||
for layer_name in layer_names: | ||
inp = input_feat_dict[layer_name] | ||
inp.div_(scales.view(1, -1).to(inp.device)) |
Oops, something went wrong.