Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V2.0 #385

Merged
merged 3 commits into from
Jun 4, 2024
Merged

V2.0 #385

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions examples/LLMs/prune_llama2.py → examples/LLMs/prune_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,10 @@ def main():
parser.add_argument('--model', type=str, help='LLaMA model')
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
"ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')
parser.add_argument("--cache_dir", default="./cache", type=str )
parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')

parser.add_argument("--eval_zero_shot", action="store_true")
args = parser.parse_args()

Expand Down Expand Up @@ -302,15 +297,17 @@ def main():
for name, m in model.named_modules():
if name.endswith("self_attn"):
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_attention_heads
num_heads[m.v_proj] = model.config.num_attention_heads
head_pruning_ratio = 0.5
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads

head_pruning_ratio = args.pruning_ratio
hidden_size_pruning_ratio = args.pruning_ratio
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=inputs,
importance=tp.importance.MagnitudeImportance(),
importance=tp.importance.GroupNormImportance(),
global_pruning=False,
pruning_ratio=0.5,
pruning_ratio=hidden_size_pruning_ratio,
ignored_layers=[model.lm_head],
num_heads=num_heads,
prune_num_heads=True,
Expand All @@ -320,17 +317,17 @@ def main():
pruner.step()

# Update model attributes

num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )
num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads )
model.config.num_attention_heads = num_heads
model.config.num_key_value_heads = num_key_value_heads
for name, m in model.named_modules():
if name.endswith("self_attn"):
m.hidden_size = m.q_proj.out_features
m.num_heads = num_heads
m.num_key_value_heads = num_heads
m.num_key_value_heads = num_key_value_heads
elif name.endswith("mlp"):
model.config.intermediate_size = m.gate_proj.out_features

print("----------------- After Pruning -----------------")
print(model)

Expand Down
86 changes: 84 additions & 2 deletions examples/LLMs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,91 @@ pip install transformers datasets

## 1. Pruning

### Llama-3 8B

```bash
python prune_llama.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
```

<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
----------------- After Pruning -----------------
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 2048)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=2048, out_features=7168, bias=False)
(up_proj): Linear(in_features=2048, out_features=7168, bias=False)
(down_proj): Linear(in_features=7168, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)
evaluating on wikitext2
nsamples 35
sample 0
wikitext perplexity 41982.296875
```

</details>




### Llama-2 7B

```bash
python prune_llama2.py --model meta-llama/Llama-2-7b-hf
python prune_llama.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
```

Output:

<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
LlamaForCausalLM(
Expand Down Expand Up @@ -77,4 +157,6 @@ sample 50
wikitext perplexity 9605.4130859375
```

</details>


15 changes: 15 additions & 0 deletions torch_pruning/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ def __call__(self, idxs: _HybridIndex):
new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs]
return new_idxs

class _ExpandIndexMapping(object):
def __init__(self, repeat, reverse=False):
self.repeat = repeat
self.reverse = reverse

