-
Notifications
You must be signed in to change notification settings - Fork 5
/
convert_tp_weigths.py
101 lines (87 loc) · 3.89 KB
/
convert_tp_weigths.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from pathlib import Path
import torch
import os
import tqdm
from torch import nn
from transformers import AutoModelForCausalLM
from safetensors.torch import save_file
def match_suffix(text, suffix):
return text[-len(suffix):] == suffix
def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
"""BLOOM specific sharding mechanism"""
save_paths = [path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.bin" for tp_rank in range(tp_world_size)]
if all(save_path.exists() for save_path in save_paths):
print("Loading already cached values")
return save_paths
model: nn.Module = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype)
model.state_dict()
shards_state_dicts = [
{} for _ in range(tp_world_size)
]
state_dict = model.state_dict()
keys = list(state_dict.keys())
for state_name in keys:
print(state_name)
state = state_dict[state_name]
if any(match_suffix(state_name, candidate) for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
# "transformer.word_embeddings.weight",
# "lm_head.weight"
]):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(match_suffix(state_name, candidate) for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"lm_head.weight"
]):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
block_size = input_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(match_suffix(state_name, candidate) for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]):
shards_state_dicts[0][state_name] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][state_name] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(zip(save_paths, shards_state_dicts)):
save_path.parent.mkdir(parents=True, exist_ok=True)
save_file(shard_state_dict.copy(), str(save_path))
save_paths.append(save_path)
return save_paths
def main():
save_path = Path("weights/")
model_name = "bigscience/bloom"
tp_world_size = 16
dtype = torch.bfloat16
shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)
def main2():
for i in tqdm.tqdm(range(16)):
local = f"bloom_tp-rank-{i}-of-16.pty"
filename = f"/home/thomas_wang_huggingface_co/models/bigscience/{local}"
data = torch.load(filename, map_location="cpu")
save_file(data.copy(), os.path.join("weights", "bigscience", local))
if __name__ == "__main__":
main2()