diff --git a/examples/LLMs/prune_llama2.py b/examples/LLMs/prune_llama.py similarity index 93% rename from examples/LLMs/prune_llama2.py rename to examples/LLMs/prune_llama.py index 9f12434..1f30b45 100644 --- a/examples/LLMs/prune_llama2.py +++ b/examples/LLMs/prune_llama.py @@ -264,15 +264,10 @@ def main(): parser.add_argument('--model', type=str, help='LLaMA model') parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.') - parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level') - parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"]) - parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt", - "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"]) + parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level') parser.add_argument("--cache_dir", default="./cache", type=str ) - parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix") parser.add_argument('--save', type=str, default=None, help='Path to save results.') parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.') - parser.add_argument("--eval_zero_shot", action="store_true") args = parser.parse_args() @@ -302,15 +297,17 @@ def main(): for name, m in model.named_modules(): if name.endswith("self_attn"): num_heads[m.q_proj] = model.config.num_attention_heads - num_heads[m.k_proj] = model.config.num_attention_heads - num_heads[m.v_proj] = model.config.num_attention_heads - head_pruning_ratio = 0.5 + num_heads[m.k_proj] = model.config.num_key_value_heads + num_heads[m.v_proj] = model.config.num_key_value_heads + + head_pruning_ratio = args.pruning_ratio + hidden_size_pruning_ratio = args.pruning_ratio pruner = tp.pruner.MagnitudePruner( model, example_inputs=inputs, - importance=tp.importance.MagnitudeImportance(), + importance=tp.importance.GroupNormImportance(), global_pruning=False, - pruning_ratio=0.5, + pruning_ratio=hidden_size_pruning_ratio, ignored_layers=[model.lm_head], num_heads=num_heads, prune_num_heads=True, @@ -320,17 +317,17 @@ def main(): pruner.step() # Update model attributes - num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads ) + num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads ) model.config.num_attention_heads = num_heads + model.config.num_key_value_heads = num_key_value_heads for name, m in model.named_modules(): if name.endswith("self_attn"): m.hidden_size = m.q_proj.out_features m.num_heads = num_heads - m.num_key_value_heads = num_heads + m.num_key_value_heads = num_key_value_heads elif name.endswith("mlp"): model.config.intermediate_size = m.gate_proj.out_features - print("----------------- After Pruning -----------------") print(model) diff --git a/examples/LLMs/readme.md b/examples/LLMs/readme.md index 84ebcc1..78f3073 100644 --- a/examples/LLMs/readme.md +++ b/examples/LLMs/readme.md @@ -10,11 +10,91 @@ pip install transformers datasets ## 1. Pruning +### Llama-3 8B + +```bash +python prune_llama.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5 +``` + +
+Output: + +``` +----------------- Before Pruning ----------------- +LlamaForCausalLM( + (model): LlamaModel( + (embed_tokens): Embedding(128256, 4096) + (layers): ModuleList( + (0-31): 32 x LlamaDecoderLayer( + (self_attn): LlamaSdpaAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=1024, bias=False) + (v_proj): Linear(in_features=4096, out_features=1024, bias=False) + (o_proj): Linear(in_features=4096, out_features=4096, bias=False) + (rotary_emb): LlamaRotaryEmbedding() + ) + (mlp): LlamaMLP( + (gate_proj): Linear(in_features=4096, out_features=14336, bias=False) + (up_proj): Linear(in_features=4096, out_features=14336, bias=False) + (down_proj): Linear(in_features=14336, out_features=4096, bias=False) + (act_fn): SiLU() + ) + (input_layernorm): LlamaRMSNorm() + (post_attention_layernorm): LlamaRMSNorm() + ) + ) + (norm): LlamaRMSNorm() + ) + (lm_head): Linear(in_features=4096, out_features=128256, bias=False) +) +----------------- After Pruning ----------------- +LlamaForCausalLM( + (model): LlamaModel( + (embed_tokens): Embedding(128256, 2048) + (layers): ModuleList( + (0-31): 32 x LlamaDecoderLayer( + (self_attn): LlamaSdpaAttention( + (q_proj): Linear(in_features=2048, out_features=2048, bias=False) + (k_proj): Linear(in_features=2048, out_features=512, bias=False) + (v_proj): Linear(in_features=2048, out_features=512, bias=False) + (o_proj): Linear(in_features=2048, out_features=2048, bias=False) + (rotary_emb): LlamaRotaryEmbedding() + ) + (mlp): LlamaMLP( + (gate_proj): Linear(in_features=2048, out_features=7168, bias=False) + (up_proj): Linear(in_features=2048, out_features=7168, bias=False) + (down_proj): Linear(in_features=7168, out_features=2048, bias=False) + (act_fn): SiLU() + ) + (input_layernorm): LlamaRMSNorm() + (post_attention_layernorm): LlamaRMSNorm() + ) + ) + (norm): LlamaRMSNorm() + ) + (lm_head): Linear(in_features=2048, out_features=128256, bias=False) +) +evaluating on wikitext2 +nsamples 35 +sample 0 +wikitext perplexity 41982.296875 +``` + +
+ + + + +### Llama-2 7B + ```bash -python prune_llama2.py --model meta-llama/Llama-2-7b-hf +python prune_llama.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5 ``` -Output: + +
+Output: + ``` ----------------- Before Pruning ----------------- LlamaForCausalLM( @@ -77,4 +157,6 @@ sample 50 wikitext perplexity 9605.4130859375 ``` +
+ diff --git a/torch_pruning/_helpers.py b/torch_pruning/_helpers.py index df71913..f16ee8b 100644 --- a/torch_pruning/_helpers.py +++ b/torch_pruning/_helpers.py @@ -80,6 +80,21 @@ def __call__(self, idxs: _HybridIndex): new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs] return new_idxs +class _ExpandIndexMapping(object): + def __init__(self, repeat, reverse=False): + self.repeat = repeat + self.reverse = reverse + + def __call__(self, idxs: _HybridIndex): + if self.reverse == True: + new_idxs = [ _HybridIndex(idx=i.idx // self.repeat, root_idx=i.root_idx) for i in idxs[::self.repeat]] + else: + new_idxs = [ + _HybridIndex(idx = i.idx * self.repeat + j, root_idx=i.root_idx) + for i in idxs + for j in range(self.repeat) + ] + return new_idxs class _SplitIndexMapping(object): def __init__(self, offset, reverse=False): diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index b4bd3f6..5167c3c 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -289,6 +289,7 @@ def __init__(self): ops.OPTYPE.ELEMENTWISE: ops.ElementWisePruner(), ops.OPTYPE.RESHAPE: ops.ReshapePruner(), ops.OPTYPE.UNBIND: ops.UnbindPruner(), + ops.OPTYPE.EXPAND: ops.ExpandPruner(), ops.OPTYPE.CUSTOMIZED: ops.CustomizedPruner(), # just a placeholder } self.REGISTERED_PRUNERS = function.PrunerBox.copy() # shallow copy @@ -842,6 +843,10 @@ def create_node_if_not_exists(grad_fn): elif "unbind" in grad_fn.name().lower(): module = ops._UnbindOp(self._op_id) self._op_id+=1 + elif "expand" in grad_fn.name().lower(): + module = ops._ExpandOp(self._op_id) + # print all attributes + self._op_id+=1 elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower(): module = ops._ReshapeOp(self._op_id) self._op_id+=1 @@ -914,6 +919,8 @@ def update_index_mapping(self): self._update_reshape_index_mapping(node) if node.type == ops.OPTYPE.UNBIND: self._update_unbind_index_mapping(node) + if node.type == ops.OPTYPE.EXPAND and torch.__version__ >= "1.8": + self._update_expand_index_mapping(node) def _init_shape_information(self): for module, node in self.module2node.items(): @@ -949,7 +956,7 @@ def _init_shape_information(self): offsets.append(offsets[-1] + ch) node.module.split_sizes = chs node.module.offsets = offsets - + def _update_flatten_index_mapping(self, fc_node: Node): if fc_node.type != ops.OPTYPE.LINEAR: return @@ -1177,6 +1184,47 @@ def _update_unbind_index_mapping(self, unbind_node: Node): addressed_dep.append(dep) break + def _update_expand_index_mapping(self, node: Node): + out_channels = None + for n in node.outputs: + recursive_depth = [0] + out_channels = self._infer_in_channels_recursively(n, recursive_depth) + if recursive_depth[0] > MAX_RECURSION_DEPTH: + return + if out_channels is not None: # =0 if there is a residual connection to model inputs + break + if not hasattr(node.grad_fn, '_saved_self_sym_sizes'): + #warnings.warn("Expand operation detected but the shape information is not available") + return + + if len(node.grad_fn._saved_self_sym_sizes) != 5: + return + + # for Huggingface GQA + batch, num_key_value_heads, n_rep, slen, head_dim = node.grad_fn._saved_self_sym_sizes + in_channels = num_key_value_heads * n_rep * head_dim + if out_channels is None or in_channels is None: return + repeat = out_channels // in_channels + addressed_dep = [] + for i, out_node in enumerate(node.outputs): + for dep in node.dependencies: + if any((dep is d) for d in addressed_dep): continue + if dep.target == out_node: + if node.enable_index_mapping: + dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=False)) + addressed_dep.append(dep) + break + + addressed_dep = [] + for i, out_node in enumerate(node.outputs): + for dep in out_node.dependencies: + if dep.target == node: + if any((dep is d) for d in addressed_dep): continue + if node.enable_index_mapping: + dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=True)) + addressed_dep.append(dep) + break + def infer_channels_between(self, node_1, node_2): if node_1.type == ops.OPTYPE.SPLIT: for i, n in enumerate(node_1.outputs): diff --git a/torch_pruning/ops.py b/torch_pruning/ops.py index 3befdca..a54f076 100644 --- a/torch_pruning/ops.py +++ b/torch_pruning/ops.py @@ -62,6 +62,13 @@ def __init__(self, id, grad_fn): def __repr__(self): return "_ElementWiseOp_{}({})".format(self.id, self._grad_fn) +class _ExpandOp(nn.Module): + def __init__(self, id): + super(_ExpandOp, self).__init__() + self.id = id + + def __repr__(self): + return "_ExpandOp_{}()".format(self.id) ###################################################### # Dummy Pruners @@ -90,6 +97,9 @@ def get_out_channel_groups(self, layer): class UnbindPruner(DummyPruner): pass +class ExpandPruner(DummyPruner): + pass + class ConcatPruner(DummyPruner): def prune_out_channels(self, layer, idxs): if layer.concat_sizes is None: @@ -188,6 +198,7 @@ class OPTYPE(IntEnum): GN = 15 # nn.GroupNorm IN = 16 # nn.InstanceNorm UNBIND = 17 + EXPAND = 18 def module2type(module): @@ -226,6 +237,8 @@ def module2type(module): return OPTYPE.RESHAPE elif isinstance(module, _UnbindOp): return OPTYPE.UNBIND + elif isinstance(module, _ExpandOp): + return OPTYPE.EXPAND else: return OPTYPE.ELEMENTWISE @@ -263,6 +276,8 @@ def type2class(op_type): return _ReshapeOp elif OPTYPE == OPTYPE.UNBIND: return _UnbindOp + elif OPTYPE == OPTYPE.EXPAND: + return _ExpandOp else: return _ElementWiseOp diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index cf8a3e6..4abebef 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -339,8 +339,9 @@ def _downstream_node_as_root_if_attention(self, group): is_attention = True if isinstance(_dep.target.module, tuple(self.root_module_types)) and self.DG.is_in_channel_pruning_fn(_dep.handler): downstream_dep = _dep + idxs = _idxs if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers - group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, _idxs) + group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, idxs) return group def _round_to(self, n_pruned, current_channels, round_to):