-
Notifications
You must be signed in to change notification settings - Fork 17
/
bench.py
67 lines (58 loc) · 2.87 KB
/
bench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import time
import torch
from flash_attention import flash_attention, normal_attention
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, required=True, help="flash/normal")
parser.add_argument('--b', type=int, required=False, default=1, help="Batch size")
parser.add_argument('--h', type=int, required=False, default=2, help="Number of heads")
parser.add_argument('--q_len', type=int, required=False, default=4096, help="Length/first dimension of Q matrix")
parser.add_argument('--kv_len', type=int, required=False, default=4096, help="Length/first dimension of K/V matrix")
parser.add_argument('--d', type=int, required=False, default=512, help="Dimension of vector")
parser.add_argument('--profile', action='store_true', help="For Pytorch profiling")
args = parser.parse_args()
Q = torch.randn(args.b, args.h, args.q_len, args.d, requires_grad=True).to(device='cuda')
K = torch.randn(args.b, args.h, args.kv_len, args.d, requires_grad=True).to(device='cuda')
V = torch.randn(args.b, args.h, args.kv_len, args.d, requires_grad=True).to(device='cuda')
mask = torch.randint(0, 2, (args.b, args.kv_len)).to(device='cuda')
if args.type == "flash":
for _ in range(10):
flash_attention(Q, K, V, mask)
start = time.time_ns()
flash_attention(Q, K, V, mask)
end = time.time_ns()
t = (end - start) / 1000000
print(f'{t}ms')
else:
for _ in range(10):
normal_attention(Q, K, V, mask)
start = time.time_ns()
normal_attention(Q, K, V, mask)
end = time.time_ns()
t = (end - start) / 1000000
print(f'{t}ms')
if args.profile:
if args.type == "flash":
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs/bench_log_flash'),
record_shapes=True,
profile_memory=True,
with_stack=False, # incurs an additional overhead, disable if not needed
with_flops=True,
with_modules=False, # only for torchscript models atm
) as prof:
flash_attention(Q, K, V, mask)
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))
else:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs/bench_log_normal'),
record_shapes=True,
profile_memory=True,
with_stack=False, # incurs an additional overhead, disable if not needed
with_flops=True,
with_modules=False, # only for torchscript models atm
) as prof:
normal_attention(Q, K, V, mask)
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))