Skip to content

Commit

Permalink
Examples: initial support for LLMs PTQ (#658)
Browse files Browse the repository at this point in the history
* Examples: WIP LLM block quantization

* Add support for block zero-point

* Add torch-mlir custom op support

* Add test linear

Signed-off-by: Alessandro Pappalardo <[email protected]>

* Update to custom matmul export

Signed-off-by: Alessandro Pappalardo <[email protected]>

* Fix errors

Signed-off-by: Alessandro Pappalardo <[email protected]>

* Fix output shape of custom op

Signed-off-by: Alessandro Pappalardo <[email protected]>

* Add lowering to torch_mlir for single layer

Signed-off-by: Alessandro Pappalardo <[email protected]>

* Some cleanups

* WIP llm flow

Signed-off-by: Alessandro Pappalardo <[email protected]>

* Fix (examples/llm): typo in custom quant matmul op (#607)

* Test act equalization support

* Initial end to end flow

* Initial support for QuantMHA on OPT

* Fix act equalization

* Typos in prints

* Reorganize validate

* Add initial per row quantizers

* Add per row input quantization support

* Support group quant slicing

* Adopt SliceTensor for block weight partial quant

* Add float16 support

* Fix scale type name

* Add support for LN affine merging

* WIP currently broken

* Clean up weight eq support

* Set weight narrow range always to False

* Add fx act equalization, fixes for float16 support

* Fix validate

* Fix backport imports

* Fix example export

Signed-off-by: Alessandro Pappalardo <[email protected]>

* Fix value_trace call in ln affine merging

* Add per tensor/row/group dynamic scale support, some dtype improvements

* Fix (llm): correct handling of attention mask shape (#652)

* ALways export in fp32 base dtype on CPU

* Export improvements

* Fix errors after latest PR

---------

Signed-off-by: Alessandro Pappalardo <[email protected]>
Co-authored-by: jinchen62 <[email protected]>
Co-authored-by: Giuseppe Franco <[email protected]>
  • Loading branch information
3 people authored Jul 17, 2023
1 parent b783650 commit 51baf37
Show file tree
Hide file tree
Showing 24 changed files with 2,552 additions and 4 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ def read_requirements(filename):
'brevitas_quartznet_preprocess = brevitas_examples.speech_to_text.get_librispeech_data:main',
'brevitas_melgan_preprocess = brevitas_examples.text_to_speech.preprocess_dataset:main',
'brevitas_ptq_imagenet_benchmark = brevitas_examples.imagenet_classification.ptq.ptq_benchmark:main',
'brevitas_ptq_imagenet_val = brevitas_examples.imagenet_classification.ptq.ptq_evaluate:main'
],})
'brevitas_ptq_imagenet_val = brevitas_examples.imagenet_classification.ptq.ptq_evaluate:main',
'brevitas_ptq_llm = brevitas_examples.llm.main:main'],})
1 change: 1 addition & 0 deletions src/brevitas/backport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,4 @@ def sym_min(a, b):

# Populate magic methods on SymInt and SymFloat
import brevitas.backport.fx.experimental.symbolic_shapes
import brevitas.backport.fx
5 changes: 3 additions & 2 deletions src/brevitas/backport/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@
from torch.utils._python_dispatch import TorchDispatchMode
import torch.utils._pytree as pytree

from brevitas import backport
from brevitas.backport import fx
from brevitas.backport import SymBool
from brevitas.backport import SymFloat
from brevitas.backport import SymInt
from brevitas.backport.fx import GraphModule
from brevitas.backport.fx import Proxy
from brevitas.backport.fx import Tracer
import brevitas.backport.fx as fx
from brevitas.backport.fx.passes.shape_prop import _extract_tensor_metadata
from brevitas.backport.utils._stats import count
from brevitas.backport.utils.weak import WeakTensorKeyDictionary
Expand Down Expand Up @@ -316,7 +317,7 @@ def proxy_call(proxy_mode, func, args, kwargs):
# `__torch_dispatch__` is only called on torch ops, which must subclass `OpOverload`
# We treat all other functions as an `external_call`, for instance, a function decorated
# with `@torch.fx.wrap`
external_call = not isinstance(func, fx.backport._ops.OpOverload)
external_call = not isinstance(func, backport._ops.OpOverload)

def can_handle_tensor(x):
return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
Expand Down
60 changes: 60 additions & 0 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# LLM quantization

## Requirements

- transformers
- datasets
- torch_mlir (optional for torch-mlir based export)

