From 51baf37cadaf90a470c192f24f97cc3736a46974 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo <1934033+volcacius@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:24:33 +0200 Subject: [PATCH] Examples: initial support for LLMs PTQ (#658) * Examples: WIP LLM block quantization * Add support for block zero-point * Add torch-mlir custom op support * Add test linear Signed-off-by: Alessandro Pappalardo * Update to custom matmul export Signed-off-by: Alessandro Pappalardo * Fix errors Signed-off-by: Alessandro Pappalardo * Fix output shape of custom op Signed-off-by: Alessandro Pappalardo * Add lowering to torch_mlir for single layer Signed-off-by: Alessandro Pappalardo * Some cleanups * WIP llm flow Signed-off-by: Alessandro Pappalardo * Fix (examples/llm): typo in custom quant matmul op (#607) * Test act equalization support * Initial end to end flow * Initial support for QuantMHA on OPT * Fix act equalization * Typos in prints * Reorganize validate * Add initial per row quantizers * Add per row input quantization support * Support group quant slicing * Adopt SliceTensor for block weight partial quant * Add float16 support * Fix scale type name * Add support for LN affine merging * WIP currently broken * Clean up weight eq support * Set weight narrow range always to False * Add fx act equalization, fixes for float16 support * Fix validate * Fix backport imports * Fix example export Signed-off-by: Alessandro Pappalardo * Fix value_trace call in ln affine merging * Add per tensor/row/group dynamic scale support, some dtype improvements * Fix (llm): correct handling of attention mask shape (#652) * ALways export in fp32 base dtype on CPU * Export improvements * Fix errors after latest PR --------- Signed-off-by: Alessandro Pappalardo Co-authored-by: jinchen62 <49575973+jinchen62@users.noreply.github.com> Co-authored-by: Giuseppe Franco --- setup.py | 4 +- src/brevitas/backport/__init__.py | 1 + .../backport/fx/experimental/proxy_tensor.py | 5 +- src/brevitas_examples/llm/README.md | 60 +++ src/brevitas_examples/llm/__init__.py | 0 .../llm/llm_quant/__init__.py | 0 .../llm/llm_quant/bias_corr.py | 26 ++ .../llm/llm_quant/calibrate.py | 26 ++ src/brevitas_examples/llm/llm_quant/data.py | 71 +++ .../llm/llm_quant/equalize.py | 74 ++++ src/brevitas_examples/llm/llm_quant/eval.py | 91 ++++ src/brevitas_examples/llm/llm_quant/export.py | 207 +++++++++ src/brevitas_examples/llm/llm_quant/gptq.py | 33 ++ .../llm/llm_quant/ln_affine_merge.py | 92 ++++ .../llm/llm_quant/mha_layers.py | 177 ++++++++ .../llm/llm_quant/mlir_custom_mm.py | 113 +++++ .../llm/llm_quant/prepare_for_quantize.py | 17 + .../llm/llm_quant/quant_blocks.py | 135 ++++++ .../llm/llm_quant/quantize.py | 284 ++++++++++++ .../llm/llm_quant/quantizers.py | 132 ++++++ .../llm/llm_quant/run_utils.py | 164 +++++++ .../llm_quant/sharded_mlir_group_export.py | 418 ++++++++++++++++++ src/brevitas_examples/llm/main.py | 300 +++++++++++++ .../llm/test_linear_mlir_export.py | 126 ++++++ 24 files changed, 2552 insertions(+), 4 deletions(-) create mode 100644 src/brevitas_examples/llm/README.md create mode 100644 src/brevitas_examples/llm/__init__.py create mode 100644 src/brevitas_examples/llm/llm_quant/__init__.py create mode 100644 src/brevitas_examples/llm/llm_quant/bias_corr.py create mode 100644 src/brevitas_examples/llm/llm_quant/calibrate.py create mode 100644 src/brevitas_examples/llm/llm_quant/data.py create mode 100644 src/brevitas_examples/llm/llm_quant/equalize.py create mode 100644 src/brevitas_examples/llm/llm_quant/eval.py create mode 100644 src/brevitas_examples/llm/llm_quant/export.py create mode 100644 src/brevitas_examples/llm/llm_quant/gptq.py create mode 100644 src/brevitas_examples/llm/llm_quant/ln_affine_merge.py create mode 100644 src/brevitas_examples/llm/llm_quant/mha_layers.py create mode 100644 src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py create mode 100644 src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py create mode 100644 src/brevitas_examples/llm/llm_quant/quant_blocks.py create mode 100644 src/brevitas_examples/llm/llm_quant/quantize.py create mode 100644 src/brevitas_examples/llm/llm_quant/quantizers.py create mode 100644 src/brevitas_examples/llm/llm_quant/run_utils.py create mode 100644 src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py create mode 100644 src/brevitas_examples/llm/main.py create mode 100644 src/brevitas_examples/llm/test_linear_mlir_export.py diff --git a/setup.py b/setup.py index 323472c31..b6c303fe8 100644 --- a/setup.py +++ b/setup.py @@ -55,5 +55,5 @@ def read_requirements(filename): 'brevitas_quartznet_preprocess = brevitas_examples.speech_to_text.get_librispeech_data:main', 'brevitas_melgan_preprocess = brevitas_examples.text_to_speech.preprocess_dataset:main', 'brevitas_ptq_imagenet_benchmark = brevitas_examples.imagenet_classification.ptq.ptq_benchmark:main', - 'brevitas_ptq_imagenet_val = brevitas_examples.imagenet_classification.ptq.ptq_evaluate:main' - ],}) + 'brevitas_ptq_imagenet_val = brevitas_examples.imagenet_classification.ptq.ptq_evaluate:main', + 'brevitas_ptq_llm = brevitas_examples.llm.main:main'],}) diff --git a/src/brevitas/backport/__init__.py b/src/brevitas/backport/__init__.py index 00102698a..57912f0e2 100644 --- a/src/brevitas/backport/__init__.py +++ b/src/brevitas/backport/__init__.py @@ -251,3 +251,4 @@ def sym_min(a, b): # Populate magic methods on SymInt and SymFloat import brevitas.backport.fx.experimental.symbolic_shapes +import brevitas.backport.fx diff --git a/src/brevitas/backport/fx/experimental/proxy_tensor.py b/src/brevitas/backport/fx/experimental/proxy_tensor.py index f920e9f8e..430623e9d 100644 --- a/src/brevitas/backport/fx/experimental/proxy_tensor.py +++ b/src/brevitas/backport/fx/experimental/proxy_tensor.py @@ -62,13 +62,14 @@ from torch.utils._python_dispatch import TorchDispatchMode import torch.utils._pytree as pytree +from brevitas import backport +from brevitas.backport import fx from brevitas.backport import SymBool from brevitas.backport import SymFloat from brevitas.backport import SymInt from brevitas.backport.fx import GraphModule from brevitas.backport.fx import Proxy from brevitas.backport.fx import Tracer -import brevitas.backport.fx as fx from brevitas.backport.fx.passes.shape_prop import _extract_tensor_metadata from brevitas.backport.utils._stats import count from brevitas.backport.utils.weak import WeakTensorKeyDictionary @@ -316,7 +317,7 @@ def proxy_call(proxy_mode, func, args, kwargs): # `__torch_dispatch__` is only called on torch ops, which must subclass `OpOverload` # We treat all other functions as an `external_call`, for instance, a function decorated # with `@torch.fx.wrap` - external_call = not isinstance(func, fx.backport._ops.OpOverload) + external_call = not isinstance(func, backport._ops.OpOverload) def can_handle_tensor(x): return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md new file mode 100644 index 000000000..d2ca7fbcb --- /dev/null +++ b/src/brevitas_examples/llm/README.md @@ -0,0 +1,60 @@ +# LLM quantization + +## Requirements + +- transformers +- datasets +- torch_mlir (optional for torch-mlir based export) + +## Run + +Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. + +```bash +usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] [--weight-bit-width WEIGHT_BIT_WIDTH] [--weight-param-method {stats,mse}] + [--weight-scale-type {float32,po2}] [--weight-quant-type {sym,asym}] [--weight-quant-granularity {per_channel,per_tensor,per_group}] + [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point] [--input-bit-width INPUT_BIT_WIDTH] [--input-param-method {stats,mse}] + [--input-scale-type {float32,po2}] [--input-quant-type {sym,asym}] [--input-quant-granularity {per_tensor}] [--quantize-input-zero-point] [--gptq] + [--act-calibration] [--bias-corr] [--act-equalization] + [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] + +optional arguments: + -h, --help show this help message and exit + --model MODEL HF model name. Default: facebook/opt-125m. + --seed SEED Seed for sampling the calibration data. Default: 0. + --nsamples NSAMPLES Number of calibration data samples. Default: 128. + --seqlen SEQLEN Sequence length. Default: 2048. + --eval Eval model PPL on C4. + --weight-bit-width WEIGHT_BIT_WIDTH + Weight bit width. Default: 8. + --weight-param-method {stats,mse} + How scales/zero-point are determined. Default: stats. + --weight-scale-type {float32,po2} + Whether scale is a float value or a po2. Default: po2. + --weight-quant-type {sym,asym} + Weight quantization type. Default: asym. + --weight-quant-granularity {per_channel,per_tensor,per_group} + Granularity for scales/zero-point of weights. Default: per_group. + --weight-group-size WEIGHT_GROUP_SIZE + Group size for per_group weight quantization. Default: 128. + --quantize-weight-zero-point + Quantize weight zero-point. + --input-bit-width INPUT_BIT_WIDTH + Input bit width. Default: None (disables input quantization). + --input-param-method {stats,mse} + How scales/zero-point are determined. Default: stats. + --input-scale-type {float32,po2} + Whether input scale is a float value or a po2. Default: float32. + --input-quant-type {sym,asym} + Input quantization type. Default: asym. + --input-quant-granularity {per_tensor} + Granularity for scales/zero-point of inputs. Default: per_tensor. + --quantize-input-zero-point + Quantize input zero-point. + --gptq Apply GPTQ. + --act-calibration Apply activation calibration. + --bias-corr Apply bias correction. + --act-equalization Apply activation equalization (SmoothQuant). + --export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight} + Model export. +``` diff --git a/src/brevitas_examples/llm/__init__.py b/src/brevitas_examples/llm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/brevitas_examples/llm/llm_quant/__init__.py b/src/brevitas_examples/llm/llm_quant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/brevitas_examples/llm/llm_quant/bias_corr.py b/src/brevitas_examples/llm/llm_quant/bias_corr.py new file mode 100644 index 000000000..c3be64c83 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/bias_corr.py @@ -0,0 +1,26 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +import torch + +from brevitas.graph.calibrate import bias_correction_mode +from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn + + +@torch.no_grad() +def bias_corr_iter(curr_layer, inps, outs, cached_values): + curr_layer = curr_layer.cuda() + with bias_correction_mode(curr_layer): + for j in range(len(inps)): + inp = inps[j].unsqueeze(0).cuda() + curr_out = curr_layer(inp, **cached_values)[0] + outs[j] = curr_out + curr_layer.cpu() + return outs + + +@torch.no_grad() +def apply_bias_correction(model, dataloader, nsamples, seqlen=2048): + apply_layer_ptq_fn(model, dataloader, nsamples, inference_fn=bias_corr_iter, seqlen=seqlen) diff --git a/src/brevitas_examples/llm/llm_quant/calibrate.py b/src/brevitas_examples/llm/llm_quant/calibrate.py new file mode 100644 index 000000000..9c5f00dac --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/calibrate.py @@ -0,0 +1,26 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +import torch + +from brevitas.graph.calibrate import calibration_mode +from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn + + +@torch.no_grad() +def calibration_iter(curr_layer, inps, outs, cached_values): + curr_layer = curr_layer.cuda() + with calibration_mode(curr_layer): + for j in range(len(inps)): + inp = inps[j].unsqueeze(0).cuda() + curr_out = curr_layer(inp, **cached_values)[0] + outs[j] = curr_out + curr_layer.cpu() + return outs + + +@torch.no_grad() +def apply_calibration(model, dataloader, nsamples, seqlen=2048): + apply_layer_ptq_fn(model, dataloader, nsamples, inference_fn=calibration_iter, seqlen=seqlen) diff --git a/src/brevitas_examples/llm/llm_quant/data.py b/src/brevitas_examples/llm/llm_quant/data.py new file mode 100644 index 000000000..6fa9fef00 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/data.py @@ -0,0 +1,71 @@ +""" +Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: + +Copyright 2023 IST-DASLab + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import random + +from datasets import load_dataset +import torch +from transformers import AutoTokenizer + + +def get_c4(nsamples, seed, seqlen, model, nvalsamples=256): + traindata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, + split='train', + use_auth_token=False) + valdata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation', + use_auth_token=False) + + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + trainloader.append(inp) + + random.seed(0) # hardcoded for validation reproducibility + valenc = [] + for _ in range(nvalsamples): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]['text'], return_tensors='pt') + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + + valenc = torch.hstack(valenc) + return trainloader, valenc diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py new file mode 100644 index 000000000..f3e4c3b0d --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -0,0 +1,74 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +import warnings + +import torch + +from brevitas.fx.brevitas_tracer import value_trace +from brevitas.graph.equalize import activation_equalization_mode +from brevitas.graph.equalize import EqualizeGraph +from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn +from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 + + +@torch.no_grad() +def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha): + curr_layer = curr_layer.cuda() + with activation_equalization_mode(curr_layer, alpha, add_mul_node=True, layerwise=True): + for j in range(len(inps)): + inp = inps[j].unsqueeze(0).cuda() + curr_out = curr_layer(inp, **cached_values)[0] + outs[j] = curr_out + curr_layer.cpu() + return outs + + +@torch.no_grad() +def apply_act_equalization( + model, + dtype, + act_equalization_type, + dataloader, + nsamples, + seqlen=2048, + alpha=0.5, + ref_kwargs=None): + if act_equalization_type == 'layerwise': + apply_layer_ptq_fn( + model, + dataloader, + nsamples, + inference_fn=activation_equalization_iter, + seqlen=seqlen, + alpha=alpha) + elif act_equalization_type == 'fx': + assert ref_kwargs is not None, "Ref kwargs required to perform tracing and lift the model into FX." + # We can't do fp16 tracing on CPU as many kernels are not implemented + # So we have to cast to fp32 first, trace, apply equalization, and then cast back + with cast_to_float32(model, dtype): + graph_model = value_trace(model, value_args=ref_kwargs) + # TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode + # or an FX interpreter to run it on GPU + warnings.warn( + "FX mode activation equalization currently runs on CPU, expect it to be slow for large models." + ) + with activation_equalization_mode(graph_model, + alpha, + add_mul_node=False, + layerwise=False): + for input_ids in dataloader: + graph_model(input_ids=input_ids) + else: + raise RuntimeError(f"{act_equalization_type} not supported.") + + +@torch.no_grad() +def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='range'): + # We can't do fp16 tracing on CPU as many kernels are not implemented + # So we have to cast to fp32 first, trace, apply equalization, and then cast back + with cast_to_float32(model, dtype): + graph_model = value_trace(model, value_args=ref_kwargs) + EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model) diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py new file mode 100644 index 000000000..8d7834085 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -0,0 +1,91 @@ +""" +Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: + +Copyright 2023 IST-DASLab + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from torch import nn +from tqdm import tqdm + +from brevitas_examples.llm.llm_quant.run_utils import apply_layer_inference_fn +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl +from brevitas_examples.llm.llm_quant.run_utils import InputCatcherException + + +def eval_inference_fn(curr_layer, inps, outs, cached_values): + curr_layer.cuda() + for j in range(len(inps)): + outs[j] = curr_layer(inps[j].unsqueeze(0).cuda(), **cached_values)[0] + curr_layer.cpu() + + +@torch.no_grad() +def model_eval(model, valenc, seqlen): + + nsamples = valenc.numel() // seqlen + + def eval_input_capture_fn(model, data): + for i in range(nsamples): + batch = data[:, (i * seqlen):((i + 1) * seqlen)].cuda() + try: + model(batch) + except InputCatcherException: + pass + + inps = apply_layer_inference_fn( + model, + valenc, + nsamples, + input_capture_fn=eval_input_capture_fn, + inference_fn=eval_inference_fn, + seqlen=seqlen) + + model_impl = get_model_impl(model) + use_cache = model.config.use_cache + model.config.use_cache = False + + if hasattr(model_impl, 'norm') and model_impl.norm is not None: + model_impl.norm = model_impl.norm.cuda() + if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None: + model_impl.final_layer_norm = model_impl.final_layer_norm.cuda() + if hasattr(model_impl, 'project_out') and model_impl.project_out is not None: + model_impl.project_out = model_impl.project_out.cuda() + if hasattr(model, 'lm_head'): + model.lm_head = model.lm_head.cuda() + + valenc = valenc.cuda() + nlls = [] + for i in tqdm(range(nsamples)): + hidden_states = inps[i].unsqueeze(0) + if hasattr(model_impl, 'norm') and model_impl.norm is not None: + hidden_states = model_impl.norm(hidden_states) + if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None: + hidden_states = model_impl.final_layer_norm(hidden_states) + if hasattr(model_impl, 'project_out') and model_impl.project_out is not None: + hidden_states = model_impl.project_out(hidden_states) + lm_logits = hidden_states + if hasattr(model, 'lm_head'): + lm_logits = model.lm_head(lm_logits) + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = valenc[:, (i * seqlen):((i + 1) * seqlen)][:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + neg_log_likelihood = loss.float() * seqlen + nlls.append(neg_log_likelihood) + + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen)) + model.config.use_cache = use_cache + return ppl diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py new file mode 100644 index 000000000..2edd9f777 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -0,0 +1,207 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +from abc import ABC +from abc import abstractmethod +from contextlib import contextmanager + +import torch +from torch.nn import Module + +from brevitas.export.common.handler.base import BaseHandler +from brevitas.export.manager import _set_layer_export_handler +from brevitas.export.manager import _set_layer_export_mode +from brevitas.export.manager import _set_proxy_export_handler +from brevitas.export.manager import _set_proxy_export_mode +from brevitas.export.manager import BaseManager +from brevitas.nn import QuantLinear +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector + + +class WeightBlockQuantHandlerBase(BaseHandler, ABC): + handled_layer = WeightQuantProxyFromInjector + + def __init__(self): + super(WeightBlockQuantHandlerBase, self).__init__() + self.int_weight = None + self.scale = None + self.zero_point = None + self.bit_width = None + self.dtype = None + + def scaling_impl(self, proxy_module): + return proxy_module.tensor_quant.scaling_impl + + def zero_point_impl(self, proxy_module): + return proxy_module.tensor_quant.zero_point_impl + + def bit_width_impl(self, proxy_module): + return proxy_module.tensor_quant.msb_clamp_bit_width_impl + + def export_scale(self, proxy_module, bit_width): + scaling_impl = self.scaling_impl(proxy_module) + int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl + int_threshold = int_scaling_impl(bit_width) + threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl( + scaling_impl.wrapped_scaling_impl.parameter_list_stats()) + return threshold / int_threshold + + def export_zero_point(self, proxy_module, scale, bit_width): + zero_point_impl = self.zero_point_impl(proxy_module) + return zero_point_impl.unexpanded_zero_point(scale, bit_width) + + @abstractmethod + def prepare_for_export(self, module): + pass + + @abstractmethod + def forward(self, x): + pass + + +class WeightBlockQuantProxyHandler(WeightBlockQuantHandlerBase): + + def __init__(self): + super().__init__() + self.expanded_scaling_shape = None + self.reshaped_scaling_shape = None + self.expanded_zero_point_shape = None + self.reshaped_zero_point_shape = None + + def prepare_for_export(self, module): + assert len(module.tracked_module_list) == 1, "Shared quantizers not supported." + self.bit_width = self.bit_width_impl(module)() + assert self.bit_width <= 8., "Only 8b or lower is supported." + quant_layer = module.tracked_module_list[0] + quant_weight = quant_layer.quant_weight() + self.int_weight = quant_weight.int().detach() + self.dtype = quant_weight.value.dtype + self.scale = self.export_scale(module, self.bit_width).detach() + self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape + self.reshaped_scaling_shape = self.scaling_impl(module).reshaped_scaling_shape + if (quant_weight.zero_point != 0.).any(): + self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach() + self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape + self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape + + def forward(self, x): + scale = self.scale.expand(self.expanded_scaling_shape).contiguous() + # contiguous above is to avoid the reshape below being mapped to a unsafe view + scale = scale.view(self.reshaped_scaling_shape) + int_weight = self.int_weight + if self.zero_point is not None: + zero_point = self.zero_point.expand(self.expanded_zero_point_shape).contiguous() + # contiguous above is to avoid the reshape below being mapped to a unsafe view + zero_point = zero_point.view(self.reshaped_zero_point_shape) + # avoid unsigned subtraction + int_weight = int_weight.to(self.dtype) - zero_point.to(self.dtype) + else: + zero_point = torch.zeros_like(scale) + quant_weight = int_weight * scale + return quant_weight, scale, zero_point, self.bit_width + + +class LinearWeightBlockQuantHandler(WeightBlockQuantHandlerBase, ABC): + handled_layer = QuantLinear + + def __init__(self): + super(LinearWeightBlockQuantHandler, self).__init__() + self.group_size = None + + def pack_int_weights(self, bit_width, int_weights): + assert int_weights.dtype in [torch.int8, torch.uint8], "Packing requires (u)int8 input." + if bit_width == 8: + return int_weights + elif bit_width == 4 or bit_width == 2: + packed_int_weights = torch.zeros( + (int_weights.shape[0], int_weights.shape[1] * bit_width // 8), + device=int_weights.device, + dtype=int_weights.dtype) + i = 0 + for column in range(packed_int_weights.shape[1]): + # Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b + # https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346 + for j in range(i, i + (8 // bit_width)): + shift_factor = (bit_width * (j - i)) + packed_int_weights[:, column] |= int_weights[:, j] << shift_factor + i += 8 // bit_width + return packed_int_weights + else: + raise ValueError(f"Bit width {bit_width} not supported.") + + def prepare_for_export(self, module): + self.bit_width = self.bit_width_impl(module.weight_quant)() + assert self.bit_width <= 8., "Only 8b or lower is supported." + quant_weight = module.quant_weight() + self.bias = module.bias + self.scale = self.export_scale(module.weight_quant, self.bit_width) + if (quant_weight.zero_point != 0.).any(): + self.zero_point = self.export_zero_point( + module.weight_quant, self.scale, self.bit_width) + else: + # if there is no zero-point, export zeroes in the shape of scale + self.zero_point = torch.zeros_like(self.scale) + self.group_size = module.weight_quant.quant_injector.block_size + self.bit_width = int(self.bit_width.cpu().item()) + self.int_weight = self.pack_int_weights(self.bit_width, quant_weight.int().detach()) + + @abstractmethod + def forward(self, x): + pass + + +class BlockQuantProxyLevelManager(BaseManager): + + handlers = [WeightBlockQuantProxyHandler] + + @classmethod + def set_export_handler(cls, module): + _set_proxy_export_handler(cls, module) + + +def block_quant_layer_level_manager(export_handlers): + + class BlockQuantLayerLevelManager(BaseManager): + handlers = export_handlers + + @classmethod + def set_export_handler(cls, module): + _set_layer_export_handler(cls, module) + + return BlockQuantLayerLevelManager + + +@contextmanager +def brevitas_proxy_export_mode(model, export_manager=BlockQuantProxyLevelManager): + is_training = model.training + model.eval() + model.apply(export_manager.set_export_handler) + _set_proxy_export_mode(model, enabled=True) + try: + yield model + finally: + _set_proxy_export_mode(model, enabled=False) + model.train(is_training) + + +@contextmanager +def brevitas_layer_export_mode(model, export_manager): + is_training = model.training + model.eval() + model.apply(export_manager.set_export_handler) + _set_layer_export_mode(model, enabled=True) + try: + yield model + finally: + _set_layer_export_mode(model, enabled=False) + model.train(is_training) + + +def replace_call_fn_target(graph_model, src, target): + for node in graph_model.graph.nodes: + if node.op == "call_function" and node.target is src: + node.target = target + graph_model.graph.lint() + graph_model.recompile() diff --git a/src/brevitas_examples/llm/llm_quant/gptq.py b/src/brevitas_examples/llm/llm_quant/gptq.py new file mode 100644 index 000000000..2e73bdf76 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/gptq.py @@ -0,0 +1,33 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +import torch + +from brevitas.graph.gptq import gptq_mode +from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn + + +@torch.no_grad() +def gptq_iter(curr_layer, inps, outs, cached_values, act_order): + curr_layer = curr_layer.cuda() + with gptq_mode(curr_layer, use_quant_activations=False, act_order=act_order) as gptq: + gptq_layer = gptq.model + for _ in range(gptq.num_layers): + for j in range(len(inps)): + curr_inp = inps[j].unsqueeze(0).cuda() + gptq_layer(curr_inp, **cached_values) + gptq.update() + for j in range(len(inps)): + inp = inps[j].unsqueeze(0).cuda() + curr_out = curr_layer(inp, **cached_values)[0] + outs[j] = curr_out + curr_layer.cpu() + return outs + + +@torch.no_grad() +def apply_gptq(model, dataloader, nsamples, act_order=True, seqlen=2048): + apply_layer_ptq_fn( + model, dataloader, nsamples, inference_fn=gptq_iter, seqlen=seqlen, act_order=act_order) diff --git a/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py new file mode 100644 index 000000000..37aa8d5d3 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/ln_affine_merge.py @@ -0,0 +1,92 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: MIT +""" + +import torch +from torch import nn + +from brevitas.fx import value_trace +from brevitas.graph.equalize import _is_reshaping_op +from brevitas.graph.equalize import _is_scale_invariant_module +from brevitas.graph.utils import get_module +from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 + + +def replace_bias(next_module, new_bias): + new_bias = new_bias.view(-1) + if next_module.bias is not None: + next_module.bias.data.copy_(new_bias) + else: + new_bias = new_bias.to(next_module.weight.device).to(next_module.weight.dtype) + next_module.register_parameter('bias', torch.nn.Parameter(new_bias)) + + +def _merge_ln(layer_norm, next_module, scale_bias_by_weight): + if not layer_norm.elementwise_affine: + return False + if not isinstance(next_module, nn.Linear): + return False + view_shape = (1, -1) + # Merge weight + if scale_bias_by_weight: + layer_norm.bias.data /= layer_norm.weight.data + # We can't do an inplace update as some layers we merge into like lm_head might share the weight tensor + scale = layer_norm.weight.data.view(view_shape).expand_as(next_module.weight) + next_module.weight = torch.nn.Parameter(next_module.weight.clone() * scale) + # Merge bias, new_bias includes the bias of next_module by going through its fwd + inp = layer_norm.bias.data.view(view_shape) + new_bias = next_module(inp) + replace_bias(next_module, new_bias) + return True + + +def merge_layernorm_affine_params(graph_model): + merged_dict = {} + merged_into_layers = [] + scaled_biases = set() + for node in graph_model.graph.nodes: + if node.op == 'call_module': + module = get_module(graph_model, node.target) + if isinstance(module, nn.LayerNorm): + for next in node.users: + while (_is_reshaping_op(next) or _is_scale_invariant_module(graph_model, next)): + next = node.next + if next.op == 'call_module': + next_module = get_module(graph_model, next.target) + scale_bias = node.target not in scaled_biases + merged = _merge_ln(module, next_module, scale_bias_by_weight=scale_bias) + if merged: + print( + f"{module.__class__.__name__} {node.target} merged into {next.target}." + ) + merged_into_layers.append(next.target) + scaled_biases.add(node.target) + if module in merged_dict: + merged_dict[module] &= merged + else: + merged_dict[module] = merged + elif next.op == 'call_method' and next.target == 'size': + continue + else: + raise RuntimeError( + f"Unsupported user node {next.op} with target {next.target}. Disable LN affine merging." + ) + for module, merged in merged_dict.items(): + if merged: + # We preserve weight and bias in case they are used to merge SmoothQuant scales in fx mode later on + module.weight.data.fill_(1.) + module.bias.data.fill_(0.) + else: + raise RuntimeError( + f"Merged only into some users: {merged_dict}. Disable LN affine merging.") + return merged_into_layers + + +@torch.no_grad() +def apply_layernorm_affine_merge(model, dtype, ref_kwargs): + # We can't do fp16 tracing on CPU as many kernels are not implemented + # So we have to cast to fp32 first, trace, apply merging, and then cast back + with cast_to_float32(model, dtype): + graph_model = value_trace(model, value_args=ref_kwargs) + merge_layernorm_affine_params(graph_model) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py new file mode 100644 index 000000000..cf694d4eb --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -0,0 +1,177 @@ +from typing import Optional, Tuple + +import torch +from torch import nn + +from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.utils.torch_utils import KwargsForwardHook + + +def attention_mask_handler( + attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length): + """Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D + (implicit batch_size and n_heads).""" + if len(attention_mask.shape) == 4: + if attention_mask.shape[0] == 1: + attention_mask = attention_mask.repeat(batch_size, 1, 1, 1) + if attention_mask.shape[1] == 1: + attention_mask = attention_mask.repeat(1, num_heads, 1, 1) + if attention_mask.shape[2] == 1: + attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1) + attention_mask = attention_mask.view( + batch_size * num_heads, query_seq_length, key_value_seq_length) + elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1: + # This could happen in Encoder-like architecture + assert query_seq_length == key_value_seq_length + attention_mask = attention_mask.repeat(query_seq_length, 1) + return attention_mask + + +class MultiheadAttentionWrapper(nn.Module): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None) -> None: + super().__init__() + self.mha = nn.MultiheadAttention( + embed_dim, + num_heads, + dropout, + bias, + add_bias_kv, + add_zero_attn, + kdim, + vdim, + batch_first, + device, + dtype) + + @property + def wrapped_mha(self): + mha = self.mha + # Workaround for activation equalization for when mha is wrapped + # KwargsForwardHook is inserted during act equalization + # EqualizedModule is inserted after act equalization + if isinstance(mha, KwargsForwardHook): + mha = mha.module + if isinstance(mha, EqualizedModule): + mha = mha.layer + return mha + + @property + def num_heads(self): + return self.wrapped_mha.num_heads + + @property + def batch_first(self): + return self.wrapped_mha.batch_first + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + + def set_bias(value): + bias_name = f'{prefix}mha.in_proj_bias' + if bias_name in state_dict: + state_dict[bias_name] += value + else: + state_dict[bias_name] = value + + def set_weight(value): + weight_name = f'{prefix}mha.in_proj_weight' + if weight_name in state_dict: + state_dict[weight_name] += value + else: + state_dict[weight_name] = value + + embed_dim = self.mha.embed_dim + for name, value in list(state_dict.items()): + if prefix + 'q_proj.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[:embed_dim] = value + set_weight(weight) + del state_dict[name] + elif prefix + 'k_proj.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[embed_dim:2 * embed_dim] = value + set_weight(weight) + del state_dict[name] + elif prefix + 'v_proj.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[2 * embed_dim:3 * embed_dim] = value + set_weight(weight) + del state_dict[name] + if prefix + 'q_proj.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[:embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'k_proj.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[embed_dim:2 * embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'v_proj.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[2 * embed_dim:3 * embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'out_proj.weight' in name: + state_dict[prefix + 'mha.out_proj.weight'] = value + del state_dict[name] + elif prefix + 'out_proj.bias' in name: + state_dict[prefix + 'mha.out_proj.bias'] = value + del state_dict[name] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +class QuantizableOPTAttention(MultiheadAttentionWrapper): + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if key_value_states is None: + key_value_states = hidden_states + if layer_head_mask is not None: + raise RuntimeError("layer_head_mask is not supported.") + if self.batch_first: + batch_size, query_seq_length = hidden_states.shape[:2] + key_value_seq_length = key_value_states.shape[1] + else: + query_seq_length, batch_size = hidden_states.shape[:2] + key_value_seq_length = key_value_states.shape[0] + num_heads = self.num_heads + attention_mask = attention_mask_handler( + attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) + attn_output, attn_output_weights = self.mha( + hidden_states, + key_value_states, + key_value_states, + attn_mask=attention_mask, + need_weights=output_attentions, + average_attn_weights=False) + past_key_value = None + return attn_output, attn_output_weights, past_key_value diff --git a/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py b/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py new file mode 100644 index 000000000..c4a23f123 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py @@ -0,0 +1,113 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +from typing import List, Tuple + +import torch +import torch.utils.cpp_extension +import torch_mlir +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.registry import \ + _rename_python_keyword_parameter_name +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.registry import JitOperator +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.registry import SIG_ATTR_TYPE + +from brevitas.backport.fx._symbolic_trace import wrap + + +def patched_has_value_semantics_function_signature(self): + """Gets the Python function signature for this op's has_value_semantics function. + While this is technically debug-only output, it is useful to copy-paste + it from the debug dump into the library definitions, as many + ops have extra default arguments and stuff that are tedious to write out + right. + """ + + def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: + parameter_name = _rename_python_keyword_parameter_name(arg["name"]) + return f"{parameter_name}" + + def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: + return "None" + + return self._get_function_signature( + "has_value_semantics", parameter_decl_builder, ret_decl_builder) + + +JitOperator.get_has_value_semantics_function_signature = patched_has_value_semantics_function_signature + + +def matmul_rhs_group_quant( + lhs: torch.Tensor, + rhs: torch.Tensor, + rhs_scale: torch.Tensor, + rhs_zero_point: torch.Tensor, + rhs_bit_width: int, + rhs_group_size: int): + # This is just a placeholder for the actual implementation that provides correct shape/device/dtype + if len(lhs.shape) == 3 and len(rhs.shape) == 2: + return torch.randn( + lhs.shape[0], lhs.shape[1], rhs.shape[0], device=lhs.device, dtype=lhs.dtype) + elif len(lhs.shape) == 2 and len(rhs.shape) == 2: + return torch.randn(lhs.shape[0], rhs.shape[0], device=lhs.device, dtype=lhs.dtype) + else: + raise ValueError("Input shapes not supported.") + + +brevitas_lib = torch.library.Library("brevitas", "DEF") +brevitas_lib.define( + "matmul_rhs_group_quant(Tensor lhs, Tensor rhs, Tensor rhs_scale, Tensor rhs_zero_point, int rhs_bit_width, int rhs_group_size) -> Tensor" +) +brevitas_lib.impl("matmul_rhs_group_quant", matmul_rhs_group_quant) + + +def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: + if len(lhs) == 3 and len(rhs) == 2: + return [lhs[0], lhs[1], rhs[0]] + elif len(lhs) == 2 and len(rhs) == 2: + return [lhs[0], rhs[0]] + else: + raise ValueError("Input shapes not supported.") + + +def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: + # output dtype is the dtype of the lhs float input + lhs_rank, lhs_dtype = lhs_rank_dtype + return lhs_dtype + + +def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: + return + + +brevitas_matmul_rhs_group_quant_library = [ + brevitas〇matmul_rhs_group_quant〡shape, + brevitas〇matmul_rhs_group_quant〡dtype, + brevitas〇matmul_rhs_group_quant〡has_value_semantics] + +if __name__ == '__main__': + + class CustomOpExampleModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward( + self: torch.nn.Module, + lhs: torch.Tensor, + rhs: torch.Tensor, + rhs_scale: torch.Tensor, + rhs_zero_point: torch.Tensor): + return torch.ops.brevitas.matmul_rhs_group_quant( + lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width=8, rhs_group_size=128) + + mod = CustomOpExampleModule() + mod.eval() + + module = torch_mlir.compile( + mod, (torch.ones(3, 4), torch.ones(5, 4), torch.ones(1), torch.ones(1)), + output_type="torch", + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library) + print(module) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py new file mode 100644 index 000000000..51e38bcd5 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -0,0 +1,17 @@ +from transformers.models.opt.modeling_opt import OPTAttention + +from brevitas.graph import ModuleToModuleByClass +from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention + +QUANTIZABLE_MHA_MAP = {OPTAttention: (QuantizableOPTAttention, {'batch_first': True})} + + +def replace_mha_with_quantizable_layers(model, dtype): + rewriters = [] + for src_module, (quantizable_module, quantizable_module_kwargs) in QUANTIZABLE_MHA_MAP.items(): + rewriter = ModuleToModuleByClass( + src_module, quantizable_module, **quantizable_module_kwargs, dtype=dtype) + rewriters.append(rewriter) + for rewriter in rewriters: + model = rewriter.apply(model) + return model diff --git a/src/brevitas_examples/llm/llm_quant/quant_blocks.py b/src/brevitas_examples/llm/llm_quant/quant_blocks.py new file mode 100644 index 000000000..a4334157b --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/quant_blocks.py @@ -0,0 +1,135 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +import torch.nn as nn + +import brevitas +from brevitas.core.function_wrapper.shape import PermuteDims +from brevitas.core.utils import SliceTensor + + +class OverSubChannelBlockView(brevitas.jit.ScriptModule): + __constants__ = ['scaling_input_shape'] + + def __init__(self, scaling_input_shape, permute_dims: Optional[Tuple[int, ...]]) -> None: + super(OverSubChannelBlockView, self).__init__() + self.scaling_input_shape = scaling_input_shape + if permute_dims is not None: + self.permute_impl = PermuteDims(permute_dims) + else: + self.permute_impl = nn.Identity() + + @brevitas.jit.script_method + def forward(self, x: torch.Tensor): + y = self.permute_impl(x) + y = y.view(self.scaling_input_shape) + return y + + +class ExpandReshapeScalingWrapper(brevitas.jit.ScriptModule): + __constants__ = ['expanded_scaling_shape', 'reshaped_scaling_shape'] + + def __init__(self, wrapped_scaling_impl, expanded_scaling_shape, reshaped_scaling_shape): + super(ExpandReshapeScalingWrapper, self).__init__() + self.wrapped_scaling_impl = wrapped_scaling_impl + self.expanded_scaling_shape = expanded_scaling_shape + self.reshaped_scaling_shape = reshaped_scaling_shape + self.slice_tensor = SliceTensor() + + @brevitas.jit.script_method + def forward(self, x): + scale = self.wrapped_scaling_impl(x) + scale = scale.expand(self.expanded_scaling_shape) + scale = scale.reshape(self.reshaped_scaling_shape) + # slice tensor when required by partial quantization + scale = self.slice_tensor(scale) + return scale + + +class ExpandReshapeZeroPointWrapper(brevitas.jit.ScriptModule): + __constants__ = ['expanded_zero_point_shape', 'reshaped_zero_point_shape'] + + def __init__( + self, wrapped_zero_point_impl, expanded_zero_point_shape, reshaped_zero_point_shape): + super(ExpandReshapeZeroPointWrapper, self).__init__() + self.wrapped_zero_point_impl = wrapped_zero_point_impl + self.expanded_zero_point_shape = expanded_zero_point_shape + self.reshaped_zero_point_shape = reshaped_zero_point_shape + self.slice_tensor = SliceTensor() + + def unexpanded_zero_point(self, unexpanded_scale, bit_width): + """ + This is used at export time. + """ + zero_point_stats = self.wrapped_zero_point_impl.parameter_list_stats() + zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point( + -zero_point_stats, unexpanded_scale, bit_width) + return zero_point + + @brevitas.jit.script_method + def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor): + # We have to break into wrapped_zero_point_impl since we need to expand and reshape + # Before we call into scale_shift_zero_point + zero_point_stats = self.wrapped_zero_point_impl.parameter_list_stats() + zero_point_stats = zero_point_stats.expand(self.expanded_zero_point_shape).contiguous() + # contiguous() above is to avoid an unsafe_view below + zero_point_stats = zero_point_stats.reshape(self.reshaped_zero_point_shape) + # slice tensor when required by partial quantization + zero_point_stats = self.slice_tensor(zero_point_stats) + zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point( + -zero_point_stats, scale, bit_width) + return zero_point + + +class RuntimeDynamicStatsScaling(brevitas.jit.ScriptModule): + __constants__ = ['dynamic_scaling_broadcastable_shape'] + + def __init__( + self, + scaling_stats_impl: nn.Module, + dynamic_scaling_broadcastable_shape: Tuple[int, ...], + scaling_stats_input_view_shape_impl: nn.Module) -> None: + super(RuntimeDynamicStatsScaling, self).__init__() + self.scaling_stats_input_view_shape_impl = scaling_stats_input_view_shape_impl + self.stats_impl = scaling_stats_impl + self.dynamic_scaling_broadcastable_shape = dynamic_scaling_broadcastable_shape + + @brevitas.jit.script_method + def forward(self, x) -> Tensor: + x = self.scaling_stats_input_view_shape_impl(x) + x = self.stats_impl(x) + x = x.view(self.dynamic_scaling_broadcastable_shape) + return x + + +class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): + + def __init__(self, group_size: int, group_dim: int, scaling_stats_impl: nn.Module) -> None: + super(RuntimeDynamicGroupStatsScaling, self).__init__() + self.group_size = group_size + self.group_dim = group_dim + self.scaling_stats_impl = scaling_stats_impl + + @brevitas.jit.script_method + def group_scaling_reshape(self, stats_input): + tensor_shape = stats_input.shape + tensor_shape_list = list(tensor_shape) + tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) + tensor_shape_list.insert(self.group_dim + 1, self.group_size) + stats_input = stats_input.view(tensor_shape_list) + return stats_input + + @brevitas.jit.script_method + def forward(self, stats_input) -> Tensor: + stats_input_reshaped = self.group_scaling_reshape(stats_input) + out = self.scaling_stats_impl(stats_input_reshaped) + out = torch.clamp_min(out, min=torch.tensor(1e-6, device=out.device, dtype=out.dtype)) + out = out.expand(stats_input_reshaped.shape) + out = out.reshape(stats_input.shape) + return out diff --git a/src/brevitas_examples/llm/llm_quant/quantize.py b/src/brevitas_examples/llm/llm_quant/quantize.py new file mode 100644 index 000000000..ec4b3b72a --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/quantize.py @@ -0,0 +1,284 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +from torch import nn + +from brevitas import nn as qnn +from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint +from brevitas.graph.quantize import layerwise_quantize +from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint +from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE +from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint +from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPointMSE +from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint +from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPointMSE +from brevitas.quant.scaled_int import Int8ActPerTensorFloat +from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE +from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatMSE +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas.quant.scaled_int import Int8WeightPerTensorFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE +from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerGroupFloat +from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerRowFloat +from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerTensorFloat +from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloat +from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloatMSE +from brevitas_examples.llm.llm_quant.quantizers import IntWeightSymmetricGroupQuant +from brevitas_examples.llm.llm_quant.quantizers import ShiftedUint8ActPerRowFloat +from brevitas_examples.llm.llm_quant.quantizers import ShiftedUint8ActPerRowFloatMSE +from brevitas_examples.llm.llm_quant.quantizers import ShiftedUintWeightAsymmetricGroupQuant + +WEIGHT_QUANT_MAP = { + 'float': { + 'stats': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat}, + 'per_group': { + 'sym': IntWeightSymmetricGroupQuant, 'asym': ShiftedUintWeightAsymmetricGroupQuant}, + }, + 'mse': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFloatMSE, 'asym': ShiftedUint8WeightPerTensorFloatMSE}, + 'per_channel': { + 'sym': Int8WeightPerChannelFloatMSE, 'asym': ShiftedUint8WeightPerChannelFloatMSE}, + },}, + 'po2': { + 'stats': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFixedPoint}, + 'per_channel': { + 'sym': Int8WeightPerChannelFixedPoint},}, + 'mse': { + 'per_tensor': { + 'sym': Int8WeightPerTensorFixedPointMSE}, + 'per_channel': { + 'sym': Int8WeightPerChannelFixedPointMSE},},}} + +INPUT_QUANT_MAP = { + 'static': { + 'float': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}, + 'per_row': { + 'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}, + 'per_row': { + 'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},}, + 'po2': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPoint},}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPointMSE},},}}, + 'dynamic': { + 'float': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActDynamicPerTensorFloat}, + 'per_row': { + 'sym': Int8ActDynamicPerRowFloat}, + 'per_group': { + 'sym': Int8ActDynamicPerGroupFloat},}}}} + + +def quantize_model( + model, + dtype, + weight_bit_width, + weight_param_method, + weight_scale_precision, + weight_quant_type, + weight_quant_granularity, + weight_group_size, + quantize_weight_zero_point, + input_bit_width=None, + input_scale_precision=None, + input_scale_type=None, + input_param_method=None, + input_quant_type=None, + input_quant_granularity=None, + input_group_size=None, + quantize_input_zero_point=False, + seqlen=None): + """ + Replace float layers with quant layers in the target model + """ + # Retrive base input and weight quantizers + weight_quant = WEIGHT_QUANT_MAP[weight_scale_precision][weight_param_method][ + weight_quant_granularity][weight_quant_type] + if input_bit_width is not None: + input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][input_param_method][ + input_quant_granularity][input_quant_type] + # Some activations in MHA should always be symmetric + sym_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][ + input_param_method][input_quant_granularity]['sym'] + # Linear layers with 2d input should always be per tensor or per group, as there is no row dimension + if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row': + linear_2d_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][ + input_param_method]['per_tensor'][input_quant_type] + else: + assert input_quant_granularity == 'per_group' + linear_2d_input_quant = input_quant + else: + input_quant = None + sym_input_quant = None + linear_2d_input_quant = None + + # Modify the weight quantizer based on the arguments passed in + weight_quant = weight_quant.let( + **{ + 'bit_width': weight_bit_width, + 'narrow_range': False, + 'block_size': weight_group_size, + 'quantize_zero_point': quantize_weight_zero_point}) + # weight scale is converted to a standalone parameter + # This is done already by default in the per_group quantizer + if weight_quant_granularity != 'per_group': + weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats') + # weight zero-point is converted to a standalone parameter + # This is done already by default in the per_group quantizer + if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group': + weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint) + + # Modify the input quantizers based on the arguments passed in + if input_quant is not None: + input_quant = input_quant.let( + **{ + 'bit_width': input_bit_width, + 'quantize_zero_point': quantize_input_zero_point, + 'dtype': dtype}) + if input_scale_type == 'static' and input_quant_granularity == 'per_row': + # QuantMHA internally always uses Seq, B, E + input_quant = input_quant.let( + **{ + 'per_channel_broadcastable_shape': (seqlen, 1, 1), + 'scaling_stats_permute_dims': (0, 1, 2)}) + elif input_scale_type == 'dynamic': + if input_quant_granularity == 'per_tensor': + input_quant = input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (1, -1, 1), + 'permute_dims': (1, 0, 2), + 'stats_reduce_dim': 1}) + elif input_quant_granularity == 'per_row': + input_quant = input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (seqlen, -1, 1), + 'permute_dims': (1, 0, 2), + 'stats_reduce_dim': 2}) + elif input_quant_granularity == 'per_group': + input_quant = input_quant.let(**{'group_dim': 2, 'group_size': input_group_size}) + if sym_input_quant is not None: + sym_input_quant = sym_input_quant.let( + **{ + 'bit_width': input_bit_width, + 'quantize_zero_point': quantize_input_zero_point, + 'dtype': dtype}) + if input_scale_type == 'static' and input_quant_granularity == 'per_row': + q_scaled_quant = sym_input_quant.let( + **{ + 'per_channel_broadcastable_shape': (1, seqlen, 1), + 'scaling_stats_permute_dims': (1, 0, 2)}) + k_transposed_quant = sym_input_quant.let( + **{ + 'per_channel_broadcastable_shape': (1, 1, seqlen), + 'scaling_stats_permute_dims': (2, 0, 1)}) + v_quant = q_scaled_quant + attn_output_weights_quant = q_scaled_quant + elif input_scale_type == 'dynamic': + if input_quant_granularity == 'per_tensor': + q_scaled_quant = sym_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, 1, 1), + 'permute_dims': None, + 'stats_reduce_dim': 1}) + k_transposed_quant = sym_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, 1, 1), + 'permute_dims': None, + 'stats_reduce_dim': 1}) + elif input_quant_granularity == 'per_row': + q_scaled_quant = sym_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, seqlen, 1), + 'permute_dims': None, + 'stats_reduce_dim': 2}) + k_transposed_quant = sym_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, 1, seqlen), + 'permute_dims': None, + 'stats_reduce_dim': 1}) + elif input_quant_granularity == 'per_group': + q_scaled_quant = sym_input_quant.let( + **{ + 'group_dim': 2, 'group_size': input_group_size}) + k_transposed_quant = sym_input_quant.let( + **{ + 'group_dim': 1, 'group_size': input_group_size}) + v_quant = q_scaled_quant + attn_output_weights_quant = q_scaled_quant + else: + q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = sym_input_quant + else: + q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = None + if linear_2d_input_quant is not None: + linear_2d_input_quant = linear_2d_input_quant.let( + **{ + 'bit_width': input_bit_width, + 'quantize_zero_point': quantize_input_zero_point, + 'dtype': dtype}) + if input_scale_type == 'dynamic': + # Note: this breaks if applied to 3d Linear inputs, + # in case standard MHA layers haven't been inserted + if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row': + linear_2d_input_quant = linear_2d_input_quant.let( + **{ + 'dynamic_scaling_broadcastable_shape': (-1, 1), + 'permute_dims': None, + 'stats_reduce_dim': 1}) + elif input_quant_granularity == 'per_group': + linear_2d_input_quant = linear_2d_input_quant.let( + **{ + 'group_dim': 1, 'group_size': input_group_size}) + + quant_linear_kwargs = { + 'input_quant': linear_2d_input_quant, 'weight_quant': weight_quant, 'dtype': dtype} + + quant_mha_kwargs = { + 'in_proj_input_quant': input_quant, + 'in_proj_weight_quant': weight_quant, + 'in_proj_bias_quant': None, + 'softmax_input_quant': None, + 'attn_output_weights_quant': attn_output_weights_quant, + 'attn_output_weights_signed': False, + 'q_scaled_quant': q_scaled_quant, + 'k_transposed_quant': k_transposed_quant, + 'v_quant': v_quant, + 'out_proj_input_quant': input_quant, + 'out_proj_weight_quant': weight_quant, + 'out_proj_bias_quant': None, + 'out_proj_output_quant': None, + 'batch_first': True, + # activation equalization requires packed_in_proj + # since it supports only self-attention + 'packed_in_proj': True, + 'dtype': dtype} + + layer_map = { + nn.Linear: (qnn.QuantLinear, quant_linear_kwargs), + nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)} + layerwise_quantize(model=model, compute_layer_map=layer_map) diff --git a/src/brevitas_examples/llm/llm_quant/quantizers.py b/src/brevitas_examples/llm/llm_quant/quantizers.py new file mode 100644 index 000000000..9848d994c --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/quantizers.py @@ -0,0 +1,132 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +from torch import nn + +from brevitas.core.function_wrapper.shape import OverBatchOverOutputChannelView +from brevitas.core.function_wrapper.shape import OverBatchOverTensorView +from brevitas.core.function_wrapper.shape import OverTensorView +from brevitas.core.scaling import ParameterFromStatsFromParameterScaling +from brevitas.core.stats import AbsMinMax +from brevitas.core.stats import NegativeMinOrZero +from brevitas.core.stats import NegativePercentileOrZero +from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint +from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint +from brevitas.inject import this +from brevitas.inject import value +from brevitas.quant.scaled_int import Int8ActPerTensorFloat +from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE +from brevitas.quant.scaled_int import Int8WeightPerChannelFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat +from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloatMSE + +from .quant_blocks import * + + +class IntWeightSymmetricGroupQuant(Int8WeightPerChannelFloat): + """ + Block / group / vector signed symmetric weight quantizer with float scales. + We inherit from a per-channel quantizer to re-use some underlying machinery. + """ + + @value + def expanded_scaling_shape(module, block_size): + if isinstance(module, nn.Conv2d): + return module.weight.size(0), module.weight.size(1) // block_size, block_size, module.weight.size(2), module.weight.size(3) + elif isinstance(module, nn.Linear): + return module.weight.size(0), module.weight.size(1) // block_size, block_size + else: + raise RuntimeError("Module not supported.") + + @value + def scaling_shape(module, block_size): + if isinstance(module, nn.Conv2d): + return module.weight.size(0), module.weight.size(1) // block_size, 1, module.weight.size(2), module.weight.size(3) + elif isinstance(module, nn.Linear): + return module.weight.size(0), module.weight.size(1) // block_size, 1 + else: + raise RuntimeError("Module not supported.") + + @value + def reshaped_scaling_shape(module): + return module.weight.shape + + scaling_input_shape = this.expanded_scaling_shape + scaling_stats_input_view_shape_impl = OverSubChannelBlockView + scaling_impl = ExpandReshapeScalingWrapper + # scale is converted to a parameter right away + wrapped_scaling_impl = ParameterFromStatsFromParameterScaling + keepdim = True + stats_reduce_dim = 2 + # Set bit_width and block size externally + bit_width = None + block_size = None + + +class ShiftedUintWeightAsymmetricGroupQuant(IntWeightSymmetricGroupQuant): + """ + Block / group / vector signed asymmetric weight quantizer with float scales and zero-points. + """ + zero_point_input_shape = this.scaling_input_shape + reshaped_zero_point_shape = this.reshaped_scaling_shape + zero_point_shape = this.scaling_shape + expanded_zero_point_shape = this.expanded_scaling_shape + zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + zero_point_stats_input_concat_dim = 0 + zero_point_impl = ExpandReshapeZeroPointWrapper + zero_point_stats_impl = NegativeMinOrZero + scaling_stats_impl = AbsMinMax + keepdim = True + # zero-point is converted to a parameter right away + wrapped_zero_point_impl = ParameterFromStatsFromParameterZeroPoint + quantize_zero_point = False + signed = False + + +class Int8ActPerRowFloat(Int8ActPerTensorFloat): + scaling_per_output_channel = True + + +class Int8ActPerRowFloatMSE(Int8ActPerTensorFloatMSE): + scaling_per_output_channel = True + + +class ShiftedUint8ActPerRowFloat(ShiftedUint8ActPerTensorFloat): + scaling_per_output_channel = True + + +class ShiftedUint8ActPerRowFloatMSE(ShiftedUint8ActPerTensorFloatMSE): + scaling_per_output_channel = True + + +class Int8ActDynamicPerTensorFloat(Int8ActPerTensorFloat): + """ + Symmetric quantizer with per tensor dynamic scale. + """ + scaling_impl = RuntimeDynamicStatsScaling + scaling_stats_input_view_shape_impl = OverBatchOverTensorView + scaling_stats_op = 'max' + + +class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat): + """ + Symmetric quantizer with per row dynamic scale. + """ + scaling_impl = RuntimeDynamicStatsScaling + scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView + scaling_stats_op = 'max' + + +class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat): + """ + Symmetric quantizer with per row dynamic scale. + """ + scaling_impl = RuntimeDynamicGroupStatsScaling + keepdim = True + scaling_stats_op = 'max' + + @value + def stats_reduce_dim(group_dim): + return group_dim + 1 diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py new file mode 100644 index 000000000..0ba096c4a --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -0,0 +1,164 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: + +Copyright 2023 IST-DASLab + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from contextlib import contextmanager + +import torch +from torch import nn +from tqdm import tqdm +from transformers.models.opt.modeling_opt import OPTModel + + +def get_model_impl(model): + model_impl = model.model + if isinstance(model_impl, OPTModel): + model_impl = model_impl.decoder + return model_impl + + +class InputCatcherException(Exception): + pass + + +@torch.no_grad() +def calib_input_capture(model, dataloader): + for batch in dataloader: + batch = batch.cuda() + try: + model(batch) + except InputCatcherException: + pass + + +@torch.no_grad() +def capture_first_layer_inputs(input_capture_fn, dataloader, model, model_impl, nsamples, seqlen): + layers = model_impl.layers + + model_impl.embed_tokens = model_impl.embed_tokens.cuda() + if hasattr(model_impl, 'embed_positions'): + model_impl.embed_positions = model_impl.embed_positions.cuda() + if hasattr(model_impl, 'project_in') and model_impl.project_in is not None: + model_impl.project_in = model_impl.project_in.cuda() + if hasattr(model_impl, 'norm'): + model_impl.norm = model_impl.norm.cuda() + if hasattr(model_impl, 'embed_layer_norm'): + model_impl.embed_layer_norm = model_impl.embed_layer_norm.cuda() + + layers[0] = layers[0].cuda() + + dtype = next(iter(model_impl.parameters())).dtype + inps = torch.zeros((nsamples, seqlen, model.config.hidden_size), dtype=dtype).cuda() + cache = {'i': 0} + + class InputCatcher(nn.Module): + + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + if 'position_ids' in kwargs.keys(): + cache['position_ids'] = kwargs['position_ids'] + raise InputCatcherException + + layers[0] = InputCatcher(layers[0]) + input_capture_fn(model, dataloader) + + layers[0] = layers[0].module + layers[0] = layers[0].cpu() + model_impl.embed_tokens = model_impl.embed_tokens.cpu() + if hasattr(model_impl, 'embed_positions'): + model_impl.embed_positions = model_impl.embed_positions.cpu() + if hasattr(model_impl, 'project_in') and model_impl.project_in is not None: + model_impl.project_in = model_impl.project_in.cpu() + if hasattr(model_impl, 'norm'): + model_impl.norm = model_impl.norm.cpu() + if hasattr(model_impl, 'embed_layer_norm'): + model_impl.embed_layer_norm = model_impl.embed_layer_norm.cpu() + + return inps, cache + + +@torch.no_grad() +def apply_layer_inference_fn( + model, + dataloader, + nsamples, + inference_fn, + input_capture_fn, + seqlen=2048, + **inference_fn_kwargs): + model_impl = get_model_impl(model) + layers = model_impl.layers + + use_cache = model.config.use_cache + model.config.use_cache = False + + inps, cache = capture_first_layer_inputs( + input_capture_fn, dataloader, model, model_impl, nsamples, seqlen) + outs = torch.zeros_like(inps) + + cached_values = {} + cached_values['attention_mask'] = cache['attention_mask'] + if 'position_ids' in cache.keys(): + cached_values['position_ids'] = cache['position_ids'] + + for curr_layer in tqdm(layers): + inference_fn(curr_layer, inps, outs, cached_values, **inference_fn_kwargs) + inps, outs = outs, inps + + model.config.use_cache = use_cache + return inps + + +def apply_layer_ptq_fn( + model, dataloader, nsamples, inference_fn, seqlen=2048, **inference_fn_kwargs): + return apply_layer_inference_fn( + model, + dataloader, + nsamples, + inference_fn, + input_capture_fn=calib_input_capture, + seqlen=seqlen, + **inference_fn_kwargs) + + +@contextmanager +def cast_to_float32(model, target_dtype): + dtype_dict = {} + for name, p in model.state_dict().items(): + # This allows to pick up duplicated parameters + dtype_dict[name] = p.dtype + if any(dtype != torch.float32 for dtype in dtype_dict.values()): + model.to(dtype=torch.float32) + try: + yield model + finally: + for name, p in {**dict(model.named_parameters()), **dict(model.named_buffers())}.items(): + if name in dtype_dict: + p.data = p.data.to(dtype_dict[name]) + else: + # target_dtype covers any new tensors that might have been + # introduced in the process (e.g. during equalization) + p.data = p.data.to(target_dtype) diff --git a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py new file mode 100644 index 000000000..7f69c2029 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py @@ -0,0 +1,418 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +Based on https://github.com/nod-ai/SHARK/blob/main/apps/language_models/scripts/sharded_vicuna_fp32.py + +Copyright 2023 Nod.ai + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +---- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. +""" +import argparse +from io import BytesIO +from pathlib import Path +import re +from typing import List + +import torch +from torch._decomp import get_decompositions +import torch_mlir +from torch_mlir import TensorPlaceholder +from tqdm import tqdm + +from brevitas.backport.fx._symbolic_trace import wrap +from brevitas.backport.fx.experimental.proxy_tensor import make_fx +from brevitas_examples.llm.llm_quant.export import block_quant_layer_level_manager +from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager +from brevitas_examples.llm.llm_quant.export import brevitas_layer_export_mode +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from brevitas_examples.llm.llm_quant.export import LinearWeightBlockQuantHandler +from brevitas_examples.llm.llm_quant.export import replace_call_fn_target +from brevitas_examples.llm.llm_quant.mlir_custom_mm import brevitas_matmul_rhs_group_quant_library + + +# Due a tracing issue this annotation needs to be +# in the same module (== file) from which make_fx is called +# We also can't directly annotate torch.ops.brevitas.matmul_rhs_group_quant +# and so we trace a placeholder first and then replace it post tracing +@wrap(visible_to_make_fx=True) +def matmul_rhs_group_quant_placeholder(*args, **kwargs): + return torch.ops.brevitas.matmul_rhs_group_quant(*args, **kwargs) + + +class LinearWeightBlockQuantHandlerFwd(LinearWeightBlockQuantHandler): + + def forward(self, x): + # Due a tracing issue the call to this fn needs to be + # in the same module (== file) from which make_fx is called + out = matmul_rhs_group_quant_placeholder( + x, self.int_weight, self.scale, self.zero_point, self.bit_width, self.group_size) + if self.bias is not None: + out = out + self.bias.view(1, -1) + return out + + +class FirstVicunaLayer(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, hidden_states, attention_mask, position_ids): + outputs = self.model( + hidden_states, attention_mask=attention_mask, position_ids=position_ids, use_cache=True) + next_hidden_states = outputs[0] + past_key_value_out0, past_key_value_out1 = (outputs[-1][0], outputs[-1][1]) + + return (next_hidden_states, past_key_value_out0, past_key_value_out1) + + +class SecondVicunaLayer(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, hidden_states, attention_mask, position_ids, past_key_value0, past_key_value1): + outputs = self.model( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=(past_key_value0, past_key_value1), + use_cache=True) + next_hidden_states = outputs[0] + past_key_value_out0, past_key_value_out1 = (outputs[-1][0], outputs[-1][1]) + + return (next_hidden_states, past_key_value_out0, past_key_value_out1) + + +def write_in_dynamic_inputs0(module, dynamic_input_size): + new_lines = [] + for line in module.splitlines(): + line = re.sub(f"{dynamic_input_size}x", "?x", line) + if "?x" in line: + line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) + line = re.sub(f" {dynamic_input_size},", " %dim,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub("tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim", line) + new_lines.append(line) + new_module = "\n".join(new_lines) + return new_module + + +def write_in_dynamic_inputs1(module, dynamic_input_size): + new_lines = [] + for line in module.splitlines(): + if "dim_42 =" in line: + continue + if f"%c{dynamic_input_size}_i64 =" in line: + new_lines.append("%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>") + new_lines.append(f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64") + continue + line = re.sub(f"{dynamic_input_size}x", "?x", line) + if "?x" in line: + line = re.sub("tensor.empty\(\)", "tensor.empty(%dim_42)", line) + line = re.sub(f" {dynamic_input_size},", " %dim_42,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim_42\)", + "tensor.empty(%dim_42, %dim_42)", + line, + ) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim_42", line) + new_lines.append(line) + new_module = "\n".join(new_lines) + return new_module + + +def compile_vicuna_layer( + export_context_manager, + export_class, + vicuna_layer, + hidden_states, + attention_mask, + position_ids, + past_key_value0=None, + past_key_value1=None, +): + hidden_states_placeholder = TensorPlaceholder.like(hidden_states, dynamic_axes=[1]) + attention_mask_placeholder = TensorPlaceholder.like(attention_mask, dynamic_axes=[2, 3]) + position_ids_placeholder = TensorPlaceholder.like(position_ids, dynamic_axes=[1]) + + if past_key_value0 is None and past_key_value1 is None: + with export_context_manager(vicuna_layer, export_class): + fx_g = make_fx( + vicuna_layer, + decomposition_table=get_decompositions([ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes,]), + )(hidden_states, attention_mask, position_ids) + print(fx_g.graph) + else: + with export_context_manager(vicuna_layer, export_class): + fx_g = make_fx( + vicuna_layer, + decomposition_table=get_decompositions([ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes,]), + )(hidden_states, attention_mask, position_ids, past_key_value0, past_key_value1) + + def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: + removed_indexes = [] + for node in fx_g.graph.nodes: + if node.op == "output": + assert (len(node.args) == 1), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, (list, tuple)): + node_arg = list(node_arg) + node_args_len = len(node_arg) + for i in range(node_args_len): + curr_index = node_args_len - (i + 1) + if node_arg[curr_index] is None: + removed_indexes.append(curr_index) + node_arg.pop(curr_index) + node.args = (tuple(node_arg),) + break + + if len(removed_indexes) > 0: + fx_g.graph.lint() + fx_g.graph.eliminate_dead_code() + fx_g.recompile() + removed_indexes.sort() + return removed_indexes + + def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: + """ + Replace tuple with tuple element in functions that return one-element tuples. + Returns true if an unwrapping took place, and false otherwise. + """ + unwrapped_tuple = False + for node in fx_g.graph.nodes: + if node.op == "output": + assert (len(node.args) == 1), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, tuple): + if len(node_arg) == 1: + node.args = (node_arg[0],) + unwrapped_tuple = True + break + + if unwrapped_tuple: + fx_g.graph.lint() + fx_g.recompile() + return unwrapped_tuple + + def transform_fx(fx_g): + for node in fx_g.graph.nodes: + if node.op == "call_function": + if node.target in [torch.ops.aten.empty]: + # aten.empty should be filled with zeros. + with fx_g.graph.inserting_after(node): + new_node = fx_g.graph.call_function(torch.ops.aten.zero_, args=(node,)) + node.append(new_node) + node.replace_all_uses_with(new_node) + new_node.args = (node,) + fx_g.graph.lint() + + transform_fx(fx_g) + replace_call_fn_target( + fx_g, + src=matmul_rhs_group_quant_placeholder, + target=torch.ops.brevitas.matmul_rhs_group_quant) + + fx_g.recompile() + removed_none_indexes = _remove_nones(fx_g) + was_unwrapped = _unwrap_single_tuple_return(fx_g) + + fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) + fx_g.recompile() + + print("FX_G recompile") + + def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() + + strip_overloads(fx_g) + ts_g = torch.jit.script(fx_g) + return ts_g + + +def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_first=True): + mlirs = [] + for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"): + if is_first: + mlir_path = Path(f"{idx}_0.mlir") + vmfb_path = Path(f"{idx}_0.vmfb") + else: + mlir_path = Path(f"{idx}_1.mlir") + vmfb_path = Path(f"{idx}_1.vmfb") + if vmfb_path.exists(): + continue + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states_placeholder = TensorPlaceholder.like(inputs[0], dynamic_axes=[1]) + attention_mask_placeholder = TensorPlaceholder.like(inputs[1], dynamic_axes=[3]) + position_ids_placeholder = TensorPlaceholder.like(inputs[2], dynamic_axes=[1]) + if not is_first: + pkv0_placeholder = TensorPlaceholder.like(inputs[3], dynamic_axes=[2]) + pkv1_placeholder = TensorPlaceholder.like(inputs[4], dynamic_axes=[2]) + print(f"Compiling layer {idx} mlir") + if is_first: + ts_g = compile_vicuna_layer( + export_context_manager, export_class, layer, inputs[0], inputs[1], inputs[2]) + module = torch_mlir.compile( + ts_g, (hidden_states_placeholder, inputs[1], inputs[2]), + output_type="torch", + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False) + else: + ts_g = compile_vicuna_layer( + export_context_manager, + export_class, + layer, + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4]) + module = torch_mlir.compile( + ts_g, + ( + inputs[0], + attention_mask_placeholder, + inputs[2], + pkv0_placeholder, + pkv1_placeholder), + output_type="torch", + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False) + + if is_first: + module = write_in_dynamic_inputs0(str(module), 137) + bytecode = module.encode("UTF-8") + bytecode_stream = BytesIO(bytecode) + bytecode = bytecode_stream.read() + + else: + module = write_in_dynamic_inputs1(str(module), 138) + if idx in [0, 5, 6, 7]: + module_str = module + module_str = module_str.splitlines() + new_lines = [] + for line in module_str: + if len(line) < 1000: + new_lines.append(line) + else: + new_lines.append(line[:999]) + module_str = "\n".join(new_lines) + f1_ = open(f"{idx}_1_test.mlir", "w+") + f1_.write(module_str) + f1_.close() + + bytecode = module.encode("UTF-8") + bytecode_stream = BytesIO(bytecode) + bytecode = bytecode_stream.read() + + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + mlirs.append(bytecode) + + return mlirs + + +def sharded_weight_group_export(model, no_custom_packed_export): + + # SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, + # which is currently an increadibly hacky proccess + # please don't change it + SAMPLE_INPUT_LEN = 137 + + placeholder_input0 = ( + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]), + torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)) + + placeholder_input1 = ( + torch.zeros([1, 1, 4096]), + torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]), + torch.zeros([1, 1], dtype=torch.int64), + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128])) + + if no_custom_packed_export: + export_context_manager = brevitas_proxy_export_mode + export_class = BlockQuantProxyLevelManager + else: + export_context_manager = brevitas_layer_export_mode + # generate an export_class with the handler declared above + export_class = block_quant_layer_level_manager( + export_handlers=[LinearWeightBlockQuantHandlerFwd]) + + layers0 = [FirstVicunaLayer(layer) for layer in model.model.layers] + mlirs0 = compile_to_vmfb( + placeholder_input0, layers0, export_context_manager, export_class, is_first=True) + + layers1 = [SecondVicunaLayer(layer) for layer in model.model.layers] + mlirs1 = compile_to_vmfb( + placeholder_input1, layers1, export_context_manager, export_class, is_first=False) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py new file mode 100644 index 000000000..0e6fe05ad --- /dev/null +++ b/src/brevitas_examples/llm/main.py @@ -0,0 +1,300 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +import argparse + +import numpy as np +import torch +from transformers import AutoModelForCausalLM + +from brevitas.export import export_onnx_qcdq +from brevitas.export import export_torch_qcdq +from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction +from brevitas_examples.llm.llm_quant.calibrate import apply_calibration +from brevitas_examples.llm.llm_quant.data import get_c4 +from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization +from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization +from brevitas_examples.llm.llm_quant.eval import model_eval +from brevitas_examples.llm.llm_quant.gptq import apply_gptq +from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge +from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + +parser = argparse.ArgumentParser() +parser.add_argument( + '--model', + type=str, + default="facebook/opt-125m", + help='HF model name. Default: facebook/opt-125m.') +parser.add_argument( + '--seed', type=int, default=0, help='Seed for sampling the calibration data. Default: 0.') +parser.add_argument( + '--nsamples', type=int, default=128, help='Number of calibration data samples. Default: 128.') +parser.add_argument('--seqlen', type=int, default=2048, help='Sequence length. Default: 2048.') +parser.add_argument('--eval', action='store_true', help='Eval model PPL on C4.') +parser.add_argument('--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') +parser.add_argument( + '--weight-param-method', + type=str, + default='stats', + choices=['stats', 'mse'], + help='How scales/zero-point are determined. Default: stats.') +parser.add_argument( + '--weight-scale-precision', + type=str, + default='float', + choices=['float', 'po2'], + help='Whether scale is a float value or a po2. Default: po2.') +parser.add_argument( + '--weight-quant-type', + type=str, + default='asym', + choices=['sym', 'asym'], + help='Weight quantization type. Default: asym.') +parser.add_argument( + '--weight-quant-granularity', + type=str, + default='per_group', + choices=['per_channel', 'per_tensor', 'per_group'], + help='Granularity for scales/zero-point of weights. Default: per_group.') +parser.add_argument( + '--weight-group-size', + type=int, + default=128, + help='Group size for per_group weight quantization. Default: 128.') +parser.add_argument( + '--quantize-weight-zero-point', action='store_true', help='Quantize weight zero-point.') +parser.add_argument( + '--input-bit-width', + type=int, + default=None, + help='Input bit width. Default: None (disables input quantization).') +parser.add_argument( + '--input-param-method', + type=str, + default='stats', + choices=['stats', 'mse'], + help= + 'How scales/zero-point are determined. Default: stats (percentile for static, absmax or minmax for dynamic).' +) +parser.add_argument( + '--input-scale-precision', + type=str, + default='float', + choices=['float', 'po2'], + help='Whether input scale is a float value or a po2. Default: float.') +parser.add_argument( + '--input-scale-type', + type=str, + default='static', + choices=['static', 'dynamic'], + help='Whether input scale is a static value or a dynamic value.') +parser.add_argument( + '--input-quant-type', + type=str, + default='asym', + choices=['sym', 'asym'], + help='Input quantization type. Default: asym.') +parser.add_argument( + '--input-quant-granularity', + type=str, + default='per_tensor', + choices=['per_tensor', 'per_row', 'per_group'], + help='Granularity for scales/zero-point of inputs. Default: per_tensor.') +parser.add_argument( + '--input-group-size', + type=int, + default=64, + help='Group size for per_group input quantization. Default: 64.') +parser.add_argument( + '--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.') +parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') +parser.add_argument('--act-calibration', action='store_true', help='Apply activation calibration.') +parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.') +parser.add_argument('--ln-affine-merge', action='store_true', help='Merge LN affine params.') +parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.') +parser.add_argument( + '--no-float16', + action='store_true', + help='Disable float16 as base datatype and switch to float32.') +parser.add_argument( + '--weight-equalization', + action='store_true', + help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).') +parser.add_argument( + '--act-equalization', + default=None, + choices=[None, 'layerwise', 'fx'], + help='Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,' + 'while fx merges them whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT).' +) +parser.add_argument( + '--export-target', + default=None, + choices=[ + None, + 'onnx_qcdq', + 'torch_qcdq', + 'sharded_torchmlir_group_weight', + 'sharded_packed_torchmlir_group_weight'], + help='Model export.') + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def model_export(model, ref_input, args): + if args.export_target == 'sharded_torchmlir_group_weight': + from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import \ + sharded_weight_group_export + sharded_weight_group_export(model, no_custom_packed_export=True) + elif args.export_target == 'sharded_packed_torchmlir_group_weight': + from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import \ + sharded_weight_group_export + sharded_weight_group_export(model, no_custom_packed_export=False) + elif args.export_target == 'onnx_qcdq': + export_onnx_qcdq(model, ref_input, export_path=f"{args.model.replace('/', '-')}.onnx") + elif args.export_target == 'torch_qcdq': + export_torch_qcdq(model, ref_input, export_path=f"{args.model.replace('/', '-')}.pt") + + +def validate(args): + if not args.no_quantize: + if args.export_target is not None and args.input_bit_width is not None: + assert args.input_scale_type == 'static', "Only static scale supported for export currently." + if args.export_target == 'sharded_torchmlir_group_weight': + assert args.weight_quant_granularity == 'per_group', "Sharded torch group export requires per group weight quant." + assert args.input_bit_width is None, "Sharded torch group weight export doesn't support input quant." + assert not args.quantize_weight_zero_point, "Quantized weight zero point not supported." + if args.export_target == 'sharded_packed_torchmlir_group_weight': + assert args.weight_quant_granularity == 'per_group', "Sharded torch group export requires per group weight quant." + assert args.input_bit_width is None, "Sharded packed torch group weight export doesn't support input quant." + assert not args.quantize_weight_zero_point, "Quantized weight zero point not supported." + if args.export_target == 'onnx_qcdq': + assert args.weight_quant_granularity != 'per_group', "ONNX QCDQ export doesn't support group weight quantization." + if args.weight_quant_type == 'asym': + assert args.quantize_weight_zero_point, "Quantized weight zero point required." + if args.input_bit_width is not None and args.input_quant_type == 'asym': + assert args.quantize_input_zero_point, "Quantized input zero point required." + if args.export_target == 'torch_qcdq': + assert args.weight_quant_granularity != 'per_group', "TorchScript QCDQ export doesn't support group weight quantization." + if args.weight_quant_type == 'asym': + assert args.quantize_weight_zero_point, "Quantized weight zero point required." + if args.input_bit_width is not None and args.input_quant_type == 'asym': + assert args.quantize_input_zero_point, "Quantized input zero point required." + if (args.input_bit_width and + (args.input_scale_type == 'static' or + (args.input_scale_type == 'dynamic' and args.input_quant_type == 'asym'))): + assert args.act_calibration, "Static input quantization is being applied without activation calibration. Set --act-calibration." + + +def main(): + args = parser.parse_args() + validate(args) + set_seed(args.seed) + + if args.no_float16: + dtype = torch.float32 + else: + dtype = torch.float16 + + kwargs = {"torch_dtype": dtype} + print("Model loading...") + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + print("Model loaded.") + model.eval() + + if (args.export_target or args.eval or args.act_equalization or args.act_calibration or + args.gptq or args.bias_corr or args.ln_affine_merge or args.weight_equalization): + print("Data loading...") + calibration_loader, val_data = get_c4( + nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=args.seqlen) + print("Data loaded.") + + # Apply LN affine merging before inserting MHA layers + # since currently there is support only for merging into Linear + if args.ln_affine_merge: + print("Apply LN affine merge...") + apply_layernorm_affine_merge(model, dtype, ref_kwargs={'input_ids': calibration_loader[0]}) + print("LN affine merge applied.") + + # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing + # with all the variability in HF implementations + if args.weight_equalization or args.act_equalization == 'fx' or args.input_bit_width: + print("Replace HF MHA with quantizable variants...") + model = replace_mha_with_quantizable_layers(model, dtype) + print("Replacing done.") + + if args.weight_equalization: + print("Apply weight equalization...") + apply_weight_equalization(model, dtype, ref_kwargs={'input_ids': calibration_loader[0]}) + print("Weight equalization applied.") + + if args.act_equalization is not None: + print("Apply act equalization (SmoothQuant)...") + apply_act_equalization( + model, + dtype, + args.act_equalization, + calibration_loader, + args.nsamples, + ref_kwargs={'input_ids': calibration_loader[0]}) + print("Act equalization applied.") + + if not args.no_quantize: + print("Applying model quantization...") + quantize_model( + get_model_impl(model).layers, + dtype=dtype, + weight_quant_type=args.weight_quant_type, + weight_bit_width=args.weight_bit_width, + weight_param_method=args.weight_param_method, + weight_scale_precision=args.weight_scale_precision, + weight_quant_granularity=args.weight_quant_granularity, + weight_group_size=args.weight_group_size, + quantize_weight_zero_point=args.quantize_weight_zero_point, + input_bit_width=args.input_bit_width, + input_quant_type=args.input_quant_type, + input_param_method=args.input_param_method, + input_scale_precision=args.input_scale_precision, + input_scale_type=args.input_scale_type, + input_quant_granularity=args.input_quant_granularity, + input_group_size=args.input_group_size, + quantize_input_zero_point=args.quantize_input_zero_point, + seqlen=args.seqlen) + print("Model quantization applied.") + + if args.act_calibration: + print("Apply act calibration...") + apply_calibration(model, calibration_loader, args.nsamples) + print("Act calibration applied.") + + if args.gptq: + print("Applying GPTQ...") + apply_gptq(model, calibration_loader, args.nsamples) + print("GPTQ applied.") + + if args.bias_corr: + print("Applying bias correction...") + apply_bias_correction(model, calibration_loader, args.nsamples) + print("Bias correction applied.") + + if args.eval: + print("Model eval...") + ppl = model_eval(model, val_data, args.seqlen) + print(f"C4 perplexity: {ppl}") + + if args.export_target: + print(f"Export to {args.export_target}") + # Currently we always export on CPU with a float32 container to avoid float16 CPU errors + model = model.cpu().to(dtype=torch.float32) + model_export(model, calibration_loader[0], args) + + +if __name__ == '__main__': + main() diff --git a/src/brevitas_examples/llm/test_linear_mlir_export.py b/src/brevitas_examples/llm/test_linear_mlir_export.py new file mode 100644 index 000000000..417721a58 --- /dev/null +++ b/src/brevitas_examples/llm/test_linear_mlir_export.py @@ -0,0 +1,126 @@ +import argparse + +import torch +from torch import nn +import torch_mlir + +from brevitas.backport.fx._symbolic_trace import wrap +from brevitas.backport.fx.experimental.proxy_tensor import make_fx +from brevitas_examples.llm.llm_quant.export import block_quant_layer_level_manager +from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager +from brevitas_examples.llm.llm_quant.export import brevitas_layer_export_mode +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from brevitas_examples.llm.llm_quant.export import LinearWeightBlockQuantHandler +from brevitas_examples.llm.llm_quant.export import replace_call_fn_target +from brevitas_examples.llm.llm_quant.mlir_custom_mm import brevitas_matmul_rhs_group_quant_library +from brevitas_examples.llm.llm_quant.quantize import quantize_model + + +# Due a tracing issue this annotation needs to be +# in the same module (== file) from which make_fx is called +# We also can't directly annotate torch.ops.brevitas.matmul_rhs_group_quant +# and so we trace a placeholder first and then replace it post tracing +@wrap(visible_to_make_fx=True) +def matmul_rhs_group_quant_placeholder(*args, **kwargs): + return torch.ops.brevitas.matmul_rhs_group_quant(*args, **kwargs) + + +class LinearWeightBlockQuantHandlerFwd(LinearWeightBlockQuantHandler): + + def forward(self, x): + # Due a tracing issue the call to this fn needs to be + # in the same module (== file) from which make_fx is called + out = matmul_rhs_group_quant_placeholder( + x, self.int_weight, self.scale, self.zero_point, self.bit_width, self.group_size) + if self.bias is not None: + out = out + self.bias.view(1, -1) + return out + + +class Model(nn.Module): + + def __init__(self): + super().__init__() + self.layer = nn.Linear(128, 256, bias=True) + + def forward(self, x): + return self.layer(x) + + +def quantize_and_export(args): + # Init model + model = Model() + + # Run quantization + quantize_model( + model, + dtype=torch.float32, + weight_quant_type=args.weight_quant_type, + weight_bit_width=args.weight_bit_width, + weight_group_size=args.weight_group_size, + weight_param_method='stats', + weight_scale_precision='float', + weight_quant_granularity='per_group', + quantize_weight_zero_point=False) + + # Run a test forward pass + model(torch.randn(2, 128)) + + # Pick export mode + if not args.no_custom_packed_export: + export_context_manager = brevitas_layer_export_mode + # we generate an export_class since we need to pass in the handler defined above + export_class = block_quant_layer_level_manager( + export_handlers=[LinearWeightBlockQuantHandlerFwd]) + else: + export_context_manager = brevitas_proxy_export_mode + export_class = BlockQuantProxyLevelManager + + # export with make_fx with support for fx wrap + with export_context_manager(model, export_class): + traced_model = make_fx(model)(torch.randn(2, 128)) + + # Replace placeholder for custom op with correct call, if any + replace_call_fn_target( + traced_model, + src=matmul_rhs_group_quant_placeholder, + target=torch.ops.brevitas.matmul_rhs_group_quant) + + # print the output graph + print(traced_model.graph) + + torch_mlir.compile( + traced_model, + torch.randn(2, 128), + output_type="torch", + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=True, + verbose=False) + + +def main(): + parser = argparse.ArgumentParser( + description='Export single linear with weight group quant to torch-mlir.') + parser.add_argument('--weight-bit-width', type=int, default=8, help='Weight bit width.') + parser.add_argument( + '--weight-quant-type', + type=str, + default='asym', + choices=['sym', 'asym'], + help='Weight quantization type.') + parser.add_argument( + '--weight-group-size', + type=int, + default=128, + help='Group size for group weight quantization.') + parser.add_argument( + '--no-custom-packed-export', + action='store_true', + help='Enable export to a custom mm op with packed weights for int2 and int4.') + args = parser.parse_args() + quantize_and_export(args) + + +if __name__ == "__main__": + main()