-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert.py
44 lines (30 loc) · 1.25 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# python3 convert.py input-model.tar output-model.pth
import argparse
import torch
from backbone.repvgg import repvgg_model_convert
from model import RepNet6D, RepNet5D
parser = argparse.ArgumentParser(description='6DoFHPE Conversion')
parser.add_argument('load', metavar='LOAD', help='path to the weights file')
parser.add_argument('save', metavar='SAVE', help='path to the weights file')
parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-B1g4')
def load_filtered_state_dict(model, snapshot):
# By user apaszke from discuss.pytorch.org
model_dict = model.state_dict()
snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
model_dict.update(snapshot)
model.load_state_dict(model_dict)
def convert():
args = parser.parse_args()
print('Loading model.')
model = RepNet6D(backbone_name=args.arch,
backbone_file='',
deploy=False,
pretrained=False)
# Load snapshot
saved_state_dict = torch.load(args.load)
load_filtered_state_dict(model, saved_state_dict['model_state_dict'])
print('Converting model.')
repvgg_model_convert(model, save_path=args.save)
print('Done.')
if __name__ == '__main__':
convert()