def __call__(self, idxs: _HybridIndex):
if self.reverse == True:
new_idxs = [ _HybridIndex(idx=i.idx // self.repeat, root_idx=i.root_idx) for i in idxs[::self.repeat]]
else:
new_idxs = [
_HybridIndex(idx = i.idx * self.repeat + j, root_idx=i.root_idx)
for i in idxs
for j in range(self.repeat)
]
return new_idxs

class _SplitIndexMapping(object):
def __init__(self, offset, reverse=False):
Expand Down
50 changes: 49 additions & 1 deletion torch_pruning/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def __init__(self):
ops.OPTYPE.ELEMENTWISE: ops.ElementWisePruner(),
ops.OPTYPE.RESHAPE: ops.ReshapePruner(),
ops.OPTYPE.UNBIND: ops.UnbindPruner(),
ops.OPTYPE.EXPAND: ops.ExpandPruner(),
ops.OPTYPE.CUSTOMIZED: ops.CustomizedPruner(), # just a placeholder
}
self.REGISTERED_PRUNERS = function.PrunerBox.copy() # shallow copy
Expand Down Expand Up @@ -842,6 +843,10 @@ def create_node_if_not_exists(grad_fn):
elif "unbind" in grad_fn.name().lower():
module = ops._UnbindOp(self._op_id)
self._op_id+=1
elif "expand" in grad_fn.name().lower():
module = ops._ExpandOp(self._op_id)
# print all attributes
self._op_id+=1
elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower():
module = ops._ReshapeOp(self._op_id)
self._op_id+=1
Expand Down Expand Up @@ -914,6 +919,8 @@ def update_index_mapping(self):
self._update_reshape_index_mapping(node)
if node.type == ops.OPTYPE.UNBIND:
self._update_unbind_index_mapping(node)
if node.type == ops.OPTYPE.EXPAND and torch.__version__ >= "1.8":
self._update_expand_index_mapping(node)

def _init_shape_information(self):
for module, node in self.module2node.items():
Expand Down Expand Up @@ -949,7 +956,7 @@ def _init_shape_information(self):
offsets.append(offsets[-1] + ch)
node.module.split_sizes = chs
node.module.offsets = offsets

def _update_flatten_index_mapping(self, fc_node: Node):
if fc_node.type != ops.OPTYPE.LINEAR:
return
Expand Down Expand Up @@ -1177,6 +1184,47 @@ def _update_unbind_index_mapping(self, unbind_node: Node):
addressed_dep.append(dep)
break

def _update_expand_index_mapping(self, node: Node):
out_channels = None
for n in node.outputs:
recursive_depth = [0]
out_channels = self._infer_in_channels_recursively(n, recursive_depth)
if recursive_depth[0] > MAX_RECURSION_DEPTH:
return
if out_channels is not None: # =0 if there is a residual connection to model inputs
break
if not hasattr(node.grad_fn, '_saved_self_sym_sizes'):
#warnings.warn("Expand operation detected but the shape information is not available")
return

if len(node.grad_fn._saved_self_sym_sizes) != 5:
return

# for Huggingface GQA
batch, num_key_value_heads, n_rep, slen, head_dim = node.grad_fn._saved_self_sym_sizes
in_channels = num_key_value_heads * n_rep * head_dim
if out_channels is None or in_channels is None: return
repeat = out_channels // in_channels
addressed_dep = []
for i, out_node in enumerate(node.outputs):
for dep in node.dependencies:
if any((dep is d) for d in addressed_dep): continue
if dep.target == out_node:
if node.enable_index_mapping:
dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=False))
addressed_dep.append(dep)
break

addressed_dep = []
for i, out_node in enumerate(node.outputs):
for dep in out_node.dependencies:
if dep.target == node:
if any((dep is d) for d in addressed_dep): continue
if node.enable_index_mapping:
dep.index_mapping[0] = (_helpers._ExpandIndexMapping(repeat=repeat, reverse=True))
addressed_dep.append(dep)
break

def infer_channels_between(self, node_1, node_2):
if node_1.type == ops.OPTYPE.SPLIT:
for i, n in enumerate(node_1.outputs):
Expand Down
15 changes: 15 additions & 0 deletions torch_pruning/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def __init__(self, id, grad_fn):
def __repr__(self):
return "_ElementWiseOp_{}({})".format(self.id, self._grad_fn)

class _ExpandOp(nn.Module):
def __init__(self, id):
super(_ExpandOp, self).__init__()
self.id = id

def __repr__(self):
return "_ExpandOp_{}()".format(self.id)

######################################################
# Dummy Pruners
Expand Down Expand Up @@ -90,6 +97,9 @@ def get_out_channel_groups(self, layer):
class UnbindPruner(DummyPruner):
pass

class ExpandPruner(DummyPruner):
pass

class ConcatPruner(DummyPruner):
def prune_out_channels(self, layer, idxs):
if layer.concat_sizes is None:
Expand Down Expand Up @@ -188,6 +198,7 @@ class OPTYPE(IntEnum):
GN = 15 # nn.GroupNorm
IN = 16 # nn.InstanceNorm
UNBIND = 17
EXPAND = 18


def module2type(module):
Expand Down Expand Up @@ -226,6 +237,8 @@ def module2type(module):
return OPTYPE.RESHAPE
elif isinstance(module, _UnbindOp):
return OPTYPE.UNBIND
elif isinstance(module, _ExpandOp):
return OPTYPE.EXPAND
else:
return OPTYPE.ELEMENTWISE

Expand Down Expand Up @@ -263,6 +276,8 @@ def type2class(op_type):
return _ReshapeOp
elif OPTYPE == OPTYPE.UNBIND:
return _UnbindOp
elif OPTYPE == OPTYPE.EXPAND:
return _ExpandOp
else:
return _ElementWiseOp

3 changes: 2 additions & 1 deletion torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,9 @@ def _downstream_node_as_root_if_attention(self, group):
is_attention = True
if isinstance(_dep.target.module, tuple(self.root_module_types)) and self.DG.is_in_channel_pruning_fn(_dep.handler):
downstream_dep = _dep
idxs = _idxs
if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers
group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, _idxs)
group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, idxs)
return group

def _round_to(self, n_pruned, current_channels, round_to):
Expand Down
Loading