-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loadModel.py
66 lines (60 loc) · 2.62 KB
/
loadModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from collections import OrderedDict
def loadModel(model, pathToModel, dataParallelModel=False):
try:
#LOAD TRAINED MODEL INTO GPU
if(torch.cuda.is_available() and dataParallelModel==False):
model = torch.load(pathToModel)
print("\n--------model restored--------\n")
return model
elif(torch.cuda.is_available() and dataParallelModel==True):
state_dict = torch.load(pathToModel)
print(state_dict.keys())
model.load_state_dict(state_dict)
print("\n--------DataParallel GPU model restored--------\n")
return model
#LOAD MODEL TRAINED ON GPU INTO CPU
elif(torch.cuda.is_available()==False and dataParallelModel==True):
state_dict = torch.load(pathToModel, map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
if k[0] == 'm':
new_state_dict[name] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
print("\n--------GPU data parallel model restored--------\n")
return model
else:
storage = {"cuda":"cpu"}
model = torch.load(pathToModel, map_location=lambda storage, loc: storage)
print("\n--------GPU model restored--------\n")
return model
except:
print("\n--------no saved model found--------\n")
def loadStateDict(model, pathToStateDict):
try:
if(torch.cuda.is_available()):
state_dict = torch.load(pathToStateDict)
model.load_state_dict(state_dict)
print("\n--------GPU state dict restored and loaded into GPU--------\n")
return model
else:
state_dict = torch.load(pathToStateDict, map_location=lambda storage, loc: storage)
#print(state_dict.keys())
model.load_state_dict(state_dict)
print("\n--------GPU state dict restored, loaded into CPU--------\n")
return model
except:
state_dict = torch.load(pathToStateDict, map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
if k[0] == 'm':
new_state_dict[name] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
print("\n--------GPU data parallel state dict restored--------\n")
return model