-
Notifications
You must be signed in to change notification settings - Fork 7
/
activation.py
69 lines (55 loc) · 2.39 KB
/
activation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
from types import MethodType
import torch
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="meta-llama/Llama-2-7b-hf")
parser.add_argument("-l", "--lang", type=str, default="zh")
args = parser.parse_args()
is_llama = bool(args.model.lower().find('llama') >= 0)
model = LLM(model=args.model, tensor_parallel_size=torch.cuda.device_count(), enforce_eager=True)
max_length = model.llm_engine.model_config.max_model_len
num_layers = model.llm_engine.model_config.hf_config.num_hidden_layers
intermediate_size = model.llm_engine.model_config.hf_config.intermediate_size if is_llama else model.llm_engine.model_config.hf_config.hidden_size * 4
over_zero = torch.zeros(num_layers, intermediate_size, dtype=torch.int32).to('cuda')
def factory(idx):
def llama_forward(self, x):
gate_up, _ = self.gate_up_proj(x) # b, l, 2i
i = gate_up.size(-1)
gate_up[:, :, : i // 2] = torch.nn.SiLU()(gate_up[:, :, : i // 2])
activation = gate_up[:, :, : i // 2].float() # b, l, i
over_zero[idx, :] += (activation > 0).sum(dim=(0,1))
x = gate_up[:, :, : i // 2] * gate_up[:, :, i // 2 :]
x, _ = self.down_proj(x)
return x
def bloom_forward(self, x: torch.Tensor):
x, _ = self.dense_h_to_4h(x)
x = self.gelu_impl(x)
activation = x.float()
over_zero[idx, :] += (activation > 0).sum(dim=(0,1))
x, _ = self.dense_4h_to_h(x)
return x
if is_llama:
return llama_forward
else:
return bloom_forward
for i in range(num_layers):
if is_llama:
obj = model.llm_engine.driver_worker.model_runner.model.model.layers[i].mlp
else:
obj = model.llm_engine.driver_worker.model_runner.model.transformer.h[i].mlp
obj.forward = MethodType(factory(i), obj)
lang = args.lang
if is_llama:
ids = torch.load(f'data/id.{lang}.train.llama')
else:
ids = torch.load(f'data/id.{lang}.train.bloom')
l = ids.size(0)
l = min(l, 99999744) // max_length * max_length
input_ids = ids[:l].reshape(-1, max_length)
output = model.generate(prompt_token_ids=input_ids.tolist(), sampling_params=SamplingParams(max_tokens=1))
output = dict(n=l, over_zero=over_zero.to('cpu'))
if is_llama:
torch.save(output, f'data/activation.{lang}.train.llama-7b')
else:
torch.save(output, f'data/activation.{lang}.train.bloom-7b')