diff --git a/examples/LLMs/prune_llama2.py b/examples/LLMs/prune_llama.py
similarity index 93%
rename from examples/LLMs/prune_llama2.py
rename to examples/LLMs/prune_llama.py
index 9f12434..1f30b45 100644
--- a/examples/LLMs/prune_llama2.py
+++ b/examples/LLMs/prune_llama.py
@@ -264,15 +264,10 @@ def main():
parser.add_argument('--model', type=str, help='LLaMA model')
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
- parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
- parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
- parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
- "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
+ parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')
parser.add_argument("--cache_dir", default="./cache", type=str )
- parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
-
parser.add_argument("--eval_zero_shot", action="store_true")
args = parser.parse_args()
@@ -302,15 +297,17 @@ def main():
for name, m in model.named_modules():
if name.endswith("self_attn"):
num_heads[m.q_proj] = model.config.num_attention_heads
- num_heads[m.k_proj] = model.config.num_attention_heads
- num_heads[m.v_proj] = model.config.num_attention_heads
- head_pruning_ratio = 0.5
+ num_heads[m.k_proj] = model.config.num_key_value_heads
+ num_heads[m.v_proj] = model.config.num_key_value_heads
+
+ head_pruning_ratio = args.pruning_ratio
+ hidden_size_pruning_ratio = args.pruning_ratio
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=inputs,
- importance=tp.importance.MagnitudeImportance(),
+ importance=tp.importance.GroupNormImportance(),
global_pruning=False,
- pruning_ratio=0.5,
+ pruning_ratio=hidden_size_pruning_ratio,
ignored_layers=[model.lm_head],
num_heads=num_heads,
prune_num_heads=True,
@@ -320,17 +317,17 @@ def main():
pruner.step()
# Update model attributes
-
num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )
+ num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads )
model.config.num_attention_heads = num_heads
+ model.config.num_key_value_heads = num_key_value_heads
for name, m in model.named_modules():
if name.endswith("self_attn"):
m.hidden_size = m.q_proj.out_features
m.num_heads = num_heads
- m.num_key_value_heads = num_heads
+ m.num_key_value_heads = num_key_value_heads
elif name.endswith("mlp"):
model.config.intermediate_size = m.gate_proj.out_features
-
print("----------------- After Pruning -----------------")
print(model)
diff --git a/examples/LLMs/readme.md b/examples/LLMs/readme.md
index 84ebcc1..78f3073 100644
--- a/examples/LLMs/readme.md
+++ b/examples/LLMs/readme.md
@@ -10,11 +10,91 @@ pip install transformers datasets
## 1. Pruning
+### Llama-3 8B
+
+```bash
+python prune_llama.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
+```
+
+
+Output:
+
+```
+----------------- Before Pruning -----------------
+LlamaForCausalLM(
+ (model): LlamaModel(
+ (embed_tokens): Embedding(128256, 4096)
+ (layers): ModuleList(
+ (0-31): 32 x LlamaDecoderLayer(
+ (self_attn): LlamaSdpaAttention(
+ (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
+ (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
+ (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
+ (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
+ (rotary_emb): LlamaRotaryEmbedding()
+ )
+ (mlp): LlamaMLP(
+ (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
+ (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
+ (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
+ (act_fn): SiLU()
+ )
+ (input_layernorm): LlamaRMSNorm()
+ (post_attention_layernorm): LlamaRMSNorm()
+ )
+ )
+ (norm): LlamaRMSNorm()
+ )
+ (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
+)
+----------------- After Pruning -----------------
+LlamaForCausalLM(
+ (model): LlamaModel(
+ (embed_tokens): Embedding(128256, 2048)
+ (layers): ModuleList(
+ (0-31): 32 x LlamaDecoderLayer(
+ (self_attn): LlamaSdpaAttention(
+ (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
+ (k_proj): Linear(in_features=2048, out_features=512, bias=False)
+ (v_proj): Linear(in_features=2048, out_features=512, bias=False)
+ (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
+ (rotary_emb): LlamaRotaryEmbedding()
+ )
+ (mlp): LlamaMLP(
+ (gate_proj): Linear(in_features=2048, out_features=7168, bias=False)
+ (up_proj): Linear(in_features=2048, out_features=7168, bias=False)
+ (down_proj): Linear(in_features=7168, out_features=2048, bias=False)
+ (act_fn): SiLU()
+ )
+ (input_layernorm): LlamaRMSNorm()
+ (post_attention_layernorm): LlamaRMSNorm()
+ )
+ )
+ (norm): LlamaRMSNorm()
+ )
+ (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
+)
+evaluating on wikitext2
+nsamples 35
+sample 0
+wikitext perplexity 41982.296875
+```
+
+
+
+
+
+
+### Llama-2 7B
+
```bash
-python prune_llama2.py --model meta-llama/Llama-2-7b-hf
+python prune_llama.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
```
-Output:
+
+
+Output:
+
```
----------------- Before Pruning -----------------
LlamaForCausalLM(
@@ -77,4 +157,6 @@ sample 50
wikitext perplexity 9605.4130859375
```
+
+
diff --git a/torch_pruning/_helpers.py b/torch_pruning/_helpers.py
index df71913..f16ee8b 100644
--- a/torch_pruning/_helpers.py
+++ b/torch_pruning/_helpers.py
@@ -80,6 +80,21 @@ def __call__(self, idxs: _HybridIndex):
new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs]
return new_idxs
+class _ExpandIndexMapping(object):
+ def __init__(self, repeat, reverse=False):
+ self.repeat = repeat
+ self.reverse = reverse
+
+ def __call__(self, idxs: _HybridIndex):
+ if self.reverse == True:
+ new_idxs = [ _HybridIndex(idx=i.idx // self.repeat, root_idx=i.root_idx) for i in idxs[::self.repeat]]
+ else:
+ new_idxs = [
+ _HybridIndex(idx = i.idx * self.repeat + j, root_idx=i.root_idx)
+ for i in idxs
+ for j in range(self.repeat)
+ ]
+ return new_idxs
class _SplitIndexMapping(object):
def __init__(self, offset, reverse=False):
diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py
index b4bd3f6..5167c3c 100644
--- a/torch_pruning/dependency.py
+++ b/torch_pruning/dependency.py
@@ -289,6 +289,7 @@ def __init__(self):
ops.OPTYPE.ELEMENTWISE: ops.ElementWisePruner(),
ops.OPTYPE.RESHAPE: ops.ReshapePruner(),
ops.OPTYPE.UNBIND: ops.UnbindPruner(),
+ ops.OPTYPE.EXPAND: ops.ExpandPruner(),
ops.OPTYPE.CUSTOMIZED: ops.CustomizedPruner(), # just a placeholder
}
self.REGISTERED_PRUNERS = function.PrunerBox.copy() # shallow copy
@@ -842,6 +843,10 @@ def create_node_if_not_exists(grad_fn):
elif "unbind" in grad_fn.name().lower():
module = ops._UnbindOp(self._op_id)
self._op_id+=1
+ elif "expand" in grad_fn.name().lower():
+ module = ops._ExpandOp(self._op_id)
+ # print all attributes
+ self._op_id+=1
elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower():
module = ops._ReshapeOp(self._op_id)
self._op_id+=1
@@ -914,6 +919,8 @@ def update_index_mapping(self):
self._update_reshape_index_mapping(node)
if node.type == ops.OPTYPE.UNBIND:
self._update_unbind_index_mapping(node)
+ if node.type == ops.OPTYPE.EXPAND and torch.__version__ >= "1.8":
+ self._update_expand_index_mapping(node)
def _init_shape_information(self):
for module, node in self.module2node.items():
@@ -949,7 +956,7 @@ def _init_shape_information(self):
offsets.append(offsets[-1] + ch)
node.module.split_sizes = chs
node.module.offsets = offsets
-
+
def _update_flatten_index_mapping(self, fc_node: Node):
if fc_node.type != ops.OPTYPE.LINEAR:
return
@@ -1177,6 +1184,47 @@ def _update_unbind_index_mapping(self, unbind_node: Node):
addressed_dep.append(dep)
break
+ def _update_expand_index_mapping(self, node: Node):
+ out_channels = None
+ for n in node.outputs:
+ recursive_depth = [0]
+ out_channels = self._infer_in_channels_recursively(n, recursive_depth)
+ if recursive_depth[0] > MAX_RECURSION_DEPTH:
+ return
+ if out_channels is not None: # =0 if there is a residual connection to model inputs
+ break
+ if not hasattr(node.grad_fn, '_saved_self_sym_sizes'):
+ #warnings.warn("Expand operation detected but the shape information is not available")
+ return
+
+ if len(node.grad_fn._saved_self_sym_sizes) != 5:
+ return
+
+ # for Huggingface GQA
+ batch, num_key_value_heads, n_rep, slen, head_dim = node.grad_fn._saved_self_sym_sizes
+ in_channels = num_key_value_heads * n_rep * head_dim
+ if out_channels is None or in_channels is None: return
+ repeat = out_channels // in_channels
+ addressed_dep = []
+ for i, out_node in enumerate(node.outputs):
+ for dep in node.dependencies:
+ if any((dep is d) for d in addressed_dep): continue
+ if dep.target == out_node:
+ if node.enable_index_mapping:
+ dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=False))
+ addressed_dep.append(dep)
+ break
+
+ addressed_dep = []
+ for i, out_node in enumerate(node.outputs):
+ for dep in out_node.dependencies:
+ if dep.target == node:
+ if any((dep is d) for d in addressed_dep): continue
+ if node.enable_index_mapping:
+ dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=True))
+ addressed_dep.append(dep)
+ break
+
def infer_channels_between(self, node_1, node_2):
if node_1.type == ops.OPTYPE.SPLIT:
for i, n in enumerate(node_1.outputs):
diff --git a/torch_pruning/ops.py b/torch_pruning/ops.py
index 3befdca..a54f076 100644
--- a/torch_pruning/ops.py
+++ b/torch_pruning/ops.py
@@ -62,6 +62,13 @@ def __init__(self, id, grad_fn):
def __repr__(self):
return "_ElementWiseOp_{}({})".format(self.id, self._grad_fn)
+class _ExpandOp(nn.Module):
+ def __init__(self, id):
+ super(_ExpandOp, self).__init__()
+ self.id = id
+
+ def __repr__(self):
+ return "_ExpandOp_{}()".format(self.id)
######################################################
# Dummy Pruners
@@ -90,6 +97,9 @@ def get_out_channel_groups(self, layer):
class UnbindPruner(DummyPruner):
pass
+class ExpandPruner(DummyPruner):
+ pass
+
class ConcatPruner(DummyPruner):
def prune_out_channels(self, layer, idxs):
if layer.concat_sizes is None:
@@ -188,6 +198,7 @@ class OPTYPE(IntEnum):
GN = 15 # nn.GroupNorm
IN = 16 # nn.InstanceNorm
UNBIND = 17
+ EXPAND = 18
def module2type(module):
@@ -226,6 +237,8 @@ def module2type(module):
return OPTYPE.RESHAPE
elif isinstance(module, _UnbindOp):
return OPTYPE.UNBIND
+ elif isinstance(module, _ExpandOp):
+ return OPTYPE.EXPAND
else:
return OPTYPE.ELEMENTWISE
@@ -263,6 +276,8 @@ def type2class(op_type):
return _ReshapeOp
elif OPTYPE == OPTYPE.UNBIND:
return _UnbindOp
+ elif OPTYPE == OPTYPE.EXPAND:
+ return _ExpandOp
else:
return _ElementWiseOp
diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py
index cf8a3e6..4abebef 100644
--- a/torch_pruning/pruner/algorithms/metapruner.py
+++ b/torch_pruning/pruner/algorithms/metapruner.py
@@ -339,8 +339,9 @@ def _downstream_node_as_root_if_attention(self, group):
is_attention = True
if isinstance(_dep.target.module, tuple(self.root_module_types)) and self.DG.is_in_channel_pruning_fn(_dep.handler):
downstream_dep = _dep
+ idxs = _idxs
if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers
- group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, _idxs)
+ group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, idxs)
return group
def _round_to(self, n_pruned, current_channels, round_to):