-
Notifications
You must be signed in to change notification settings - Fork 0
/
analysis_model.py
46 lines (35 loc) · 1.34 KB
/
analysis_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import time
import argparse
from utils import read_yml
from torchsummaryX import summary
from thop import profile
from road_extractor import build_road_extractor
def parse_args():
parser = argparse.ArgumentParser(description='Analysis road extractor model')
parser.add_argument('--config', default='configs/LRDNet_RNBD.yml', help='train config file path')
parser.add_argument('--size', type=int, default=512, help='the dir to save logs and models')
parser.add_argument('--batch', type=int, default=1, help='the batch for calculate FPS')
parser.add_argument("--device", type=int, default=0)
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = read_yml(args.config)
torch.cuda.set_device(args.device)
model = build_road_extractor(cfg_model=cfg['model']).cuda()
model.eval()
inputs = torch.zeros(args.batch, 3, args.size, args.size).cuda()
# torch_summary
# summary(model, x=inputs)
# thop
print('=============================thop===============================')
flops, params = profile(model, (inputs,))
print('FLOPs:', flops / 1000 ** 3, 'G')
print('Params:', params / 1000 ** 2, 'M')
t1 = time.time()
for i in range(100):
y = model(inputs)
print('FPS:', args.batch/((time.time()-t1)/100))
if __name__ == '__main__':
main()