Skip to content

Commit

Permalink
Feat: Update LLM entry-point (#987)
Browse files Browse the repository at this point in the history
* Feat (example/llm): Added fnuz/ocp args

* Docs (example/llm): typo fix in args description

* Feat (example/llm): Add zero bias to linear layers when doing bias correction.

* Fix (example/llm): Remove unnecessary forward pass

* Feat (example/llm): Leveraged data utils from optimum-amd integration

* Feat (example/llm): Load KV Cache to correct `dtype`

* Feat (example/llm): Added progress bar to bias correction

* Fix (example/llm): Fix formatting.

* feat (example/llm): Switched `ln_affine_merge` to use HF's tracer

* feat (example/llm): decompose `quantize_model` into component parts.

* Fix (example/llm): Assert that TorchQCDQ export & Eval aren't both enabled.

* feat (example/llm): Added option not to quantize the last linear layer

* Fix precommit

* Fix (example/llm): disable embedded lookup quantization
  • Loading branch information
nickfraser authored Aug 20, 2024
1 parent d7cfc04 commit b9eecf7
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 35 deletions.
3 changes: 2 additions & 1 deletion src/brevitas_examples/llm/llm_quant/bias_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
"""

import torch
from tqdm import tqdm

from brevitas.graph.calibrate import bias_correction_mode


@torch.no_grad()
def apply_bias_correction(model, dataloader):
with bias_correction_mode(model):
for inps in dataloader:
for inps in tqdm(dataloader):
model(**inps)
108 changes: 108 additions & 0 deletions src/brevitas_examples/llm/llm_quant/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Adapted from https://github.com/huggingface/optimum-amd, released under the following LICENSE:
MIT License
Copyright (c) 2023 Hugging Face
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import random
from typing import Any, Optional, Union

import numpy as np
from optimum.amd.brevitas.data_utils import DatasetToDevice
from optimum.amd.brevitas.data_utils import get_c4
from optimum.amd.brevitas.data_utils import get_wikitext2
from optimum.utils.normalized_config import NormalizedConfigManager
import torch
from transformers import AutoConfig


def get_dataset_for_model(
model_name_or_path: str,
dataset_name: str,
tokenizer: Any,
nsamples: int = 128,
seqlen: int = 2048,
seed: int = 0,
split: str = "train",
fuse_sequences: bool = True,
require_fx: bool = False,
device: Optional[Union[str, torch.device]] = None,
):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
get_dataset_map = {
"wikitext2": get_wikitext2,
"c4": get_c4,}
if split not in ["train", "validation"]:
raise ValueError(f"The split need to be 'train' or 'validation' but found {split}")
if dataset_name not in get_dataset_map:
raise ValueError(
f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}")
get_dataset_fn = get_dataset_map[dataset_name]

data = get_dataset_fn(
tokenizer=tokenizer,
nsamples=nsamples,
seqlen=seqlen,
split=split,
fuse_sequences=fuse_sequences,
seed=seed)

# In case the dataset is loaded to be used with an fx.GraphModule, we need to add empty past_key_values inputs in the dataset.
if require_fx:
config = AutoConfig.from_pretrained(model_name_or_path)

normalized_config_class = NormalizedConfigManager.get_normalized_config_class(
config.model_type)
normalized_config = normalized_config_class(config)

num_heads = normalized_config.num_attention_heads
if hasattr(normalized_config, "num_key_value_heads"):
num_kv_heads = normalized_config.num_key_value_heads
else:
num_kv_heads = num_heads
head_dim = normalized_config.hidden_size // num_heads
num_layers = normalized_config.num_layers

for sample in data:
sample["past_key_values"] = tuple((
torch.zeros(
1,
num_kv_heads,
0,
head_dim,
device=sample["input_ids"].device,
dtype=sample["input_ids"].dtype),
torch.zeros(
1,
num_kv_heads,
0,
head_dim,
device=sample["input_ids"].device,
dtype=sample["input_ids"].dtype),
) for _ in range(num_layers))

data = DatasetToDevice(data, device=device)

return data
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/llm_quant/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def create_validation_dataloader(data, seqlen, device):
@torch.no_grad()
def model_eval(model, valenc, seqlen):
nsamples = len(valenc)
dev = next(iter(model.parameters())).device
with torch.no_grad():
nlls = []
for inps in valenc:
lm_logits = model(**inps)['logits']
shift_logits = lm_logits[:, :-1, :].contiguous()
dev = shift_logits.device
shift_labels = inps['input_ids'][:, 1:].to(dev)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas_examples/llm/llm_quant/ln_affine_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from torch import nn

from brevitas.fx import value_trace
from brevitas.graph.equalize import _is_reshaping_op
from brevitas.graph.equalize import _is_scale_invariant_module
from brevitas.graph.utils import get_module
Expand Down Expand Up @@ -84,9 +83,8 @@ def merge_layernorm_affine_params(graph_model):


@torch.no_grad()
def apply_layernorm_affine_merge(model, dtype, ref_kwargs):
def apply_layernorm_affine_merge(graph_model, dtype):
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply merging, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
with cast_to_float32(graph_model, dtype):
merge_layernorm_affine_params(graph_model)
16 changes: 16 additions & 0 deletions src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings

import torch
from transformers.models.opt.modeling_opt import OPTAttention

from brevitas.graph import ModuleToModuleByClass
Expand All @@ -21,3 +22,18 @@ def replace_mha_with_quantizable_layers(model, dtype):
for rewriter in rewriters:
model = rewriter.apply(model)
return model


@torch.no_grad()
def add_zero_bias_to_linear(model: torch.nn.Module) -> torch.nn.Module:
for name, module in model.named_modules():
if type(module) == torch.nn.Linear:
if module.bias is None:
module.register_parameter(
"bias",
torch.nn.Parameter(
torch.zeros((module.weight.shape[0],),
device=module.weight.device,
dtype=module.weight.dtype)),
)
return model
Loading

0 comments on commit b9eecf7

Please sign in to comment.