-
Notifications
You must be signed in to change notification settings - Fork 919
/
convert.py
47 lines (39 loc) · 1.54 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import numpy
from transformers import AutoModel
def replace_key(key: str) -> str:
key = key.replace(".layer.", ".layers.")
key = key.replace(".self.key.", ".key_proj.")
key = key.replace(".self.query.", ".query_proj.")
key = key.replace(".self.value.", ".value_proj.")
key = key.replace(".attention.output.dense.", ".attention.out_proj.")
key = key.replace(".attention.output.LayerNorm.", ".ln1.")
key = key.replace(".output.LayerNorm.", ".ln2.")
key = key.replace(".intermediate.dense.", ".linear1.")
key = key.replace(".output.dense.", ".linear2.")
key = key.replace(".LayerNorm.", ".norm.")
key = key.replace("pooler.dense.", "pooler.")
return key
def convert(bert_model: str, mlx_model: str) -> None:
model = AutoModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
}
numpy.savez(mlx_model, **tensors)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument(
"--bert-model",
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()
convert(args.bert_model, args.mlx_model)