diff --git a/examples/LLMs/eval_ppl.py b/examples/LLMs/eval_ppl.py new file mode 100644 index 0000000..f0e5951 --- /dev/null +++ b/examples/LLMs/eval_ppl.py @@ -0,0 +1,272 @@ +# Code adapted from https://github.com/locuslab/wanda/blob/main/main.py +import argparse +from importlib.metadata import version +import os +import time +import fnmatch +import random +import numpy as np +from collections import defaultdict +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoModelForCausalLM +from datasets import load_dataset + +# Set seed for reproducibility +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + +# Wrapper for tokenized input IDs +class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + +# Load and process wikitext2 dataset +def get_wikitext2(nsamples, seed, seqlen, tokenizer): + # Load train and test datasets + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + # Encode datasets + trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + + # Generate samples from training set + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +# Load and process c4 dataset +def get_c4(nsamples, seed, seqlen, tokenizer): + # Load train and validation datasets + traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') + valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') + + # Generate samples from training set + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] > seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + # Prepare validation dataset + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + valenc = TokenizerWrapper(valenc) + return trainloader, valenc + +# Function to select the appropriate loader based on dataset name +def get_loaders(name, nsamples=128, seed=0, seqlen=4096, tokenizer=None): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, tokenizer) + if "c4" in name: + return get_c4(nsamples, seed, seqlen, tokenizer) + +# Function to evaluate perplexity (ppl) on a specified model and tokenizer +def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")): + # Set dataset + dataset = "wikitext2" + + # Print status + print(f"evaluating on {dataset}") + + # Get the test loader + _, testloader = get_loaders( + dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer, + ) + + # Evaluate ppl in no grad context to avoid updating the model + with torch.no_grad(): + ppl_test = eval_ppl_wikitext(model, testloader, 1, device) + return ppl_test + +# Function to evaluate perplexity (ppl) specifically on the wikitext dataset +def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None): + # Get input IDs + # testenc = testenc.input_ids + + # Calculate number of samples + # nsamples = testenc.numel() // model.seqlen + nsamples = len(trainloader) + + # List to store negative log likelihoods + nlls = [] + print(f"nsamples {nsamples}") + + # Loop through each batch + for i in range(0,nsamples,bs): + if i % 50 == 0: + print(f"sample {i}") + + # Calculate end index + j = min(i+bs, nsamples) + + # Prepare inputs and move to device + # inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device) + inputs = trainloader[i][0].to(device) + inputs = inputs.reshape(j-i, model.seqlen) + + # Forward pass through the model + lm_logits = model(inputs).logits + + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + + # Compute loss + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) + + # Calculate negative log likelihood + neg_log_likelihood = loss.float() * model.seqlen * (j-i) + + # Append to list of negative log likelihoods + nlls.append(neg_log_likelihood) + + # Compute perplexity + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) + + # Empty CUDA cache to save memory + torch.cuda.empty_cache() + + return ppl.item() + +# Function to evaluate perplexity (ppl) specifically on the wikitext dataset +def eval_ppl_wikitext(model, testenc, bs=1, device=None): + # Get input IDs + testenc = testenc.input_ids + + # Calculate number of samples + nsamples = testenc.numel() // model.seqlen + + # List to store negative log likelihoods + nlls = [] + print(f"nsamples {nsamples}") + + # Loop through each batch + for i in range(0,nsamples,bs): + if i % 50 == 0: + print(f"sample {i}") + + # Calculate end index + j = min(i+bs, nsamples) + + # Prepare inputs and move to device + inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device) + inputs = inputs.reshape(j-i, model.seqlen) + + # Forward pass through the model + lm_logits = model(inputs).logits + + # Shift logits and labels for next token prediction + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = inputs[:, 1:] + + # Compute loss + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) + + # Calculate negative log likelihood + neg_log_likelihood = loss.float() * model.seqlen * (j-i) + + # Append to list of negative log likelihoods + nlls.append(neg_log_likelihood) + + # Compute perplexity + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) + + # Empty CUDA cache to save memory + torch.cuda.empty_cache() + + return ppl.item() + +print('torch', version('torch')) +print('transformers', version('transformers')) +print('accelerate', version('accelerate')) +print('# of gpus: ', torch.cuda.device_count()) + +def get_llm(model_name, cache_dir="./assets/cache"): + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16, + #cache_dir=cache_dir, + device_map="auto" + ) + model.seqlen = 4096 if model.config.max_position_embeddings>=4096 else model.config.max_position_embeddings + return model + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, help='LLaMA model') + parser.add_argument('--mask', type=str, default=None, help="Path to the mask ckpt") + parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') + parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.') + parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level') + parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"]) + parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt", + "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"]) + parser.add_argument("--cache_dir", default="./assets/cache", type=str ) + parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix") + parser.add_argument('--save', type=str, default=None, help='Path to save results.') + parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.') + args = parser.parse_args() + + # Setting seeds for reproducibility + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + + model_name = args.model.split("/")[-1] + print(f"loading llm model {args.model}") + model = get_llm(args.model, args.cache_dir) + + if args.mask is not None: + if args.mask.endswith(".pt"): # raw mask ckpt, this will be quite large (~6GB for 7b model) + mask_ckpt = torch.load(args.mask, map_location='cpu') + model_state = model.state_dict() + for k, v in mask_ckpt.items(): + k_original = k.replace(".mask", "") + model_state[k_original] *= v.to(model_state[k_original].device).float() + model.load_state_dict(model_state) + elif args.mask.endswith(".npz"): # compressed mask ckpt, this will be much smaller (~500MB for 7b model) + mask_ckpt = np.load(args.mask) + model_state = model.state_dict() + for k, v in mask_ckpt.items(): + k_original = k.replace(".mask", "") + v = np.unpackbits(v) # to bits + mask = torch.from_numpy(v).to(model_state[k_original].device).float() + mask = mask.view(*model_state[k_original].shape) # reshape the mask + model_state[k_original] *= mask # apply the mask + model.load_state_dict(model_state) + model.eval() + + tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) + device = torch.device("cuda:0") + if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here. + device = model.hf_device_map["lm_head"] + print("use device ", device) + ppl_test = eval_ppl(args, model, tokenizer, device) + print(f"wikitext perplexity {ppl_test}") + + if args.save_model: + model.save_pretrained(args.save_model) + tokenizer.save_pretrained(args.save_model) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/LLMs/prune_llama.py b/examples/LLMs/prune_llama.py index 303489a..54e9ac2 100644 --- a/examples/LLMs/prune_llama.py +++ b/examples/LLMs/prune_llama.py @@ -1,4 +1,3 @@ - # Code adapted from # https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py # https://github.com/locuslab/wanda @@ -297,38 +296,50 @@ def main(): num_heads[m.q_proj] = model.config.num_attention_heads num_heads[m.k_proj] = model.config.num_key_value_heads num_heads[m.v_proj] = model.config.num_key_value_heads - + _is_gqa = model.config.num_attention_heads != model.config.num_key_value_heads head_pruning_ratio = args.pruning_ratio hidden_size_pruning_ratio = args.pruning_ratio - pruner = tp.pruner.MagnitudePruner( + importance = tp.importance.GroupNormImportance(p=2, group_reduction='mean') #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear]) + pruner = tp.pruner.MetaPruner( model, example_inputs=inputs, - importance=tp.importance.GroupNormImportance(), + importance=importance, global_pruning=False, pruning_ratio=hidden_size_pruning_ratio, ignored_layers=[model.lm_head], num_heads=num_heads, prune_num_heads=True, - prune_head_dims=False, + prune_head_dims=False, # we do not prune head dims so that we don't need to prune the ROPE head_pruning_ratio=head_pruning_ratio, ) + #with torch.no_grad(): + # with importance.compute_importance(model): + # calibration_data = "We recommend at least a 1TB hard drive for 4 channels, more if you plan on using 8MP \/ 4K cameras.\nDahua's Lite Series network video recorders offer excellent performance and high recording quality for IP video surveillance applications. For applications where details are critical for identification, this professional NVR provides a powerful processor with up to 4K resolution. Additionally, the NVR features a mouse shortcut operation menu, remote management and control, center storage, edge storage, and back up storage." + # calibration_data = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device) + # _ = model(calibration_data) pruner.step() - # Update model attributes - num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads ) - num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads ) - model.config.num_attention_heads = num_heads - model.config.num_key_value_heads = num_key_value_heads + # Update model attributes + model.config.hidden_size = model.lm_head.in_features for name, m in model.named_modules(): if name.endswith("self_attn"): m.hidden_size = m.q_proj.out_features - m.num_heads = num_heads - m.num_key_value_heads = num_key_value_heads + m.num_heads = m.hidden_size // m.head_dim + model.config.num_attention_heads = m.num_heads + #m.head_dim = m.q_proj.out_features // m.num_heads + if not _is_gqa: + m.num_key_value_heads = m.num_heads + m.num_key_value_groups = m.num_heads // m.num_key_value_heads elif name.endswith("mlp"): model.config.intermediate_size = m.gate_proj.out_features + if not _is_gqa: + model.config.num_key_value_heads = model.config.num_attention_heads print("----------------- After Pruning -----------------") print(model) + print(model.config) + num_params = sum(p.numel() for p in model.parameters()) + print(f"num_params {num_params}") ppl_test = eval_ppl(args, model, tokenizer, device) print(f"wikitext perplexity {ppl_test}") diff --git a/examples/LLMs/readme.md b/examples/LLMs/readme.md index decb9d3..5f1fbb4 100644 --- a/examples/LLMs/readme.md +++ b/examples/LLMs/readme.md @@ -55,8 +55,8 @@ LlamaForCausalLM( (0-31): 32 x LlamaDecoderLayer( (self_attn): LlamaSdpaAttention( (q_proj): Linear(in_features=2048, out_features=2048, bias=False) - (k_proj): Linear(in_features=2048, out_features=512, bias=False) - (v_proj): Linear(in_features=2048, out_features=512, bias=False) + (k_proj): Linear(in_features=2048, out_features=1024, bias=False) + (v_proj): Linear(in_features=2048, out_features=1024, bias=False) (o_proj): Linear(in_features=2048, out_features=2048, bias=False) (rotary_emb): LlamaRotaryEmbedding() ) @@ -66,18 +66,50 @@ LlamaForCausalLM( (down_proj): Linear(in_features=7168, out_features=2048, bias=False) (act_fn): SiLU() ) - (input_layernorm): LlamaRMSNorm() - (post_attention_layernorm): LlamaRMSNorm() + (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05) + (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05) ) ) - (norm): LlamaRMSNorm() + (norm): LlamaRMSNorm((2048,), eps=1e-05) + (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): Linear(in_features=2048, out_features=128256, bias=False) ) +LlamaConfig { + "_name_or_path": "meta-llama/Meta-Llama-3-8B", + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 7168, + "max_position_embeddings": 8192, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 16, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.44.2", + "use_cache": true, + "vocab_size": 128256 +} + +num_params 2337409024 evaluating on wikitext2 nsamples 35 sample 0 -wikitext perplexity 41982.296875 +wikitext perplexity 552648.25 ``` @@ -142,19 +174,51 @@ LlamaForCausalLM( (down_proj): Linear(in_features=5504, out_features=2048, bias=False) (act_fn): SiLU() ) - (input_layernorm): LlamaRMSNorm() - (post_attention_layernorm): LlamaRMSNorm() + (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05) + (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05) ) ) - (norm): LlamaRMSNorm() + (norm): LlamaRMSNorm((2048,), eps=1e-05) + (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): Linear(in_features=2048, out_features=32000, bias=False) ) +LlamaConfig { + "_name_or_path": "meta-llama/Llama-2-7b-hf", + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5504, + "max_position_embeddings": 4096, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 16, + "num_hidden_layers": 32, + "num_key_value_heads": 16, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.44.2", + "use_cache": true, + "vocab_size": 32000 +} + +num_params 1750206464 evaluating on wikitext2 nsamples 83 sample 0 sample 50 -wikitext perplexity 9605.4130859375 +wikitext perplexity 8479.0673828125 ``` diff --git a/examples/timm_models/prune_timm_models.py b/examples/timm_models/prune_timm_models.py index f3a4dec..98a8dcb 100644 --- a/examples/timm_models/prune_timm_models.py +++ b/examples/timm_models/prune_timm_models.py @@ -38,69 +38,39 @@ def main(): ignored_layers = [] num_heads = {} pruning_ratio_dict = {} - import random + print("========Before pruning========") + print(model) + base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) + pruner = tp.pruner.MetaPruner( + model, + example_inputs, + global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers. + importance=imp, # importance criterion for parameter selection + iterative_steps=1, # the number of iterations to achieve target pruning ratio + pruning_ratio=args.pruning_ratio, # target pruning ratio + pruning_ratio_dict=pruning_ratio_dict, + num_heads=num_heads, + ignored_layers=ignored_layers, + ) + for g in pruner.step(interactive=True): + g.prune() - population = [ - [0.265625,0.234375,0.265625,0.265625,0.93359375,0.328125,0.2265625,0.58984375,0.54296875,0.701171875,0.919921875,0.04296875,0.796875,0.240966796875,0.07763671875], - [0.96875,0.578125,0.3515625,0.6328125,0.7578125,0.7109375,0.8984375,0.533203125,0.0703125,0.697265625,0.451171875,0.626953125,0.935546875,0.294921875,0.5244140625], - [0.25,0.421875,0.171875,0.4921875,0.71875,0.51953125,0.71875,0.876953125,0.896484375,0.626953125,0.646484375,0.490234375,0.65234375,0.599609375,0.0341796875], - [0.015625,0.015625,0.078125,0.4375,0.59375,0.6953125,0.73828125,0.611328125,0.787109375,0.76171875,0.25,0.427734375,0.154296875,0.592529296875,0.298583984375] - ] + for m in model.modules(): + # Attention layers + if hasattr(m, 'num_heads'): + if hasattr(m, 'qkv'): + m.num_heads = num_heads[m.qkv] + m.head_dim = m.qkv.out_features // (3 * m.num_heads) + elif hasattr(m, 'qkv_proj'): + m.num_heads = num_heads[m.qqkv_projkv] + m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads) - for ratios in population: - k = 0 - for m in model.modules(): - #if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes: - if isinstance(m, nn.Linear) and m.out_features == model.num_classes: - ignored_layers.append(m) - print("Ignore classifier layer: ", m) - - # Attention layers - if hasattr(m, 'num_heads'): - if hasattr(m, 'qkv'): - num_heads[m.qkv] = m.num_heads - print("Attention layer: ", m.qkv, m.num_heads) - elif hasattr(m, 'qkv_proj'): - num_heads[m.qkv_proj] = m.num_heads - - elif isinstance(m, nn.Conv2d): - pruning_ratio_dict[m] = ratios[k] - k+=1 - - - print("========Before pruning========") - print(model) - base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) - pruner = tp.pruner.MetaPruner( - model, - example_inputs, - global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers. - importance=imp, # importance criterion for parameter selection - iterative_steps=1, # the number of iterations to achieve target pruning ratio - pruning_ratio=args.pruning_ratio, # target pruning ratio - pruning_ratio_dict=pruning_ratio_dict, - num_heads=num_heads, - ignored_layers=ignored_layers, - ) - for g in pruner.step(interactive=True): - g.prune() - - for m in model.modules(): - # Attention layers - if hasattr(m, 'num_heads'): - if hasattr(m, 'qkv'): - m.num_heads = num_heads[m.qkv] - m.head_dim = m.qkv.out_features // (3 * m.num_heads) - elif hasattr(m, 'qkv_proj'): - m.num_heads = num_heads[m.qqkv_projkv] - m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads) - - print("========After pruning========") - print(model) - test_output = model(example_inputs) - pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) - print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9)) - print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6)) + print("========After pruning========") + print(model) + test_output = model(example_inputs) + pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) + print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9)) + print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6)) if __name__=='__main__': main() \ No newline at end of file diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 1580e41..8de5d87 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -18,10 +18,6 @@ def test_pruner(): [tp.importance.GroupNormImportance, tp.pruner.GroupNormPruner], [tp.importance.BNScaleImportance, tp.pruner.BNScalePruner], [tp.importance.GroupNormImportance, tp.pruner.GrowingRegPruner], - [tp.importance.MagnitudeImportance, tp.pruner.GroupNormPruner], - [tp.importance.LAMPImportance, tp.pruner.GroupNormPruner], - [tp.importance.OBDCImportance, tp.pruner.GroupNormPruner], - [tp.importance.FPGMImportance, tp.pruner.GroupNormPruner], ]: if imp_cls == tp.importance.OBDCImportance: imp = imp_cls(num_classes=1000) diff --git a/torch_pruning/_helpers.py b/torch_pruning/_helpers.py index f16ee8b..f0345ca 100644 --- a/torch_pruning/_helpers.py +++ b/torch_pruning/_helpers.py @@ -80,20 +80,20 @@ def __call__(self, idxs: _HybridIndex): new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs] return new_idxs -class _ExpandIndexMapping(object): - def __init__(self, repeat, reverse=False): +class _GQAIndexMapping(object): + def __init__(self, repeat, head_dim, reverse=False): self.repeat = repeat self.reverse = reverse + self.head_dim = head_dim def __call__(self, idxs: _HybridIndex): - if self.reverse == True: - new_idxs = [ _HybridIndex(idx=i.idx // self.repeat, root_idx=i.root_idx) for i in idxs[::self.repeat]] + head_dim = self.head_dim + repeat = self.repeat + if self.reverse == True: + new_idxs = [ _HybridIndex(idx=( i.idx - i.idx // (head_dim * repeat) * head_dim * (repeat - 1) - i.idx//head_dim%repeat * head_dim ), root_idx=None) for i in idxs ] else: - new_idxs = [ - _HybridIndex(idx = i.idx * self.repeat + j, root_idx=i.root_idx) - for i in idxs - for j in range(self.repeat) - ] + new_idxs = [] + return new_idxs class _SplitIndexMapping(object): diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index 105c224..588ce65 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -499,6 +499,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args): if len(new_indices) == 0: continue + if (new_dep.target in visited_node) and group.has_pruning_op( new_dep, new_indices ): @@ -1199,33 +1200,32 @@ def _update_expand_index_mapping(self, node: Node): #warnings.warn("Expand operation detected but the shape information is not available") return - if len(node.grad_fn._saved_self_sym_sizes) != 5: - return + # for Huggingface GQA only, will support more expand operations in the future + if len(node.grad_fn._saved_self_sym_sizes) == 5: + batch, num_key_value_heads, n_rep, slen, head_dim = node.grad_fn._saved_self_sym_sizes + in_channels = num_key_value_heads * n_rep * head_dim + if out_channels is None or in_channels is None: return + repeat = out_channels // in_channels + addressed_dep = [] - # for Huggingface GQA - batch, num_key_value_heads, n_rep, slen, head_dim = node.grad_fn._saved_self_sym_sizes - in_channels = num_key_value_heads * n_rep * head_dim - if out_channels is None or in_channels is None: return - repeat = out_channels // in_channels - addressed_dep = [] - for i, out_node in enumerate(node.outputs): - for dep in node.dependencies: - if any((dep is d) for d in addressed_dep): continue - if dep.target == out_node: - if node.enable_index_mapping: - dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=False)) - addressed_dep.append(dep) - break - - addressed_dep = [] - for i, out_node in enumerate(node.outputs): - for dep in out_node.dependencies: - if dep.target == node: + for i, in_node in enumerate(node.inputs): + for dep in node.dependencies: if any((dep is d) for d in addressed_dep): continue - if node.enable_index_mapping: - dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=True)) - addressed_dep.append(dep) - break + if dep.target == in_node: + if node.enable_index_mapping: + dep.index_mapping[0] = (_helpers._GQAIndexMapping(repeat=repeat, reverse=True, head_dim=head_dim)) + addressed_dep.append(dep) + break + + addressed_dep = [] + for i, in_node in enumerate(node.inputs): + for dep in in_node.dependencies: + if dep.target == node: + if any((dep is d) for d in addressed_dep): continue + if node.enable_index_mapping: + dep.index_mapping[0] = (_helpers._GQAIndexMapping(repeat=repeat, reverse=False, head_dim=head_dim)) + addressed_dep.append(dep) + break def infer_channels_between(self, node_1, node_2): if node_1.type == ops.OPTYPE.SPLIT: diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index f564ada..fd9f78d 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -242,7 +242,10 @@ def __init__( initial_total_channels = 0 initial_total_heads = 0 for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): - group = self._downstream_node_as_root_if_attention(group) + _is_atten, qkv_layers = self._is_atten_group(group) + if _is_atten: + group = self._downstream_node_as_root_if_attention(group) + if group is None: continue initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) ) for dep, _ in group: if dep.target.module in self.num_heads and self.DG.is_out_channel_pruning_fn(dep.handler): @@ -327,7 +330,7 @@ def _check_pruning_ratio(self, group) -> bool: return False return True - def _is_attn_group(self, group) -> bool: + def _is_atten_group(self, group) -> bool: is_attn = False qkv_layers = [] for dep, _ in group: @@ -339,7 +342,7 @@ def _is_attn_group(self, group) -> bool: return is_attn, qkv_layers def _get_channel_groups(self, group) -> int: - ch_groups = 1 + ch_groups = [] #has_unbind = False #unbind_node = None @@ -349,7 +352,7 @@ def _get_channel_groups(self, group) -> int: channel_groups = self.out_channel_groups if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.in_channel_groups if module in channel_groups: - ch_groups = channel_groups[module] + ch_groups.append(channel_groups[module]) #if dep.source.type==ops.OPTYPE.UNBIND: # has_unbind = True @@ -357,7 +360,9 @@ def _get_channel_groups(self, group) -> int: #if has_unbind and ch_groups>1: # ch_groups = ch_groups // len(unbind_node.outputs) - return ch_groups # no channel grouping + if len(ch_groups) == 0: + return 1 + return max(ch_groups) # no channel grouping def _downstream_node_as_root_if_attention(self, group): # Use a downstream node as the root if torch.unbind exists. TODO: find a general way to handle torch.unbind in timm @@ -371,7 +376,8 @@ def _downstream_node_as_root_if_attention(self, group): idxs = _idxs if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, idxs) - return group + return group + return None def _round_to(self, n_pruned, current_channels, round_to): rounded_channels = current_channels - n_pruned @@ -379,6 +385,7 @@ def _round_to(self, n_pruned, current_channels, round_to): n_pruned = current_channels - rounded_channels return max(n_pruned, 0) + @torch.no_grad() def _prune(self) -> typing.Generator: if self.current_step > self.iterative_steps: @@ -387,54 +394,57 @@ def _prune(self) -> typing.Generator: ############################################## # Initialize ranking scopes - # Will perform indepenedent importance ranking & pruning within each scope + # A scope is a set of layers that will be ranked together to determine their relative importance. # This feature is useful for implementing ranking strategies such as local pruning, global pruning, customized pruning ratios or isomorphic pruning (ECCV 2024): https://arxiv.org/abs/2407.04616 # There are two pre-defined scopes: DEFAULT_SCOPE and ATTN_HEAD_SCOPE - # - DEFAULT_SCOPE: groups will be assigned to this scope if not specified. It is used for simple global pruning. + # - DEFAULT_SCOPE: a group will be assigned to this scope for global ranking if not specified # - ATTN_HEAD_SCOPE: for multi-head attention pruning ############################################## DEFAULT_SCOPE = "DEFAULT_SCOPE" ATTN_HEAD_SCOPE = "ATTN_HEAD_SCOPE" - ranking_scope = {DEFAULT_SCOPE: [], ATTN_HEAD_SCOPE: {}} # ATTN_HEAD_SCOPE will be a dict, because we need to index these groups later + ############################################## # 1. Pre-compute importance for each group and assign them to different scopes ############################################## - - for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): if self._check_pruning_ratio(group): - - # Compute raw importance score - group = self._downstream_node_as_root_if_attention(group) # use a downstream node as the root node for attention layers + # Re-order the group and use a downstream node as the root node for attention layers. + # This will not change the group structure, but make index mapping easier for attention layers. + _is_atten, qkv_layers = self._is_atten_group(group) + if _is_atten: + group = self._downstream_node_as_root_if_attention(group) + if group is None: continue ch_groups = self._get_channel_groups(group) imp = self.estimate_importance(group) # raw importance score group_size = len(imp) // ch_groups - if imp is None: continue - if ch_groups > 1: - # Corresponding elements of each group will be removed together. - # So we average importance across groups here. For example: + if imp is None: continue + if ch_groups > 1: # layers with dimension grouping, such as GroupConv, GroupNorm, Multi-head attention, etc. + # We average importance across groups here. For example: # imp = [1, 2, 3, 4, 5, 6] with ch_groups=2. # We have two groups [1,2,3] and [4,5,6]. # The average importance should be [(1+4)/2, (2+5)/2, (3+6)/2] = [2.5, 3.5, 4.5] - dim_imp = imp.view(ch_groups, -1).mean(dim=0) + dim_imp = imp.view(ch_groups, -1).mean(dim=0).cpu() else: # no grouping - dim_imp = imp + dim_imp = imp.cpu() # Importance scores for Attention Heads - _is_attn, qkv_layers = self._is_attn_group(group) - if _is_attn and self.prune_num_heads and self.get_target_head_pruning_ratio(qkv_layers[0])>0: - # average importance of each group. For example: - # the importance score of the group + _is_atten, qkv_layers = self._is_atten_group(group) + if _is_atten and self.prune_num_heads and self.get_target_head_pruning_ratio(qkv_layers[0])>0: + # average importance over heads + # Example: if we have the importance score: # imp = [1, 2, 3, 4, 5, 6] with num_heads=2 # Note: head1 = [1, 2, 3], head2 = [4, 5, 6] # the average importance is [(1+2+3)/3, (4+5+6)/3] = [2, 5] - head_imp = imp.view(ch_groups, -1).mean(1) # average importance by head. + + # GQA: the number of heads for KV might be different from Q + num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers]) # get the maximum number of heads + head_imp = imp.view(num_heads, -1).mean(1).cpu() # average importance by head. ranking_scope[ATTN_HEAD_SCOPE][group] = (qkv_layers, head_imp) + - - # Scope 1: User-defined scope (Priority 1) + # Scope 1: User-defined pruning ratios is_user_defined_scope = False for dep, _ in group: for module, pruning_fn in zip([dep.source.module, dep.target.module], [dep.trigger, dep.handler]): @@ -447,13 +457,16 @@ def _prune(self) -> typing.Generator: ranking_scope[scope_name] = [] ranking_scope[scope_name].append(record) is_user_defined_scope = True + # A bit messy here. Will refactor in the future. if is_user_defined_scope: break if is_user_defined_scope: break if is_user_defined_scope: continue - record = (group, ch_groups, group_size, self.per_step_pruning_ratio[self.current_step], dim_imp) # otherwise, use the default pruning ratio - # Scope 3: Isomorphic Pruning (Priority 3) + # otherwise, use the default pruning ratio + record = (group, ch_groups, group_size, self.per_step_pruning_ratio[self.current_step], dim_imp) + + # Scope 2: Isomorphic Pruning if self.isomorphic: scope_name = "Isomorphic_" # we transform the graph structure into a string tag for easy comparison for dep, _ in group: # if isomorphic, the source and target modules should have the same **layer type** and **pruning function** @@ -465,10 +478,10 @@ def _prune(self) -> typing.Generator: ranking_scope[scope_name] = [] ranking_scope[scope_name].append(record) - elif self.global_pruning: # Scope 4: global pruning + elif self.global_pruning: # Scope 3: use the default scope for global pruning ranking_scope[DEFAULT_SCOPE].append(record) - else: # Scope 5: local pruning + else: # Scope 4: always create a new scope if local pruning module_name = self.DG._module2name[group[0][0].source.module] ranking_scope[module_name] = [ record ] @@ -478,9 +491,8 @@ def _prune(self) -> typing.Generator: ############################################## # 2. Thresholding by ranking all importance scores within each scope ############################################## - - # Find the threshold for the Multi-head attention scope - if len(ranking_scope[ATTN_HEAD_SCOPE])>0: + # Find the threshold for the Multi-head attention scope if global pruning is enabled + if len(ranking_scope[ATTN_HEAD_SCOPE])>0 and self.global_pruning: concat_head_imp = torch.cat([local_imp[-1] for local_imp in ranking_scope[ATTN_HEAD_SCOPE].values()], dim=0) target_head_pruning_ratio = self.per_step_head_pruning_ratio[self.current_step] n_heads_removed = len(concat_head_imp) - int( @@ -491,14 +503,8 @@ def _prune(self) -> typing.Generator: topk_head_imp, _ = torch.topk(concat_head_imp, k=n_heads_removed, largest=False) head_thres = topk_head_imp[-1] + # Width pruning width_pruning_scope_names = [ k for k in ranking_scope.keys() if k!=ATTN_HEAD_SCOPE] - #for name in width_pruning_scope_names: # truncate the name if lenth exceeds 10 - # print(f"Ranking Scope: {name[:50]} Scope Size={len(ranking_scope[name])}") - # if len(ranking_scope[name])>0: - # for i in range(len(ranking_scope[name])): - # print(ranking_scope[name][i][0], ranking_scope[name][i][-2]) - # Handle other scopes for width pruning. - for scope_id, scope_name in enumerate(width_pruning_scope_names): if not self.global_pruning: @@ -520,7 +526,6 @@ def _prune(self) -> typing.Generator: if n_pruned>0: topk_imp, topk_indices = torch.topk(concat_imp, k=n_pruned, largest=False) thres = topk_imp[-1] - ############################################## # 3. Pruning in each scope ############################################## @@ -528,61 +533,95 @@ def _prune(self) -> typing.Generator: module = group[0].dep.target.module pruning_fn = group[0].dep.handler get_channel_fn = self.DG.get_out_channels if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.DG.get_in_channels - - # Prune feature dims/channels + _is_atten, qkv_layers = self._is_atten_group(group) + + # Prune dims/channels pruning_indices = [] - if len(records)>0 and n_pruned>0: - if ch_groups > 1: # re-compute importance for each channel group if grouping is enabled - if self.global_pruning: # for global pruning, the n_pruned may be shared by multiple layers. For each layer, we should know how many channels/dim should be pruned. - n_pruned_per_group = len((imp <= thres).nonzero().view(-1)) - else: # for local pruning, we can directly use the n_pruned since each scope only contains one layer - n_pruned_per_group = n_pruned - if n_pruned_per_group>0: - if self.round_to: - n_pruned_per_group = self._round_to(n_pruned_per_group, group_size, self.round_to) - _is_attn, _ = self._is_attn_group(group) - if not _is_attn or self.prune_head_dims==True: - raw_imp = self.estimate_importance(group) # re-compute importance - for chg in range(ch_groups): # determine pruning indices for each channel group independently - sub_group_imp = raw_imp[chg*group_size: (chg+1)*group_size] - sub_imp_argsort = torch.argsort(sub_group_imp) - sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group]+chg*group_size - pruning_indices.append(sub_pruning_idxs) - else: # standard pruning - if self.global_pruning: - _pruning_indices = (imp <= thres).nonzero().view(-1) - else: - _pruning_indices = topk_indices - imp_argsort = torch.argsort(imp) - if len(_pruning_indices)>0 and self.round_to: - n_pruned = len(_pruning_indices) - current_channels = get_channel_fn(module) - n_pruned = self._round_to(n_pruned, current_channels, self.round_to) - _pruning_indices = imp_argsort[:n_pruned] + if not _is_atten or self.prune_head_dims: + if self.global_pruning: + _pruning_indices = (imp <= thres).nonzero().view(-1) + else: + _pruning_indices = topk_indices + imp_argsort = torch.argsort(imp) + if len(_pruning_indices)>0 and self.round_to: # recompute the number of pruned channels if round_to is enabled + n_pruned = len(_pruning_indices) + current_channels = get_channel_fn(module) + n_pruned = self._round_to(n_pruned, current_channels, self.round_to) + _pruning_indices = imp_argsort[:n_pruned] + if ch_groups>1: # if channel grouping is enabled, we repeat the pruning indices for each channel group + for g_id in range(ch_groups): + pruning_indices.append(_pruning_indices+g_id*group_size) + else: pruning_indices.append(_pruning_indices) - - # Prune heads - if len(ranking_scope[ATTN_HEAD_SCOPE])>0 and n_heads_removed>0: + + # Prune Attention Heads + if len(ranking_scope[ATTN_HEAD_SCOPE])>0: if group in ranking_scope[ATTN_HEAD_SCOPE]: qkv_layers, head_imp = ranking_scope[ATTN_HEAD_SCOPE][group] - if not self.global_pruning: + num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers]) + _is_gqa = not all([self.num_heads[qkv_layer]==num_heads for qkv_layer in qkv_layers]) + + if not self.global_pruning: # local pruning n_heads_removed_per_group = int(self.get_target_head_pruning_ratio(qkv_layers[0]) * len(head_imp)) - head_pruning_indices = torch.topk(head_imp, k=n_heads_removed_per_group, largest=False)[1] # local ranking - else: + if not _is_gqa: + head_pruning_indices = torch.topk(head_imp, k=n_heads_removed_per_group, largest=False)[1] # local ranking + else: # chunk the head imp + num_kv_heads = min([self.num_heads[qkv_layer] for qkv_layer in qkv_layers]) + num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers]) + n_heads_removed_per_group = n_heads_removed_per_group // num_kv_heads + head_pruning_indices = [] + for kv_head_id in range(num_kv_heads): + head_imp_kv = head_imp[kv_head_id * num_heads//num_kv_heads: (kv_head_id+1) * num_heads//num_kv_heads] + head_pruning_indices_kv = torch.topk(head_imp_kv, k=n_heads_removed_per_group, largest=False)[1] + head_pruning_indices.append(head_pruning_indices_kv + kv_head_id*num_heads//num_kv_heads) + head_pruning_indices = torch.cat(head_pruning_indices, 0) + + else: # global pruning head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1) # global ranking + if _is_gqa: + num_kv_heads = min([self.num_heads[qkv_layer] for qkv_layer in qkv_layers]) + n_heads_removed_per_group = len(head_pruning_indices) // num_kv_heads + head_pruning_indices = [] + for kv_head_id in range(num_kv_heads): + head_imp_kv = head_imp[kv_head_id * len(head_imp)//num_kv_heads: (kv_head_id+1) * len(head_imp)//num_kv_heads] + head_pruning_indices_kv = torch.topk(head_imp_kv, k=n_heads_removed_per_group, largest=False)[1] + head_pruning_indices.append(head_pruning_indices_kv + kv_head_id*num_kv_heads) + head_pruning_indices = torch.cat(head_pruning_indices, 0) + if len(head_pruning_indices)>0: + if len(qkv_layers)==1: + head_dim = qkv_layers[0].out_features // (self.num_heads[qkv_layers[0]]*3) + else: + head_dim = qkv_layers[0].out_features // self.num_heads[qkv_layers[0]] + for head_id in head_pruning_indices: - pruning_indices.append( torch.arange(head_id*group_size, (head_id+1)*group_size, device=head_imp.device) ) + pruning_indices.append( torch.arange(head_id*head_dim, (head_id+1)*head_dim, device=head_imp.device) ) + + num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers]) for qkv_layer in qkv_layers: - self.num_heads[qkv_layer] -= len(head_pruning_indices) # update num heads after pruning - self.out_channel_groups[qkv_layer] = self.num_heads[qkv_layer] # update out_channel_groups + if self.num_heads[qkv_layer] == num_heads: + self.num_heads[qkv_layer] -= len(head_pruning_indices) # update num heads after pruning + self.out_channel_groups[qkv_layer] = self.num_heads[qkv_layer] # update out_channel_groups + if len(pruning_indices)==0: continue pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist() if isinstance(self.importance, OBDCImportance): self.importance.adjust_fisher(group, pruning_indices) + # create pruning group group = self.DG.get_pruning_group( module, pruning_fn, pruning_indices) + + if _is_atten: + _is_gqa = not all([self.num_heads[qkv_layer]==self.num_heads[qkv_layers[0]] for qkv_layer in qkv_layers]) + if _is_gqa and self.prune_num_heads: + num_kv_heads = min([self.num_heads[qkv_layer] for qkv_layer in qkv_layers]) + kv_layers = [qkv_layer for qkv_layer in qkv_layers if self.num_heads[qkv_layer]==num_kv_heads] + for i in range(len(group)): + dep, idxs = group[i] + if dep.target.module in kv_layers: + group[i] = (dep, []) # disable pruning for the kv layers if GQA is enabled + if self.DG.check_pruning_group(group): - yield group # yield the group for interactive pruning \ No newline at end of file + yield group # yield the group for interactive pruning diff --git a/torch_pruning/pruner/importance.py b/torch_pruning/pruner/importance.py index 3e12b24..8b4969f 100644 --- a/torch_pruning/pruner/importance.py +++ b/torch_pruning/pruner/importance.py @@ -11,6 +11,8 @@ import numpy as np from collections import OrderedDict from ..utils.compute_mat_grad import ComputeMatGrad +import random +import warnings __all__ = [ # Base Class @@ -142,9 +144,13 @@ def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[ reduced_imp = torch.ones_like(group_imp[0]) * -99999 else: reduced_imp = torch.zeros_like(group_imp[0]) - + + n_imp = 0 for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)): imp = imp.to(reduced_imp.device) + if any([r is None for r in root_idxs]): + #warnings.warn("Root idxs contain None values. Skipping this layer...") + continue if self.group_reduction == "sum" or self.group_reduction == "mean": reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance elif self.group_reduction == "max": # keep the max importance @@ -165,9 +171,10 @@ def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[ reduced_imp = torch.stack(group_imp, dim=0) # no reduction else: raise NotImplementedError - + n_imp += 1 + if self.group_reduction == "mean": - reduced_imp /= len(group_imp) + reduced_imp /= n_imp return reduced_imp @torch.no_grad() @@ -397,7 +404,9 @@ class RandomImportance(Importance): @torch.no_grad() def __call__(self, group, **kwargs): _, idxs = group[0] - return torch.rand(len(idxs)) + score = list(range(len(idxs))) + random.shuffle(score) + return torch.tensor(score, dtype=torch.float32) class GroupTaylorImportance(GroupNormImportance): @@ -817,3 +826,62 @@ class TaylorImportance(GroupTaylorImportance): class HessianImportance(GroupHessianImportance): pass + +from contextlib import contextmanager + +class ActivationImportance(GroupNormImportance): + + @contextmanager + def compute_importance(self, model): + + @torch.no_grad() + def _compute_importance_hook(module, input, output): + + if isinstance(module, nn.Linear): + dim = input[0].shape[-1] + module._importance = input[0].abs().view(-1, dim).sum(0) + elif isinstance(module, nn.Conv2d): + dim = input[0].shape[1] + module._importance = input[0].abs().mean((0, 2, 3)) + return + + hooks = [] + for m in model.modules(): + if isinstance(m, tuple(self.target_types)): + hooks.append(m.register_forward_hook(_compute_importance_hook)) + + yield + + for h in hooks: + h.remove() + + @torch.no_grad() + def __call__(self, group): + group_imp = [] + group_idxs = [] + for i, (dep, idxs) in enumerate(group): + idxs.sort() + layer = dep.target.module + prune_fn = dep.handler + root_idxs = group[i].root_idxs + + if not isinstance(layer, tuple(self.target_types)): + continue + + # Conv/Linear Output + if prune_fn in [ + function.prune_conv_in_channels, + function.prune_linear_in_channels, + ]: + if not hasattr(layer, "_importance"): + warnings.warn("Layer {} does not have _importance attribute.".format(layer)) + continue + local_imp = layer._importance[idxs] + group_imp.append(local_imp) + group_idxs.append(root_idxs) + + if len(group_imp) == 0: # skip groups without parameterized layers + return None + group_imp = self._reduce(group_imp, group_idxs) + group_imp = self._normalize(group_imp, self.normalizer) + return group_imp \ No newline at end of file