-
Notifications
You must be signed in to change notification settings - Fork 4
/
generate_prediction_pytorch.py
138 lines (107 loc) · 4.52 KB
/
generate_prediction_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
from types import SimpleNamespace
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from PyContrast.pycontrast.networks.build_backbone import build_model
from torch_utils import get_loaders_imagenet, get_loaders_objectnet
device, dtype = 'cuda:0', torch.float32
def get_model(model='resnet50_infomin'):
if model == 'resnet50_infomin':
args = SimpleNamespace()
args.jigsaw = True
args.arch, args.head, args.feat_dim = 'resnet50', 'mlp', 128
args.mem = 'moco'
args.modal = 'RGB'
model, _ = build_model(args)
cp = torch.load('checkpoints/InfoMin_800.pth')
sd = cp['model']
new_sd = {}
for entry in sd:
new_sd[entry.replace('module.', '')] = sd[entry]
model.load_state_dict(new_sd, strict=False) # no head, don't need linear model
model = model.to(device=device)
return model
elif model == 'resnext152_infomin':
args = SimpleNamespace()
args.jigsaw = True
args.arch, args.head, args.feat_dim = 'resnext152v1', 'mlp', 128
args.mem = 'moco'
args.modal = 'RGB'
model, _ = build_model(args)
cp = torch.load('checkpoints/InfoMin_resnext152v1_e200.pth')
sd = cp['model']
new_sd = {}
for entry in sd:
new_sd[entry.replace('module.', '')] = sd[entry]
model.load_state_dict(new_sd, strict=False) # no head, don't need linear model
model = model.to(device=device)
return model
elif model == 'resnet50_mocov2':
args = SimpleNamespace()
args.jigsaw = False
args.arch, args.head, args.feat_dim = 'resnet50', 'linear', 2048
args.mem = 'moco'
args.modal = 'RGB'
model, _ = build_model(args)
cp = torch.load('checkpoints/MoCov2.pth')
sd = cp['model']
new_sd = {}
for entry in sd:
new_sd[entry.replace('module.', '')] = sd[entry]
model.load_state_dict(new_sd, strict=False) # no head, don't need linear model
model = model.to(device=device)
return model
elif model == 'resnet50_swav':
model = torch.hub.load('facebookresearch/swav', 'resnet50')
modules = list(model.children())[:-1]
model = nn.Sequential(*modules)
model = model.to(device=device)
return model
else:
raise ValueError('Wrong model')
def eval_swav(model, loader):
reses = []
labs = []
for batch_idx, (data, target) in enumerate(tqdm(loader)):
data, target = data.to(device=device, dtype=dtype), target.to(device=device)
output = model.forward(data)
reses.append(output.detach().cpu().numpy())
labs.append(target.detach().cpu().numpy())
rss = np.concatenate(reses, axis=0)
lbs = np.concatenate(labs, axis=0)
return rss, lbs
def eval(model, loader, kwargs):
reses = []
labs = []
for batch_idx, (data, target) in enumerate(tqdm(loader)):
data, target = data.to(device=device, dtype=dtype), target.to(device=device)
output = model.forward(data, mode=2)
reses.append(output.detach().cpu().numpy())
labs.append(target.detach().cpu().numpy())
rss = np.concatenate(reses, axis=0)
lbs = np.concatenate(labs, axis=0)
return rss, lbs
imagenet_path = '/home/vista/Datasets/ILSVRC/Data/CLS-LOC'
imagenet_path = '/home/chaimb/ILSVRC/Data/CLS-LOC'
objectnet_path = '/home/chaimb/objectnet-1.0'
def eval_and_save(model='resnet50_infomin'):
mdl = get_model(model)
bs = 32 if model in ['resnet50_infomin'] else 16
train_loader, val_loader = get_loaders_imagenet(imagenet_path, bs, bs, 224, 8, 1, 0)
obj_loader, _, _, _, _ = get_loaders_objectnet(objectnet_path, imagenet_path, bs, 224, 8, 1, 0)
eval_f = eval_swav if 'swav' in model else eval
train_embs, train_labs = eval_f(mdl, train_loader)
val_embs, val_labs = eval_f(mdl, val_loader)
obj_embs, obj_labs = eval_f(mdl, obj_loader)
os.makedirs('./results', exist_ok=True)
np.savez(os.path.join('./results', model + '.npz'), train_embs=train_embs, train_labs=train_labs, val_embs=val_embs,
val_labs=val_labs, obj_embs=obj_embs, obj_labs=obj_labs)
models = ['resnet50_infomin', 'resnext152_infomin', 'resnet50_mocov2', 'resnet50_swav']
parser = argparse.ArgumentParser(description='IM')
parser.add_argument('--model', dest='model', type=str, default='resnext152_infomin',
help='Model: one of ' + ', '.join(models))
args = parser.parse_args()
eval_and_save(args.model)