-
Notifications
You must be signed in to change notification settings - Fork 74
/
llama_exporter.py
151 lines (118 loc) · 6.05 KB
/
llama_exporter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Implementation of exporting LLaMA PyTorch model to TinyChatEngine format.
Usage:
python llama_exporter.py <path of hugging face model checkpoint> <output dir>
Example commandline:
python tools/llama_exporter.py --model models/llama2-chat/hf7B --output models/LLaMA_7B_2_chat
"""
import argparse
import math
import os
import struct
import torch
from transformers import LlamaForCausalLM
@torch.no_grad()
def _export_model(model, prefix):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
with open(os.path.join(f"{outpath}", "lm_head.bin"), "wb") as f:
f.write(model.lm_head._parameters["weight"].cpu().float().numpy().tobytes())
_export_llama_model(model.model, os.path.join(f"{outpath}", "decoder"))
def _export_embed_tokens(embed_tokens, prefix):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
with open(os.path.join(f"{outpath}", "weight.bin"), "wb") as f:
f.write(embed_tokens.weight.cpu().float().numpy().tobytes())
def _export_llama_model(model, prefix):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
_export_embed_tokens(model.embed_tokens, os.path.join(outpath, "embed_tokens"))
_export_LlamaRMSNorm(model.norm, os.path.join(outpath, "norm"))
for idx, layer in enumerate(model.layers):
_export_llama_layer(layer, os.path.join(outpath, f"layer{idx}"))
def _export_LlamaRMSNorm(op, prefix):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
with open(os.path.join(f"{outpath}", "weight.bin"), "wb") as f:
f.write(op.weight.cpu().float().numpy().tobytes())
def _export_llama_layer(layer, prefix):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
_export_attention_params(layer.self_attn, os.path.join(outpath, "self_attn"))
_export_LlamaRMSNorm(layer.input_layernorm, os.path.join(outpath, "input_layernorm"))
_export_LlamaRMSNorm(
layer.post_attention_layernorm,
os.path.join(outpath, "post_attention_layernorm"),
)
_export_linearfp(layer.mlp.gate_proj, os.path.join(outpath, "gate_proj"))
_export_linearfp(layer.mlp.down_proj, os.path.join(outpath, "down_proj"))
_export_linearfp(layer.mlp.up_proj, os.path.join(outpath, "up_proj"))
def _export_linearfp(op, prefix):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
with open(os.path.join(f"{outpath}", "weight.bin"), "wb") as f:
f.write(op._parameters["weight"].cpu().float().numpy().tobytes())
def _export_rotaryEmbedding(op, prefix):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
with open(os.path.join(f"{outpath}", "cos_cached.bin"), "wb") as f:
f.write(op.cos_cached.cpu().float().numpy().tobytes())
with open(os.path.join(f"{outpath}", "sin_cached.bin"), "wb") as f:
f.write(op.sin_cached.cpu().float().numpy().tobytes())
def _export_BMM_F32T(alpha, prefix):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
with open(os.path.join(f"{outpath}", "alpha.bin"), "wb") as f:
f.write(struct.pack("f", alpha))
def _export_attention_params(attn, prefix: str):
outpath = prefix
os.makedirs(outpath, exist_ok=True)
_export_linearfp(attn.k_proj, os.path.join(outpath, "k_proj"))
_export_linearfp(attn.v_proj, os.path.join(outpath, "v_proj"))
_export_linearfp(attn.q_proj, os.path.join(outpath, "q_proj"))
_export_linearfp(attn.o_proj, os.path.join(outpath, "o_proj"))
qk_bmm_alpha = 1 / math.sqrt(attn.head_dim)
_export_BMM_F32T(qk_bmm_alpha, os.path.join(outpath, "qk_bmm"))
_export_rotaryEmbedding(attn.rotary_emb, os.path.join(outpath, "rotary_emb"))
def main():
"""Export a LLaMA model to TinyChatEngine format."""
parser = argparse.ArgumentParser(description="export LLaMA pytorch model to TinyChatEngine format.")
parser.add_argument("--hf_path", type=str, help="Path to huggingface model hub", default=None)
parser.add_argument("--model", type=str, help="Path of the LLaMA torch model")
parser.add_argument("--output", type=str, help="Output directory of the exported model")
args = parser.parse_args()
if args.hf_path is None:
if not os.path.exists(args.model):
print(f"The model path '{args.model}' does not exist.")
return
if not os.path.exists(args.output):
print(f"The output path '{args.output}' does not exist. Creating a new directory...")
os.makedirs(args.output, exist_ok=True)
print("Loading model...")
if args.model.endswith(".pt"):
if args.model.split("/")[-1].lower().startswith("llama-2"):
if args.model.split("-")[2].lower() == "7b":
print("Loading LLaMA 7B model...")
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", torch_dtype=torch.float16)
elif args.model.split("-")[2].lower() == "13b":
print("Loading LLaMA 13B model...")
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-13b-hf", torch_dtype=torch.float16)
elif args.model.split("/")[-1].lower().startswith("codellama"):
if args.model.split("-")[1].lower() == "7b":
print("Loading CodaLLaMA 7B model...")
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-Instruct-hf", torch_dtype=torch.float16)
elif args.model.split("-")[1].lower() == "13b":
print("Loading CodaLLaMA 13B model...")
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-13b-Instruct-hf", torch_dtype=torch.float16)
else:
print("Model not supported.")
return
model.load_state_dict(torch.load(args.model))
else:
model = LlamaForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16)
else:
model = LlamaForCausalLM.from_pretrained(args.hf_path, torch_dtype=torch.bfloat16)
print("Start exporting LLaMA model...")
_export_model(model, args.output)
print("Finished exporting LLaMA model.")
if __name__ == "__main__":
main()