-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathblock_expansion.py
49 lines (37 loc) · 1.8 KB
/
block_expansion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
from transformers import AutoModelForCausalLM
import torch
def main():
# Set up the argument parser
parser = argparse.ArgumentParser(description="Receive deepen model's args")
parser.add_argument("--model_path", default='meta-llama/Llama-2-7b-hf', type=str, help="original model path")
parser.add_argument("--output_path", default='pytorch_model.bin', type=str, help="deepened model ckpt save path")
parser.add_argument("--original_layers", default=32, type=int, help="original model num layers")
parser.add_argument("--layers", default=40, type=int, help="deepen model num layers")
# Parse the arguments
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()
split = int(args.original_layers / (args.layers - args.original_layers))
layer_cnt = 0
output = {}
for i in range(args.original_layers):
for k in ckpt:
if ('layers.' + str(i) + '.') in k:
output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]
layer_cnt += 1
if (i+1) % split == 0:
for k in ckpt:
if ('layers.' + str(i) + '.') in k:
if 'down_proj' in k or 'o_proj' in k:
output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])
else:
output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]
layer_cnt += 1
assert layer_cnt==args.layers
for k in ckpt:
if 'layers' not in k:
output[k] = ckpt[k]
torch.save(output, args.output_path)
if __name__ == "__main__":
main()