-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconvert_tensorrt.py
24 lines (19 loc) · 995 Bytes
/
convert_tensorrt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from argparse import ArgumentParser
import torch
from mmdet2trt import mmdet2trt
def main():
parser = ArgumentParser()
parser.add_argument('--config', default='/workspaces/detection_and_tracking/yolox_x_8x8_300e_coco.py', help='mmdet Config file')
parser.add_argument('--checkpoint', default='/workspaces/detection_and_tracking/yolox-xl-epoch_345-327map.pth', help='mmdet Checkpoint file')
parser.add_argument('--save_path', default='/workspaces/detection_and_tracking/yolox-xl-best.pth', help='tensorrt model save path')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--fp16', type=bool, default=True, help='enable fp16 inference')
args = parser.parse_args()
cfg_path = args.config
trt_model = mmdet2trt(
cfg_path, args.checkpoint, fp16_mode=args.fp16, device=args.device)
torch.save(trt_model.state_dict(), args.save_path)
if __name__ == '__main__':
main()