## Run

Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points.

```bash
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] [--weight-bit-width WEIGHT_BIT_WIDTH] [--weight-param-method {stats,mse}]
[--weight-scale-type {float32,po2}] [--weight-quant-type {sym,asym}] [--weight-quant-granularity {per_channel,per_tensor,per_group}]
[--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point] [--input-bit-width INPUT_BIT_WIDTH] [--input-param-method {stats,mse}]
[--input-scale-type {float32,po2}] [--input-quant-type {sym,asym}] [--input-quant-granularity {per_tensor}] [--quantize-input-zero-point] [--gptq]
[--act-calibration] [--bias-corr] [--act-equalization]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]

optional arguments:
-h, --help show this help message and exit
--model MODEL HF model name. Default: facebook/opt-125m.
--seed SEED Seed for sampling the calibration data. Default: 0.
--nsamples NSAMPLES Number of calibration data samples. Default: 128.
--seqlen SEQLEN Sequence length. Default: 2048.
--eval Eval model PPL on C4.
--weight-bit-width WEIGHT_BIT_WIDTH
Weight bit width. Default: 8.
--weight-param-method {stats,mse}
How scales/zero-point are determined. Default: stats.
--weight-scale-type {float32,po2}
Whether scale is a float value or a po2. Default: po2.
--weight-quant-type {sym,asym}
Weight quantization type. Default: asym.
--weight-quant-granularity {per_channel,per_tensor,per_group}
Granularity for scales/zero-point of weights. Default: per_group.
--weight-group-size WEIGHT_GROUP_SIZE
Group size for per_group weight quantization. Default: 128.
--quantize-weight-zero-point
Quantize weight zero-point.
--input-bit-width INPUT_BIT_WIDTH
Input bit width. Default: None (disables input quantization).
--input-param-method {stats,mse}
How scales/zero-point are determined. Default: stats.
--input-scale-type {float32,po2}
Whether input scale is a float value or a po2. Default: float32.
--input-quant-type {sym,asym}
Input quantization type. Default: asym.
--input-quant-granularity {per_tensor}
Granularity for scales/zero-point of inputs. Default: per_tensor.
--quantize-input-zero-point
Quantize input zero-point.
--gptq Apply GPTQ.
--act-calibration Apply activation calibration.
--bias-corr Apply bias correction.
--act-equalization Apply activation equalization (SmoothQuant).
--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}
Model export.
```
Empty file.
Empty file.
26 changes: 26 additions & 0 deletions src/brevitas_examples/llm/llm_quant/bias_corr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""

import torch

from brevitas.graph.calibrate import bias_correction_mode
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn


@torch.no_grad()
def bias_corr_iter(curr_layer, inps, outs, cached_values):
curr_layer = curr_layer.cuda()
with bias_correction_mode(curr_layer):
for j in range(len(inps)):
inp = inps[j].unsqueeze(0).cuda()
curr_out = curr_layer(inp, **cached_values)[0]
outs[j] = curr_out
curr_layer.cpu()
return outs


@torch.no_grad()
def apply_bias_correction(model, dataloader, nsamples, seqlen=2048):
apply_layer_ptq_fn(model, dataloader, nsamples, inference_fn=bias_corr_iter, seqlen=seqlen)
26 changes: 26 additions & 0 deletions src/brevitas_examples/llm/llm_quant/calibrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""

import torch

from brevitas.graph.calibrate import calibration_mode
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn


@torch.no_grad()
def calibration_iter(curr_layer, inps, outs, cached_values):
curr_layer = curr_layer.cuda()
with calibration_mode(curr_layer):
for j in range(len(inps)):
inp = inps[j].unsqueeze(0).cuda()
curr_out = curr_layer(inp, **cached_values)[0]
outs[j] = curr_out
curr_layer.cpu()
return outs


@torch.no_grad()
def apply_calibration(model, dataloader, nsamples, seqlen=2048):
apply_layer_ptq_fn(model, dataloader, nsamples, inference_fn=calibration_iter, seqlen=seqlen)
71 changes: 71 additions & 0 deletions src/brevitas_examples/llm/llm_quant/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE:
Copyright 2023 IST-DASLab
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import random

from datasets import load_dataset
import torch
from transformers import AutoTokenizer


def get_c4(nsamples, seed, seqlen, model, nvalsamples=256):
traindata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
split='train',
use_auth_token=False)
valdata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation',
use_auth_token=False)

try:
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
except:
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)

random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
trainloader.append(inp)

