-
Notifications
You must be signed in to change notification settings - Fork 331
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add examples for Llama-2 magnitude pruning & add more comments
- Loading branch information
Showing
6 changed files
with
498 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,345 @@ | ||
|
||
# Code adapted from | ||
# https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py | ||
# https://github.com/locuslab/wanda | ||
|
||
import os, sys | ||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) | ||
|
||
import argparse | ||
import os | ||
import numpy as np | ||
import torch | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from importlib.metadata import version | ||
import time | ||
import torch | ||
import torch.nn as nn | ||
from collections import defaultdict | ||
import fnmatch | ||
import numpy as np | ||
import random | ||
from datasets import load_dataset | ||
|
||
# Set seed for reproducibility | ||
def set_seed(seed): | ||
np.random.seed(seed) | ||
torch.random.manual_seed(seed) | ||
|
||
# Wrapper for tokenized input IDs | ||
class TokenizerWrapper: | ||
def __init__(self, input_ids): | ||
self.input_ids = input_ids | ||
|
||
# Load and process wikitext2 dataset | ||
def get_wikitext2(nsamples, seed, seqlen, tokenizer): | ||
# Load train and test datasets | ||
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') | ||
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') | ||
|
||
# Encode datasets | ||
trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') | ||
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') | ||
|
||
# Generate samples from training set | ||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
return trainloader, testenc | ||
|
||
# Load and process c4 dataset | ||
def get_c4(nsamples, seed, seqlen, tokenizer): | ||
# Load train and validation datasets | ||
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') | ||
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') | ||
|
||
# Generate samples from training set | ||
random.seed(seed) | ||
trainloader = [] | ||
for _ in range(nsamples): | ||
while True: | ||
i = random.randint(0, len(traindata) - 1) | ||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') | ||
if trainenc.input_ids.shape[1] > seqlen: | ||
break | ||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
j = i + seqlen | ||
inp = trainenc.input_ids[:, i:j] | ||
tar = inp.clone() | ||
tar[:, :-1] = -100 | ||
trainloader.append((inp, tar)) | ||
|
||
# Prepare validation dataset | ||
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') | ||
valenc = valenc.input_ids[:, :(256 * seqlen)] | ||
valenc = TokenizerWrapper(valenc) | ||
return trainloader, valenc | ||
|
||
# Function to select the appropriate loader based on dataset name | ||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): | ||
if 'wikitext2' in name: | ||
return get_wikitext2(nsamples, seed, seqlen, tokenizer) | ||
if "c4" in name: | ||
return get_c4(nsamples, seed, seqlen, tokenizer) | ||
|
||
# Function to evaluate perplexity (ppl) on a specified model and tokenizer | ||
def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")): | ||
# Set dataset | ||
dataset = "wikitext2" | ||
|
||
# Print status | ||
print(f"evaluating on {dataset}") | ||
|
||
# Get the test loader | ||
_, testloader = get_loaders( | ||
dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer, | ||
) | ||
|
||
# Evaluate ppl in no grad context to avoid updating the model | ||
with torch.no_grad(): | ||
ppl_test = eval_ppl_wikitext(model, testloader, 1, device) | ||
return ppl_test | ||
|
||
# Function to evaluate perplexity (ppl) specifically on the wikitext dataset | ||
def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None): | ||
# Get input IDs | ||
# testenc = testenc.input_ids | ||
|
||
# Calculate number of samples | ||
# nsamples = testenc.numel() // model.seqlen | ||
nsamples = len(trainloader) | ||
|
||
# List to store negative log likelihoods | ||
nlls = [] | ||
print(f"nsamples {nsamples}") | ||
|
||
# Loop through each batch | ||
for i in range(0,nsamples,bs): | ||
if i % 50 == 0: | ||
print(f"sample {i}") | ||
|
||
# Calculate end index | ||
j = min(i+bs, nsamples) | ||
|
||
# Prepare inputs and move to device | ||
# inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device) | ||
inputs = trainloader[i][0].to(device) | ||
inputs = inputs.reshape(j-i, model.seqlen) | ||
|
||
# Forward pass through the model | ||
lm_logits = model(inputs).logits | ||
|
||
# Shift logits and labels for next token prediction | ||
shift_logits = lm_logits[:, :-1, :].contiguous() | ||
shift_labels = inputs[:, 1:] | ||
|
||
# Compute loss | ||
loss_fct = nn.CrossEntropyLoss() | ||
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) | ||
|
||
# Calculate negative log likelihood | ||
neg_log_likelihood = loss.float() * model.seqlen * (j-i) | ||
|
||
# Append to list of negative log likelihoods | ||
nlls.append(neg_log_likelihood) | ||
|
||
# Compute perplexity | ||
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) | ||
|
||
# Empty CUDA cache to save memory | ||
torch.cuda.empty_cache() | ||
|
||
return ppl.item() | ||
|
||
# Function to evaluate perplexity (ppl) specifically on the wikitext dataset | ||
def eval_ppl_wikitext(model, testenc, bs=1, device=None): | ||
# Get input IDs | ||
testenc = testenc.input_ids | ||
|
||
# Calculate number of samples | ||
nsamples = testenc.numel() // model.seqlen | ||
|
||
# List to store negative log likelihoods | ||
nlls = [] | ||
print(f"nsamples {nsamples}") | ||
|
||
# Loop through each batch | ||
for i in range(0,nsamples,bs): | ||
if i % 50 == 0: | ||
print(f"sample {i}") | ||
|
||
# Calculate end index | ||
j = min(i+bs, nsamples) | ||
|
||
# Prepare inputs and move to device | ||
inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device) | ||
inputs = inputs.reshape(j-i, model.seqlen) | ||
|
||
# Forward pass through the model | ||
lm_logits = model(inputs).logits | ||
|
||
# Shift logits and labels for next token prediction | ||
shift_logits = lm_logits[:, :-1, :].contiguous() | ||
shift_labels = inputs[:, 1:] | ||
|
||
# Compute loss | ||
loss_fct = nn.CrossEntropyLoss() | ||
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) | ||
|
||
# Calculate negative log likelihood | ||
neg_log_likelihood = loss.float() * model.seqlen * (j-i) | ||
|
||
# Append to list of negative log likelihoods | ||
nlls.append(neg_log_likelihood) | ||
|
||
# Compute perplexity | ||
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) | ||
|
||
# Empty CUDA cache to save memory | ||
torch.cuda.empty_cache() | ||
|
||
return ppl.item() | ||
|
||
|
||
def eval_zero_shot(model_name, model, tokenizer, task_list=["boolq","rte","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"], | ||
num_fewshot=0, use_accelerate=False, add_special_tokens=False): | ||
from lm_eval import tasks, evaluator | ||
def pattern_match(patterns, source_list): | ||
task_names = set() | ||
for pattern in patterns: | ||
for matching in fnmatch.filter(source_list, pattern): | ||
task_names.add(matching) | ||
return list(task_names) | ||
task_names = pattern_match(task_list, tasks.ALL_TASKS) | ||
model_args = f"pretrained={model_name}, cache_dir=./cache" | ||
limit = None | ||
if "70b" in model_name or "65b" in model_name: | ||
limit = 2000 | ||
if use_accelerate: | ||
model_args = f"pretrained={model_name}, cache_dir=./cache, use_accelerate=True" | ||
results = evaluator.simple_evaluate( | ||
model="hf-causal-experimental", | ||
model_args=model_args, | ||
tasks=task_names, | ||
num_fewshot=num_fewshot, | ||
batch_size=None, | ||
device=None, | ||
no_cache=True, | ||
limit=limit, | ||
description_dict={}, | ||
decontamination_ngrams_path=None, | ||
check_integrity=False, | ||
pretrained_model=model, | ||
tokenizer=tokenizer, | ||
add_special_tokens=add_special_tokens | ||
) | ||
|
||
return results | ||
|
||
|
||
print('torch', version('torch')) | ||
print('transformers', version('transformers')) | ||
print('accelerate', version('accelerate')) | ||
print('# of gpus: ', torch.cuda.device_count()) | ||
|
||
def get_llm(model_name, cache_dir="./cache"): | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
torch_dtype=torch.float16, | ||
cache_dir=cache_dir, | ||
device_map="auto" | ||
) | ||
|
||
model.seqlen = model.config.max_position_embeddings | ||
return model | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model', type=str, help='LLaMA model') | ||
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') | ||
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.') | ||
parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level') | ||
parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"]) | ||
parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt", | ||
"ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"]) | ||
parser.add_argument("--cache_dir", default="./cache", type=str ) | ||
parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix") | ||
parser.add_argument('--save', type=str, default=None, help='Path to save results.') | ||
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.') | ||
|
||
parser.add_argument("--eval_zero_shot", action="store_true") | ||
args = parser.parse_args() | ||
|
||
# Setting seeds for reproducibility | ||
np.random.seed(args.seed) | ||
torch.random.manual_seed(args.seed) | ||
|
||
model_name = args.model.split("/")[-1] | ||
print(f"loading llm model {args.model}") | ||
model = get_llm(args.model, args.cache_dir) | ||
model.eval() | ||
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) | ||
device = torch.device("cuda:0") | ||
if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here. | ||
device = model.hf_device_map["lm_head"] | ||
print("use device ", device) | ||
|
||
############## | ||
# Pruning | ||
############## | ||
print("----------------- Before Pruning -----------------") | ||
print(model) | ||
text = "Hello world." | ||
inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device) | ||
import torch_pruning as tp | ||
num_heads = {} | ||
for name, m in model.named_modules(): | ||
if name.endswith("self_attn"): | ||
num_heads[m.q_proj] = model.config.num_attention_heads | ||
num_heads[m.k_proj] = model.config.num_attention_heads | ||
num_heads[m.v_proj] = model.config.num_attention_heads | ||
head_pruning_ratio = 0.5 | ||
pruner = tp.pruner.MagnitudePruner( | ||
model, | ||
example_inputs=inputs, | ||
importance=tp.importance.MagnitudeImportance(), | ||
global_pruning=False, | ||
pruning_ratio=0.5, | ||
ignored_layers=[model.lm_head], | ||
num_heads=num_heads, | ||
prune_num_heads=True, | ||
prune_head_dims=False, | ||
head_pruning_ratio=head_pruning_ratio, | ||
) | ||
pruner.step() | ||
|
||
# Update model attributes | ||
|
||
num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads ) | ||
model.config.num_attention_heads = num_heads | ||
for name, m in model.named_modules(): | ||
if name.endswith("self_attn"): | ||
m.hidden_size = m.q_proj.out_features | ||
m.num_heads = num_heads | ||
m.num_key_value_heads = num_heads | ||
elif name.endswith("mlp"): | ||
model.config.intermediate_size = m.gate_proj.out_features | ||
|
||
print("----------------- After Pruning -----------------") | ||
print(model) | ||
|
||
ppl_test = eval_ppl(args, model, tokenizer, device) | ||
print(f"wikitext perplexity {ppl_test}") | ||
|
||
if args.save_model: | ||
model.save_pretrained(args.save_model) | ||
tokenizer.save_pretrained(args.save_model) | ||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.