Skip to content

Commit

Permalink
Add examples for LLama-2 & 3
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Jun 4, 2024
1 parent a06bbcc commit 2864357
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 18 deletions.
25 changes: 11 additions & 14 deletions examples/LLMs/prune_llama2.py → examples/LLMs/prune_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,10 @@ def main():
parser.add_argument('--model', type=str, help='LLaMA model')
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
"ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')
parser.add_argument("--cache_dir", default="./cache", type=str )
parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')

parser.add_argument("--eval_zero_shot", action="store_true")
args = parser.parse_args()

Expand Down Expand Up @@ -302,15 +297,17 @@ def main():
for name, m in model.named_modules():
if name.endswith("self_attn"):
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_attention_heads
num_heads[m.v_proj] = model.config.num_attention_heads
head_pruning_ratio = 0.5
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads

head_pruning_ratio = args.pruning_ratio
hidden_size_pruning_ratio = args.pruning_ratio
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=inputs,
importance=tp.importance.MagnitudeImportance(),
importance=tp.importance.GroupNormImportance(),
global_pruning=False,
pruning_ratio=0.5,
pruning_ratio=hidden_size_pruning_ratio,
ignored_layers=[model.lm_head],
num_heads=num_heads,
prune_num_heads=True,
Expand All @@ -320,17 +317,17 @@ def main():
pruner.step()

# Update model attributes

num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )
num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads )
model.config.num_attention_heads = num_heads
model.config.num_key_value_heads = num_key_value_heads
for name, m in model.named_modules():
if name.endswith("self_attn"):
m.hidden_size = m.q_proj.out_features
m.num_heads = num_heads
m.num_key_value_heads = num_heads
m.num_key_value_heads = num_key_value_heads
elif name.endswith("mlp"):
model.config.intermediate_size = m.gate_proj.out_features

print("----------------- After Pruning -----------------")
print(model)

Expand Down
86 changes: 84 additions & 2 deletions examples/LLMs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,91 @@ pip install transformers datasets

## 1. Pruning

### Llama-3 8B

```bash
python prune_llama.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
```

<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
----------------- After Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 2048)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=2048, out_features=7168, bias=False)
(up_proj): Linear(in_features=2048, out_features=7168, bias=False)
(down_proj): Linear(in_features=7168, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)
evaluating on wikitext2
nsamples 35
sample 0
wikitext perplexity 41982.296875
```

</details>




### Llama-2 7B

```bash
python prune_llama2.py --model meta-llama/Llama-2-7b-hf
python prune_llama.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
```

Output:

<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
LlamaForCausalLM(
Expand Down Expand Up @@ -77,4 +157,6 @@ sample 50
wikitext perplexity 9605.4130859375
```

</details>


15 changes: 15 additions & 0 deletions torch_pruning/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ def __call__(self, idxs: _HybridIndex):
new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs]
return new_idxs

class _ExpandIndexMapping(object):
def __init__(self, repeat, reverse=False):
self.repeat = repeat
self.reverse = reverse

