-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathunet.py
486 lines (371 loc) · 16.2 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras import layers
from tensorflow.keras.initializers import TruncatedNormal
import math as m
def build_model_denoise(unet_args=None):
inputs=Input(shape=(None, None,2))
outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
#Encapsulating MultiStage_denoise in a keras.Model object
model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
return model
class DenseBlock(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N]
DenseNet Block consisting of "num_layers" densely connected convolutional layers
'''
def __init__(self, num_layers, N, ksize,activation):
'''
num_layers: number of densely connected conv. layers
N: Number of filters (same in each layer)
ksize: Kernel size (same in each layer)
'''
super(DenseBlock, self).__init__()
self.activation=activation
self.paddings_1=get_paddings(ksize)
self.H=[]
self.num_layers=num_layers
for i in range(num_layers):
self.H.append(layers.Conv2D(filters=N,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation))
def call(self, x):
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = self.H[0](x_)
if self.num_layers>1:
for h in self.H[1:]:
x = tf.concat([x_, x], axis=-1)
x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x_ = h(x_)
return x_
class FinalBlock(layers.Layer):
'''
[B, T, F, N] => [B, T, F, 2]
Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
'''
def __init__(self):
super(FinalBlock, self).__init__()
ksize=(3,3)
self.paddings_2=get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
def call(self, inputs ):
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
pred=self.conv2(x)
return pred
class SAM(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N] , [B, T, F, N]
Supervised Attention Module:
The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer.
The first stage output is then calculated adding the original input spectrogram to the residual noise.
The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
'''
def __init__(self, n_feat):
super(SAM, self).__init__()
ksize=(3,3)
self.paddings_1=get_paddings(ksize)
self.conv1 = layers.Conv2D(filters=n_feat,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
ksize=(3,3)
self.paddings_2=get_paddings(ksize)
self.conv2=layers.Conv2D(filters=2,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
ksize=(3,3)
self.paddings_3=get_paddings(ksize)
self.conv3 = layers.Conv2D(filters=n_feat,
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=None)
self.cropadd=CropAddBlock()
def call(self, inputs, input_spectrogram):
x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
x1 = self.conv1(x1)
x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
x=self.conv2(x)
#residual prediction
pred = layers.Add()([x, input_spectrogram]) #features to next stage
x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
M=self.conv3(x3)
M= tf.keras.activations.sigmoid(M)
x1=layers.Multiply()([x1, M])
x1 = layers.Add()([x1, inputs]) #features to next stage
return x1, pred
class AddFreqEncoding(layers.Layer):
'''
[B, T, F, 2] => [B, T, F, 12]
Generates frequency positional embeddings and concatenates them as 10 extra channels
This function is optimized for F=1025
'''
def __init__(self, f_dim):
super(AddFreqEncoding, self).__init__()
pi = tf.constant(m.pi)
pi=tf.cast(pi,'float32')
self.f_dim=f_dim #f_dim is fixed
n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
coss=tf.math.cos(pi*n)
f_channel = tf.expand_dims(coss, -1) #(1025,1)
self.fembeddings= f_channel
for k in range(1,10):
coss=tf.math.cos(2**k*pi*n)
f_channel = tf.expand_dims(coss, -1) #(1025,1)
self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
def call(self, input_tensor):
batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
time_dim = tf.shape(input_tensor)[1] # get time dimension
fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
def get_paddings(K):
return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
class Decoder(layers.Layer):
'''
[B, T, F, N] , skip connections => [B, T, F, N]
Decoder side of the U-Net subnetwork.
'''
def __init__(self, Ns, Ss, unet_args):
super(Decoder, self).__init__()
self.Ns=Ns
self.Ss=Ss
self.activation=unet_args.activation
self.depth=unet_args.depth
ksize=(3,3)
self.paddings_3=get_paddings(ksize)
self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.cropadd=CropAddBlock()
self.dblocks=[]
for i in range(self.depth):
self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
def call(self,inputs, contracting_layers):
x=inputs
for i in range(self.depth,0,-1):
x=self.dblocks[i-1](x, contracting_layers[i-1])
return x
class Encoder(tf.keras.Model):
'''
[B, T, F, N] => skip connections , [B, T, F, N_4]
Encoder side of the U-Net subnetwork.
'''
def __init__(self, Ns, Ss, unet_args):
super(Encoder, self).__init__()
self.Ns=Ns
self.Ss=Ss
self.activation=unet_args.activation
self.depth=unet_args.depth
self.contracting_layers = {}
self.eblocks=[]
for i in range(self.depth):
self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
def call(self, inputs):
x=inputs
for i in range(self.depth):
x, x_contract=self.eblocks[i](x)
self.contracting_layers[i] = x_contract #if remove 0, correct this
x=self.i_block(x)
return x, self.contracting_layers
class MultiStage_denoise(tf.keras.Model):
def __init__(self, unet_args=None):
super(MultiStage_denoise, self).__init__()
self.activation=unet_args.activation
self.depth=unet_args.depth
if unet_args.use_fencoding:
self.freq_encoding=AddFreqEncoding(unet_args.f_dim)
self.use_sam=unet_args.use_SAM
self.use_fencoding=unet_args.use_fencoding
self.num_stages=unet_args.num_stages
#Encoder
self.Ns= [32,64,64,128,128,256,512]
self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)]
#initial feature extractor
ksize=(7,7)
self.paddings_1=get_paddings(ksize)
self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
self.cropconcat = CropConcatBlock()
self.cropadd = CropAddBlock()
self.finalblock=FinalBlock()
if self.num_stages>1:
self.sam_1=SAM(self.Ns[0])
#initial feature extractor
ksize=(7,7)
self.paddings_2=get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
kernel_size=ksize,
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID',
activation=self.activation)
self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
@tf.function()
def call(self, inputs):
if self.use_fencoding:
x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
else:
x_w_freq=inputs
#intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
x=self.conv2d_1(x) #None, None, 1025, 32
x, contracting_layers_s1= self.encoder_s1(x)
#decoder
feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
if self.num_stages>1:
#SAM module
Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
#intitial feature extractor
x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
x=self.conv2d_2(x)
if self.use_sam:
x = tf.concat([x, Fout], axis=-1)
else:
x = tf.concat([x,feats_s1], axis=-1)
x, contracting_layers_s2= self.encoder_s2(x)
feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
#consider implementing a third stage?
pred_stage_2=self.finalblock(feats_s2)
return pred_stage_2, pred_stage_1
else:
pred_stage_1=self.finalblock(feats_s1)
return pred_stage_1
class I_Block(layers.Layer):
'''
[B, T, F, N] => [B, T, F, N]
Intermediate block:
Basically, a densenet block with a residual connection
'''
def __init__(self,N,activation, num_tfc, **kwargs):
super(I_Block, self).__init__(**kwargs)
ksize=(3,3)
self.tfc=DenseBlock(num_tfc,N,ksize, activation)
self.conv2d_res= layers.Conv2D(filters=N,
kernel_size=(1,1),
kernel_initializer=TruncatedNormal(),
strides=1,
padding='VALID')
def call(self,inputs):
x=self.tfc(inputs)
inputs_proj=self.conv2d_res(inputs)
return layers.Add()([x,inputs_proj])
class E_Block(layers.Layer):
def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
super(E_Block, self).__init__(**kwargs)
self.layer_idx=layer_idx
self.N0=N0
self.N=N
self.S=S
self.activation=activation
self.i_block=I_Block(N0,activation,num_tfc)
ksize=(S[0]+2,S[1]+2)
self.paddings_2=get_paddings(ksize)
self.conv2d_2 = layers.Conv2D(filters=N,
kernel_size=(S[0]+2,S[1]+2),
kernel_initializer=TruncatedNormal(),
strides=S,
padding='VALID',
activation=self.activation)
def call(self, inputs, training=None, **kwargs):
x=self.i_block(inputs)
x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
x_down = self.conv2d_2(x_down)
return x_down, x
def get_config(self):
return dict(layer_idx=self.layer_idx,
N=self.N,
S=self.S,
**super(E_Block, self).get_config()
)
class D_Block(layers.Layer):
def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
super(D_Block, self).__init__(**kwargs)
self.layer_idx=layer_idx
self.N=N
self.S=S
self.activation=activation
ksize=(S[0]+2, S[1]+2)
self.paddings_1=get_paddings(ksize)
self.tconv_1= layers.Conv2DTranspose(filters=N,
kernel_size=(S[0]+2, S[1]+2),
kernel_initializer=TruncatedNormal(),
strides=S,
activation=self.activation,
padding='VALID')
self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
self.projection = layers.Conv2D(filters=N,
kernel_size=(1,1),
kernel_initializer=TruncatedNormal(),
strides=1,
activation=self.activation,
padding='VALID')
self.cropadd=CropAddBlock()
self.cropconcat=CropConcatBlock()
self.i_block=I_Block(N,activation,num_tfc)
def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
x = inputs
x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
x = self.tconv_1(inputs)
x2= self.upsampling(inputs)
if x2.shape[-1]!=x.shape[-1]:
x2= self.projection(x2)
x= self.cropadd(x,x2)
x=self.cropconcat(x,bridge)
x=self.i_block(x)
return x
def get_config(self):
return dict(layer_idx=self.layer_idx,
N=self.N,
S=self.S,
**super(D_Block, self).get_config()
)
class CropAddBlock(layers.Layer):
def call(self,down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
height_diff: (x2_shape[1] + height_diff),
width_diff: (x2_shape[2] + width_diff),
:]
x = layers.Add()([down_layer_cropped, x])
return x
class CropConcatBlock(layers.Layer):
def call(self, down_layer, x, **kwargs):
x1_shape = tf.shape(down_layer)
x2_shape = tf.shape(x)
height_diff = (x1_shape[1] - x2_shape[1]) // 2
width_diff = (x1_shape[2] - x2_shape[2]) // 2
down_layer_cropped = down_layer[:,
height_diff: (x2_shape[1] + height_diff),
width_diff: (x2_shape[2] + width_diff),
:]
x = tf.concat([down_layer_cropped, x], axis=-1)
return x