random.seed(0) # hardcoded for validation reproducibility
valenc = []
for _ in range(nvalsamples):
while True:
i = random.randint(0, len(valdata) - 1)
tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
if tmp.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
valenc.append(tmp.input_ids[:, i:j])

valenc = torch.hstack(valenc)
return trainloader, valenc
74 changes: 74 additions & 0 deletions src/brevitas_examples/llm/llm_quant/equalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""

import warnings

import torch

from brevitas.fx.brevitas_tracer import value_trace
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import EqualizeGraph
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32


@torch.no_grad()
def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
curr_layer = curr_layer.cuda()
with activation_equalization_mode(curr_layer, alpha, add_mul_node=True, layerwise=True):
for j in range(len(inps)):
inp = inps[j].unsqueeze(0).cuda()
curr_out = curr_layer(inp, **cached_values)[0]
outs[j] = curr_out
curr_layer.cpu()
return outs


@torch.no_grad()
def apply_act_equalization(
model,
dtype,
act_equalization_type,
dataloader,
nsamples,
seqlen=2048,
alpha=0.5,
ref_kwargs=None):
if act_equalization_type == 'layerwise':
apply_layer_ptq_fn(
model,
dataloader,
nsamples,
inference_fn=activation_equalization_iter,
seqlen=seqlen,
alpha=alpha)
elif act_equalization_type == 'fx':
assert ref_kwargs is not None, "Ref kwargs required to perform tracing and lift the model into FX."
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode
# or an FX interpreter to run it on GPU
warnings.warn(
"FX mode activation equalization currently runs on CPU, expect it to be slow for large models."
)
with activation_equalization_mode(graph_model,
alpha,
add_mul_node=False,
layerwise=False):
for input_ids in dataloader:
graph_model(input_ids=input_ids)
else:
raise RuntimeError(f"{act_equalization_type} not supported.")


@torch.no_grad()
def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='range'):
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model)
91 changes: 91 additions & 0 deletions src/brevitas_examples/llm/llm_quant/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE:
Copyright 2023 IST-DASLab
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch
from torch import nn
from tqdm import tqdm

from brevitas_examples.llm.llm_quant.run_utils import apply_layer_inference_fn
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
from brevitas_examples.llm.llm_quant.run_utils import InputCatcherException


def eval_inference_fn(curr_layer, inps, outs, cached_values):
curr_layer.cuda()
for j in range(len(inps)):
outs[j] = curr_layer(inps[j].unsqueeze(0).cuda(), **cached_values)[0]
curr_layer.cpu()


@torch.no_grad()
def model_eval(model, valenc, seqlen):

nsamples = valenc.numel() // seqlen

def eval_input_capture_fn(model, data):
for i in range(nsamples):
batch = data[:, (i * seqlen):((i + 1) * seqlen)].cuda()
try:
model(batch)
except InputCatcherException:
pass

inps = apply_layer_inference_fn(
model,
valenc,
nsamples,
input_capture_fn=eval_input_capture_fn,
inference_fn=eval_inference_fn,
seqlen=seqlen)

model_impl = get_model_impl(model)
use_cache = model.config.use_cache
model.config.use_cache = False

if hasattr(model_impl, 'norm') and model_impl.norm is not None:
model_impl.norm = model_impl.norm.cuda()
if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None:
model_impl.final_layer_norm = model_impl.final_layer_norm.cuda()
if hasattr(model_impl, 'project_out') and model_impl.project_out is not None:
model_impl.project_out = model_impl.project_out.cuda()
if hasattr(model, 'lm_head'):
model.lm_head = model.lm_head.cuda()

valenc = valenc.cuda()
nlls = []
for i in tqdm(range(nsamples)):
hidden_states = inps[i].unsqueeze(0)
if hasattr(model_impl, 'norm') and model_impl.norm is not None:
hidden_states = model_impl.norm(hidden_states)
if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None:
hidden_states = model_impl.final_layer_norm(hidden_states)
if hasattr(model_impl, 'project_out') and model_impl.project_out is not None:
hidden_states = model_impl.project_out(hidden_states)
lm_logits = hidden_states
if hasattr(model, 'lm_head'):
lm_logits = model.lm_head(lm_logits)
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = valenc[:, (i * seqlen):((i + 1) * seqlen)][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * seqlen
nlls.append(neg_log_likelihood)

ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen))
model.config.use_cache = use_cache
return ppl
Loading

0 comments on commit 51baf37

Please sign in to comment.