def __call__(self, idxs: _HybridIndex):
if self.reverse == True:
new_idxs = [ _HybridIndex(idx=i.idx // self.repeat, root_idx=i.root_idx) for i in idxs[::self.repeat]]
else:
new_idxs = [
_HybridIndex(idx = i.idx * self.repeat + j, root_idx=i.root_idx)
for i in idxs
for j in range(self.repeat)
]
return new_idxs

class _SplitIndexMapping(object):
def __init__(self, offset, reverse=False):
Expand Down
44 changes: 43 additions & 1 deletion torch_pruning/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def __init__(self):
ops.OPTYPE.ELEMENTWISE: ops.ElementWisePruner(),
ops.OPTYPE.RESHAPE: ops.ReshapePruner(),
ops.OPTYPE.UNBIND: ops.UnbindPruner(),
ops.OPTYPE.EXPAND: ops.ExpandPruner(),
ops.OPTYPE.CUSTOMIZED: ops.CustomizedPruner(), # just a placeholder
}
self.REGISTERED_PRUNERS = function.PrunerBox.copy() # shallow copy
Expand Down Expand Up @@ -842,6 +843,10 @@ def create_node_if_not_exists(grad_fn):
elif "unbind" in grad_fn.name().lower():
module = ops._UnbindOp(self._op_id)
self._op_id+=1
elif "expand" in grad_fn.name().lower():
module = ops._ExpandOp(self._op_id)
# print all attributes
self._op_id+=1
elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower():
module = ops._ReshapeOp(self._op_id)
self._op_id+=1
Expand Down Expand Up @@ -914,6 +919,8 @@ def update_index_mapping(self):
self._update_reshape_index_mapping(node)
if node.type == ops.OPTYPE.UNBIND:
self._update_unbind_index_mapping(node)
if node.type == ops.OPTYPE.EXPAND:
self._update_expand_index_mapping(node)

def _init_shape_information(self):
for module, node in self.module2node.items():
Expand Down Expand Up @@ -949,7 +956,7 @@ def _init_shape_information(self):
offsets.append(offsets[-1] + ch)
node.module.split_sizes = chs
node.module.offsets = offsets

def _update_flatten_index_mapping(self, fc_node: Node):
if fc_node.type != ops.OPTYPE.LINEAR:
return
Expand Down Expand Up @@ -1177,6 +1184,41 @@ def _update_unbind_index_mapping(self, unbind_node: Node):
addressed_dep.append(dep)
break

def _update_expand_index_mapping(self, node: Node):
out_channels = None
for n in node.outputs:
recursive_depth = [0]
out_channels = self._infer_in_channels_recursively(n, recursive_depth)
if recursive_depth[0] > MAX_RECURSION_DEPTH:
return
if out_channels is not None: # =0 if there is a residual connection to model inputs
break
assert hasattr(node.grad_fn, '_saved_self_sym_sizes'), "New version of PyTorch is required for expand operation."
batch, num_key_value_heads, n_rep, slen, head_dim = node.grad_fn._saved_self_sym_sizes
in_channels = num_key_value_heads * n_rep * head_dim
if out_channels is None or in_channels is None: return

repeat = out_channels // in_channels
addressed_dep = []
for i, out_node in enumerate(node.outputs):
for dep in node.dependencies:
if any((dep is d) for d in addressed_dep): continue
if dep.target == out_node:
if node.enable_index_mapping:
dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=False))
addressed_dep.append(dep)
break

addressed_dep = []
for i, out_node in enumerate(node.outputs):
for dep in out_node.dependencies:
if dep.target == node:
if any((dep is d) for d in addressed_dep): continue
if node.enable_index_mapping:
dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=True))
addressed_dep.append(dep)
break

def infer_channels_between(self, node_1, node_2):
if node_1.type == ops.OPTYPE.SPLIT:
for i, n in enumerate(node_1.outputs):
Expand Down
15 changes: 15 additions & 0 deletions torch_pruning/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def __init__(self, id, grad_fn):
def __repr__(self):
return "_ElementWiseOp_{}({})".format(self.id, self._grad_fn)

class _ExpandOp(nn.Module):
def __init__(self, id):
super(_ExpandOp, self).__init__()
self.id = id

def __repr__(self):
return "_ExpandOp_{}()".format(self.id)

######################################################
# Dummy Pruners
Expand Down Expand Up @@ -90,6 +97,9 @@ def get_out_channel_groups(self, layer):
class UnbindPruner(DummyPruner):
pass

class ExpandPruner(DummyPruner):
pass

class ConcatPruner(DummyPruner):
def prune_out_channels(self, layer, idxs):
if layer.concat_sizes is None:
Expand Down Expand Up @@ -188,6 +198,7 @@ class OPTYPE(IntEnum):
GN = 15 # nn.GroupNorm
IN = 16 # nn.InstanceNorm
UNBIND = 17
EXPAND = 18


def module2type(module):
Expand Down Expand Up @@ -226,6 +237,8 @@ def module2type(module):
return OPTYPE.RESHAPE
elif isinstance(module, _UnbindOp):
return OPTYPE.UNBIND
elif isinstance(module, _ExpandOp):
return OPTYPE.EXPAND
else:
return OPTYPE.ELEMENTWISE

Expand Down Expand Up @@ -263,6 +276,8 @@ def type2class(op_type):
return _ReshapeOp
elif OPTYPE == OPTYPE.UNBIND:
return _UnbindOp
elif OPTYPE == OPTYPE.EXPAND:
return _ExpandOp
else:
return _ElementWiseOp

3 changes: 2 additions & 1 deletion torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,9 @@ def _downstream_node_as_root_if_attention(self, group):
is_attention = True
if isinstance(_dep.target.module, tuple(self.root_module_types)) and self.DG.is_in_channel_pruning_fn(_dep.handler):
downstream_dep = _dep
idxs = _idxs
if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers
group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, _idxs)
group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, idxs)
return group

def _round_to(self, n_pruned, current_channels, round_to):
Expand Down

0 comments on commit 2864357

Please sign in to comment.