-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
156 lines (124 loc) · 5.61 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *
from torchvision.models import vgg16_bn
from tempfile import NamedTemporaryFile
# This is largely a copy of source code found on:
# https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson7-superres-imagenet.ipynb
class deboning_512:
def __init__(self, model_name):
self.model_name = '../../../models/deboning_512_200steps'
if model_name != None:
self.model_name = model_name
defaults.cmap = 'binary'
CPU=1
if CPU == 1:
defaults.device = torch.device("cpu")
else:
# default is GPU
torch.cuda.set_device(0)
path = Path('data')
path_input_512 = Path('data/input_512')
path_bone_512 = Path('data/bone_512')
path_tissue_512 = Path('data/tissue_512')
bs,size=8,512
if not(CPU == 1):
free = gpu_mem_get_free_no_cache()
# the max size of the test image depends on the available GPU RAM
if free > 8200:
bs,size=16,128
else:
bs,size=8,128
print(f"using bs={bs}, size={size}, have {free}MB of GPU RAM free")
arch = models.resnet34
# sample = 0.1
sample = False
tfms = get_transforms()
# we want to predict the tissue from the input
src = ImageImageList.from_folder(path_input_512)
if sample:
src = src.filter_by_rand(sample, seed=42)
src = src.split_by_rand_pct(0.1, seed=42)
# ok, we need to use the path_tissue_128 as the target for the input of path_input_128
def get_data(bs,size):
data = (src.label_from_func(lambda x: path_tissue_512/x.relative_to(path_input_512))
.transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
.databunch(bs=bs).normalize(do_y=True))
data.c = 3
return data
data = get_data(bs,size)
#
# Feature list
#
def gram_matrix(x):
n,c,h,w = x.size()
x = x.view(n, c, -1)
return (x @ x.transpose(1,2))/(c*h*w)
if CPU == 1:
vgg_m = vgg16_bn(True).features.eval()
else:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
base_loss = F.l1_loss
class FeatureLoss(nn.Module):
def __init__(self, m_feat, layer_ids, layer_wgts):
super().__init__()
self.m_feat = m_feat
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
] + [f'gram_{i}' for i in range(len(layer_ids))]
def make_features(self, x, clone=False):
self.m_feat(x)
return [(o.clone() if clone else o) for o in self.hooks.stored]
def forward(self, input, target):
out_feat = self.make_features(target, clone=True)
in_feat = self.make_features(input)
self.feat_losses = [base_loss(input,target)]
self.feat_losses += [base_loss(f_in, f_out)*w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.metrics = dict(zip(self.metric_names, self.feat_losses))
return sum(self.feat_losses)
def __del__(self): self.hooks.remove()
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])
#
# Create the network
#
wd = 1e-3
self.learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics, blur=True, norm_type=NormType.Weight)
gc.collect();
#
defaults.device = 'cpu'
#_=learn.load('../../../models/deboning_512_120steps')
_=self.learn.load(self.model_name)
# This one is not needed, we can use predict immediately.
#data_mr = (ImageImageList.from_folder(path_input_512).split_by_rand_pct(0.1, seed=42)
# .label_from_func(lambda x: path_tissue_512/x.relative_to(path_input_512))
# .transform(get_transforms(), size=(512,512), tfm_y=True)
# .databunch(bs=2).normalize(do_y=True))
#
#learn.data = data_mr
def debone(self, image_path ):
path_input = Path('data/input')
fn = path_input/'img_2050.png'
if image_path != None:
fn = image_path
img = PIL.Image.open(fn); img = img.convert('L').resize([512,512])
# dirty...
fn = NamedTemporaryFile(suffix='.png')
img.save(fn)
img = open_image(fn)
_,img_hr,b = self.learn.predict(img)
#show_image(img, figsize=(18,15), interpolation='nearest');
#show_image(img_hr, figsize=(18,15), interpolation='nearest');
return img, img_hr
if __name__ == "__main__":
# In case we run this as a program we can do this for testing.
# In case we import this program this will not be running.
pre = deboning_512('../../../models/deboning_512_200steps')
img, img_hr = pre.debone('data/input/img_2050.png')