Skip to content

Commit

Permalink
Add examples for Llama-2 magnitude pruning & add more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Jun 4, 2024
1 parent 94d7161 commit c9ce227
Show file tree
Hide file tree
Showing 6 changed files with 498 additions and 47 deletions.
345 changes: 345 additions & 0 deletions examples/LLMs/prune_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@

# Code adapted from
# https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
# https://github.com/locuslab/wanda

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))

import argparse
import os
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
import time
import torch
import torch.nn as nn
from collections import defaultdict
import fnmatch
import numpy as np
import random
from datasets import load_dataset

# Set seed for reproducibility
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)

# Wrapper for tokenized input IDs
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids

# Load and process wikitext2 dataset
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
# Load train and test datasets
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

# Encode datasets
trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

# Generate samples from training set
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc

# Load and process c4 dataset
def get_c4(nsamples, seed, seqlen, tokenizer):
# Load train and validation datasets
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')

# Generate samples from training set
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] > seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))

# Prepare validation dataset
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
valenc = TokenizerWrapper(valenc)
return trainloader, valenc

# Function to select the appropriate loader based on dataset name
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, tokenizer)
if "c4" in name:
return get_c4(nsamples, seed, seqlen, tokenizer)

# Function to evaluate perplexity (ppl) on a specified model and tokenizer
def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")):
# Set dataset
dataset = "wikitext2"

# Print status
print(f"evaluating on {dataset}")

# Get the test loader
_, testloader = get_loaders(
dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer,
)

# Evaluate ppl in no grad context to avoid updating the model
with torch.no_grad():
ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
return ppl_test

# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None):
# Get input IDs
# testenc = testenc.input_ids

# Calculate number of samples
# nsamples = testenc.numel() // model.seqlen
nsamples = len(trainloader)

# List to store negative log likelihoods
nlls = []
print(f"nsamples {nsamples}")

# Loop through each batch
for i in range(0,nsamples,bs):
if i % 50 == 0:
print(f"sample {i}")

# Calculate end index
j = min(i+bs, nsamples)

# Prepare inputs and move to device
# inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
inputs = trainloader[i][0].to(device)
inputs = inputs.reshape(j-i, model.seqlen)

# Forward pass through the model
lm_logits = model(inputs).logits

# Shift logits and labels for next token prediction
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inputs[:, 1:]

# Compute loss
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

# Calculate negative log likelihood
neg_log_likelihood = loss.float() * model.seqlen * (j-i)

# Append to list of negative log likelihoods
nlls.append(neg_log_likelihood)

# Compute perplexity
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

# Empty CUDA cache to save memory
torch.cuda.empty_cache()

return ppl.item()

# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext(model, testenc, bs=1, device=None):
# Get input IDs
testenc = testenc.input_ids

# Calculate number of samples
nsamples = testenc.numel() // model.seqlen

# List to store negative log likelihoods
nlls = []
print(f"nsamples {nsamples}")

# Loop through each batch
for i in range(0,nsamples,bs):
if i % 50 == 0:
print(f"sample {i}")

# Calculate end index
j = min(i+bs, nsamples)

# Prepare inputs and move to device
inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
inputs = inputs.reshape(j-i, model.seqlen)

# Forward pass through the model
lm_logits = model(inputs).logits

# Shift logits and labels for next token prediction
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inputs[:, 1:]

# Compute loss
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

# Calculate negative log likelihood
neg_log_likelihood = loss.float() * model.seqlen * (j-i)

# Append to list of negative log likelihoods
nlls.append(neg_log_likelihood)

# Compute perplexity
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

# Empty CUDA cache to save memory
torch.cuda.empty_cache()

return ppl.item()


def eval_zero_shot(model_name, model, tokenizer, task_list=["boolq","rte","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"],
num_fewshot=0, use_accelerate=False, add_special_tokens=False):
from lm_eval import tasks, evaluator
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return list(task_names)
task_names = pattern_match(task_list, tasks.ALL_TASKS)
model_args = f"pretrained={model_name}, cache_dir=./cache"
limit = None
if "70b" in model_name or "65b" in model_name:
limit = 2000
if use_accelerate:
model_args = f"pretrained={model_name}, cache_dir=./cache, use_accelerate=True"
results = evaluator.simple_evaluate(
model="hf-causal-experimental",
model_args=model_args,
tasks=task_names,
num_fewshot=num_fewshot,
batch_size=None,
device=None,
no_cache=True,
limit=limit,
description_dict={},
decontamination_ngrams_path=None,
check_integrity=False,
pretrained_model=model,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens
)

return results


print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name, cache_dir="./cache"):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
cache_dir=cache_dir,
device_map="auto"
)

model.seqlen = model.config.max_position_embeddings
return model

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='LLaMA model')
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
"ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
parser.add_argument("--cache_dir", default="./cache", type=str )
parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')

parser.add_argument("--eval_zero_shot", action="store_true")
args = parser.parse_args()

# Setting seeds for reproducibility
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

model_name = args.model.split("/")[-1]
print(f"loading llm model {args.model}")
model = get_llm(args.model, args.cache_dir)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
device = torch.device("cuda:0")
if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
device = model.hf_device_map["lm_head"]
print("use device ", device)

##############
# Pruning
##############
print("----------------- Before Pruning -----------------")
print(model)
text = "Hello world."
inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
import torch_pruning as tp
num_heads = {}
for name, m in model.named_modules():
if name.endswith("self_attn"):
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_attention_heads
num_heads[m.v_proj] = model.config.num_attention_heads
head_pruning_ratio = 0.5
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=inputs,
importance=tp.importance.MagnitudeImportance(),
global_pruning=False,
pruning_ratio=0.5,
ignored_layers=[model.lm_head],
num_heads=num_heads,
prune_num_heads=True,
prune_head_dims=False,
head_pruning_ratio=head_pruning_ratio,
)
pruner.step()

# Update model attributes

num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )
model.config.num_attention_heads = num_heads
for name, m in model.named_modules():
if name.endswith("self_attn"):
m.hidden_size = m.q_proj.out_features
m.num_heads = num_heads
m.num_key_value_heads = num_heads
elif name.endswith("mlp"):
model.config.intermediate_size = m.gate_proj.out_features

print("----------------- After Pruning -----------------")
print(model)

ppl_test = eval_ppl(args, model, tokenizer, device)
print(f"wikitext perplexity {ppl_test}")

if args.save_model:
model.save_pretrained(args.save_model)
tokenizer.save_pretrained(args.save_model)

if __name__ == '__main__':
main()
Loading

0 comments on commit c9ce227

Please sign in to comment.