-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbound_layers.py
623 lines (546 loc) · 27.9 KB
/
bound_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
import torch
import numpy as np
from torch.nn import DataParallel
from torch.nn import Sequential, Conv2d, Linear, ReLU
from model_defs import Flatten, model_mlp_any
import torch.nn.functional as F
from itertools import chain
import logging
from torch.autograd import Variable
logging.basicConfig(level=logging.INFO)
# logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
DEBUGG = False
class BoundFlatten(torch.nn.Module):
def __init__(self, bound_opts=None):
super(BoundFlatten, self).__init__()
self.bound_opts = bound_opts
def forward(self, x):
self.shape = x.size()[1:]
return x.view(x.size(0), -1)
def convert_eval(self, bound_opts=None):
self.bound_opts = bound_opts
def interval_propagate(self, norm, h_U, h_L, eps):
self.upper_u = h_U
self.lower_l = h_L
return norm, h_U.view(h_U.size(0), -1), h_L.view(h_L.size(0), -1), 0, 0, 0, 0
def bound_backward(self, last_uA, last_lA):
def _bound_oneside(A):
if A is None:
return None
return A.view(A.size(0), A.size(1), *self.shape)
if self.bound_opts.get("same-slope", False) and (last_uA is not None) and (last_lA is not None):
new_bound = _bound_oneside(last_uA)
return new_bound, 0, new_bound, 0
else:
return _bound_oneside(last_uA), 0, _bound_oneside(last_lA), 0
class BoundLinear(Linear):
def __init__(self, in_features, out_features, bias=True, bound_opts=None):
super(BoundLinear, self).__init__(in_features, out_features, bias)
self.bound_opts = bound_opts
@staticmethod
def convert(linear_layer, bound_opts=None):
l = BoundLinear(linear_layer.in_features, linear_layer.out_features, linear_layer.bias is not None, bound_opts)
l.weight.data.copy_(linear_layer.weight.data)
l.bias.data.copy_(linear_layer.bias.data)
return l
def convert_eval(self, bound_opts=None):
self.bound_opts = bound_opts
def bound_backward(self, last_uA, last_lA):
def _bound_oneside(last_A, compute_A=True):
if last_A is None:
return None, 0
logger.debug('last_A %s', last_A.size())
# propagate A to the next layer
if compute_A:
next_A = last_A.matmul(self.weight)
logger.debug('next_A %s', next_A.size())
else:
next_A = None
# compute the bias of this layer
sum_bias = last_A.matmul(self.bias)
logger.debug('sum_bias %s', sum_bias.size())
return next_A, sum_bias
if self.bound_opts.get("same-slope", False) and (last_uA is not None) and (last_lA is not None):
uA, ubias = _bound_oneside(last_uA, True)
_, lbias = _bound_oneside(last_lA, False)
lA = uA
else:
uA, ubias = _bound_oneside(last_uA)
lA, lbias = _bound_oneside(last_lA)
return uA, ubias, lA, lbias
def interval_propagate(self, norm, h_U, h_L, eps, C = None):
self.upper_u = h_U
self.lower_l = h_L
# merge the specification
if C is not None:
# after multiplication with C, we have (batch, output_shape, prev_layer_shape)
# we have batch dimension here because of each example has different C
weight = C.matmul(self.weight)
bias = C.matmul(self.bias)
else:
# weight dimension (this_layer_shape, prev_layer_shape)
weight = self.weight
bias = self.bias
if norm == np.inf:
# Linf norm
mid = (h_U + h_L) / 2.0
diff = (h_U - h_L) / 2.0
weight_abs = weight.abs()
if C is not None:
center = weight.matmul(mid.unsqueeze(-1)) + bias.unsqueeze(-1)
deviation = weight_abs.matmul(diff.unsqueeze(-1))
# these have an extra (1,) dimension as the last dimension
center = center.squeeze(-1)
deviation = deviation.squeeze(-1)
else:
# fused multiply-add
center = torch.addmm(bias, mid, weight.t())
deviation = diff.matmul(weight_abs.t())
else:
# L2 norm
h = h_U # h_U = h_L, and eps is used
dual_norm = np.float64(1.0) / (1 - 1.0 / norm)
if C is not None:
center = weight.matmul(h.unsqueeze(-1)) + bias.unsqueeze(-1)
center = center.squeeze(-1)
else:
center = torch.addmm(bias, h, weight.t())
deviation = weight.norm(dual_norm, -1) * eps
upper = center + deviation
lower = center - deviation
# output
return np.inf, upper, lower, 0, 0, 0, 0
class BoundConv2d(Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, bound_opts=None):
super(BoundConv2d, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bound_opts = bound_opts
@staticmethod
def convert(l, bound_opts=None):
nl = BoundConv2d(l.in_channels, l.out_channels, l.kernel_size, l.stride, l.padding, l.dilation, l.groups, l.bias is not None, bound_opts)
nl.weight.data.copy_(l.weight.data)
nl.bias.data.copy_(l.bias.data)
logger.debug(nl.bias.size())
logger.debug(nl.weight.size())
return nl
def convert_eval(self, bound_opts=None):
self.bound_opts = bound_opts
def forward(self, input):
output = super(BoundConv2d, self).forward(input)
self.output_shape = output.size()[1:]
self.input_shape = input.size()[1:]
return output
def bound_backward(self, last_uA, last_lA):
def _bound_oneside(last_A, compute_A=True):
if last_A is None:
return None, 0
logger.debug('last_A %s', last_A.size())
shape = last_A.size()
# propagate A to the next layer, with batch concatenated together
if compute_A:
output_padding0 = int(self.input_shape[1]) - (int(self.output_shape[1]) - 1) * self.stride[0] + 2 * self.padding[0] - int(self.weight.size()[2])
output_padding1 = int(self.input_shape[2]) - (int(self.output_shape[2]) - 1) * self.stride[1] + 2 * self.padding[1] - int(self.weight.size()[3])
next_A = F.conv_transpose2d(last_A.view(shape[0] * shape[1], *shape[2:]), self.weight, None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=(output_padding0, output_padding1))
next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])
logger.debug('next_A %s', next_A.size())
else:
next_A = False
logger.debug('bias %s', self.bias.size())
# dot product, compute the bias of this layer, do a dot product
sum_bias = (last_A.sum((3,4)) * self.bias).sum(2)
logger.debug('sum_bias %s', sum_bias.size())
return next_A, sum_bias
# if the slope is the same (Fast-Lin) and both matrices are given, only need to compute one of them
if self.bound_opts.get("same-slope", False) and (last_uA is not None) and (last_lA is not None):
uA, ubias = _bound_oneside(last_uA, True)
_, lbias = _bound_oneside(last_lA, False)
lA = uA
else:
uA, ubias = _bound_oneside(last_uA)
lA, lbias = _bound_oneside(last_lA)
return uA, ubias, lA, lbias
def interval_propagate(self, norm, h_U, h_L, eps):
self.upper_u = h_U
self.lower_l = h_L
if norm == np.inf:
mid = (h_U + h_L) / 2.0
diff = (h_U - h_L) / 2.0
weight_abs = self.weight.abs()
deviation = F.conv2d(diff, weight_abs, None, self.stride, self.padding, self.dilation, self.groups)
else:
# L2 norm
mid = h_U
logger.debug('mid %s', mid.size())
# TODO: consider padding here?
deviation = torch.mul(self.weight, self.weight).sum((1,2,3)).sqrt() * eps
logger.debug('weight %s', self.weight.size())
logger.debug('deviation %s', deviation.size())
deviation = deviation.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
logger.debug('unsqueezed deviation %s', deviation.size())
center = F.conv2d(mid, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
logger.debug('center %s', center.size())
upper = center + deviation
lower = center - deviation
return np.inf, upper, lower, 0, 0, 0, 0
class BoundReLU(ReLU):
def __init__(self, prev_layer, inplace=False, bound_opts=None):
super(BoundReLU, self).__init__(inplace)
# ReLU needs the previous layer's bounds
# self.prev_layer = prev_layer
self.bound_opts = bound_opts
## Convert a ReLU layer to BoundReLU layer
# @param act_layer ReLU layer object
# @param prev_layer Pre-activation layer, used for get preactivation bounds
@staticmethod
def convert(act_layer, prev_layer, bound_opts=None):
l = BoundReLU(prev_layer, act_layer.inplace, bound_opts)
return l
def convert_eval(self, bound_opts=None):
self.bound_opts = bound_opts
def interval_propagate(self, norm, h_U, h_L, eps):
assert norm == np.inf
guard_eps = 1e-5
self.unstab = ((h_L < -guard_eps) & (h_U > guard_eps))
# stored upper and lower bounds will be used for backward bound propagation
self.upper_u = h_U
self.lower_l = h_L
tightness_loss = self.unstab.sum()
# tightness_loss = torch.min(h_U_unstab * h_U_unstab, h_L_unstab * h_L_unstab).sum()
return norm, F.relu(h_U), F.relu(h_L), tightness_loss, tightness_loss, \
(h_U < 0).sum(), (h_L > 0).sum()
def bound_backward(self, last_uA, last_lA,lower_d2=None):
lb_r = self.lower_l.clamp(max=0)
ub_r = self.upper_u.clamp(min=0)
# avoid division by 0 when both lb_r and ub_r are 0
ub_r = torch.max(ub_r, lb_r + 1e-8)
# CROWN upper and lower linear bounds
upper_d = ub_r / (ub_r - lb_r)
upper_b = - lb_r * upper_d
upper_d = upper_d.unsqueeze(1)
if lower_d2 is not None:
lower_d = lower_d2
lower_d[self.lower_l.unsqueeze(1)>0] = 1
lower_d[self.upper_u.unsqueeze(1)<0] = 0
else:
if self.bound_opts.get("same-slope", False):
# the same slope for upper and lower
lower_d = upper_d
elif self.bound_opts.get("zero-lb", False):
# Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN
lower_d = (upper_d >= 1.0).float()
elif self.bound_opts.get("one-lb", False):
# Always use slope 1 as lower bound
lower_d = (upper_d > 0.0).float()
elif self.bound_opts.get("ours",False):
lower_d = torch.rand_like(upper_d)
lower_d[self.lower_l.unsqueeze(1)>0] = 1
lower_d[self.upper_u.unsqueeze(1)<0] = 0
lower_d=Variable(lower_d.data, requires_grad=True)
elif self.bound_opts.get("binary", False):
lower_d = torch.randint_like(upper_d,0,2)
lower_d[self.lower_l.unsqueeze(1)>0] = 1
lower_d[self.upper_u.unsqueeze(1)<0] = 0
elif self.bound_opts.get("uniform", False):
lower_d = torch.rand_like(upper_d)
lower_d[self.lower_l.unsqueeze(1)>0] = 1
lower_d[self.upper_u.unsqueeze(1)<0] = 0
else: #CROWN
lower_d = (upper_d > 0.5).float()
uA = lA = None
ubias = lbias = 0
# Choose upper or lower bounds based on the sign of last_A
if last_uA is not None:
pos_uA = last_uA.clamp(min=0)
if self.bound_opts.get("same-slope", False):
# same upper_d and lower_d, no need to check the sign
uA = upper_d * last_uA
else:
neg_uA = last_uA.clamp(max=0)
uA = upper_d * pos_uA + lower_d * neg_uA
mult_uA = pos_uA.view(last_uA.size(0), last_uA.size(1), -1)
ubias = mult_uA.matmul(upper_b.view(upper_b.size(0), -1, 1)).squeeze(-1)
if last_lA is not None:
neg_lA = last_lA.clamp(max=0)
if self.bound_opts.get("same-slope", False):
lA = uA if uA is not None else lower_d * last_lA
else:
pos_lA = last_lA.clamp(min=0)
lA = upper_d * neg_lA + lower_d * pos_lA
mult_lA = neg_lA.view(last_lA.size(0), last_lA.size(1), -1)
lbias = mult_lA.matmul(upper_b.view(upper_b.size(0), -1, 1)).squeeze(-1)
if (lbias>0).sum()>0:
print('there is positive part in the lbias')
print('this needs to be negative')
print(mult_lA.shape)
print(mult_lA[mult_lA>0])
print('this needs to be positive')
print(upper_b[upper_b<0])
self.lower_d=lower_d
return uA, ubias, lA, lbias
class BoundSequential(Sequential):
def __init__(self, *args):
super(BoundSequential, self).__init__(*args)
## Convert a Pytorch model to a model with bounds
# @param sequential_model Input pytorch model
# @return Converted model
@staticmethod
def convert(sequential_model, bound_opts=None):
layers = []
if isinstance(sequential_model, Sequential):
seq_model = sequential_model
else:
seq_model = sequential_model.module
for l in seq_model:
if isinstance(l, Linear):
layers.append(BoundLinear.convert(l, bound_opts))
if isinstance(l, Conv2d):
layers.append(BoundConv2d.convert(l, bound_opts))
if isinstance(l, ReLU):
layers.append(BoundReLU.convert(l, layers[-1], bound_opts))
if isinstance(l, Flatten):
layers.append(BoundFlatten(bound_opts))
return BoundSequential(*layers)
def define_bound_opts(self, bound_opts=None):
self.bound_opts=bound_opts
def convert_bounds(self,bound_opt_eval=None):
for iii, l in enumerate(self):
l.convert_eval(bound_opt_eval)
self.bound_opts=bound_opt_eval
## The __call__ function is overwritten for DataParallel
def __call__(self, *input, **kwargs):
if "method_opt" in kwargs:
opt = kwargs["method_opt"]
kwargs.pop("method_opt")
else:
raise ValueError("Please specify the 'method_opt' as the last argument.")
if "disable_multi_gpu" in kwargs:
kwargs.pop("disable_multi_gpu")
if opt == "full_backward_range":
return self.full_backward_range(*input, **kwargs)
elif opt == "backward_range":
return self.backward_range(*input, **kwargs)
elif opt == "backward_range_frozen":
return self.backward_range_frozen(*input, **kwargs)
elif opt == "interval_range":
return self.interval_range(*input, **kwargs)
else:
return super(BoundSequential, self).__call__(*input, **kwargs)
def full_backward_range(self, norm=np.inf, x_U=None, x_L=None, eps=None, C=None, upper=True, lower=True):
h_U = x_U
h_L = x_L
modules = list(self._modules.values())
# IBP through the first weight (it is the same bound as CROWN for 1st layer, and IBP can be faster)
for i, module in enumerate(modules):
norm, h_U, h_L, _, _, _, _ = module.interval_propagate(norm, h_U, h_L, eps)
# skip the first flatten and linear layer, until we reach the first ReLU layer
if isinstance(module, BoundReLU):
# now the upper and lower bound of this ReLU layer has been set in interval_propagate()
last_module = i
break
# CROWN propagation for all rest layers
# outer loop, starting from the 2nd layer until we reach the output layer
for i in range(last_module + 1, len(modules)):
# we do not need bounds after ReLU/flatten layers; we only need the bounds
# before a ReLU layer
if isinstance(modules[i], BoundReLU):
# we set C as the weight of previous layer
if isinstance(modules[i-1], BoundLinear):
# add a batch dimension; all images have the same C in this case
newC = modules[i-1].weight.unsqueeze(0)
# we skip the layer i, and use CROWN to compute pre-activation bounds
# starting from layer i-2 (layer i-1 passed as specification)
ub, _, lb, _ = self.backward_range(norm = norm, x_U = x_U, x_L = x_L, eps = eps, C = newC, upper = True, lower = True, modules = modules[:i-1])
# add the missing bias term (we propagate newC which do not have bias)
ub += modules[i-1].bias
lb += modules[i-1].bias
elif isinstance(modules[i-1], BoundConv2d):
# we need to unroll the convolutional layer here
c, h, w = modules[i-1].output_shape
newC = torch.eye(c*h*w, device = x_U.device, dtype = x_U.dtype)
newC = newC.view(1, c*h*w, c, h, w)
# use CROWN to compute pre-actiation bounds starting from layer i-1
ub, _, lb, _ = self.backward_range(norm = norm, x_U = x_U, x_L = x_L, eps = eps, C = newC, upper = True, lower = True, modules = modules[:i])
# reshape to conv output shape; these are pre-activation bounds
ub = ub.view(ub.size(0), c, h, w)
lb = lb.view(lb.size(0), c, h, w)
else:
raise RuntimeError("Unsupported network structure")
# set pre-activation bounds for layer i (the ReLU layer)
modules[i].upper_u = ub
modules[i].lower_l = lb
# get the final layer bound with spec C
return self.backward_range(norm = norm, x_U = x_U, x_L = x_L, eps = eps, C = C, upper = upper, lower = lower)
def backward_range(self, norm=np.inf, x_U=None, x_L=None, eps=None, C=None, lower_d_list2=[],upper=False, lower=True, modules=None):
# start propagation from the last layer
modules = list(self._modules.values()) if modules is None else modules
upper_A = C if upper else None
lower_A = C if lower else None
upper_sum_b = lower_sum_b = x_U.new([0])
ReLU_lower_sum_b = lower_sum_b
not_ReLU_lower_sum_b = lower_sum_b
self.lower_d_list = []
j = 0
for module in reversed(modules):
if isinstance(module, BoundReLU):
if module.bound_opts.get("ours",False):
if lower_d_list2:
upper_A, upper_b, lower_A, ReLU_lower_b = module.bound_backward(upper_A, lower_A,lower_d_list2[j])
j+=1
else:
upper_A, upper_b, lower_A, ReLU_lower_b = module.bound_backward(upper_A, lower_A)
self.lower_d_list.append(module.lower_d)
else:
upper_A, upper_b, lower_A, ReLU_lower_b = module.bound_backward(upper_A, lower_A)
ReLU_lower_sum_b = ReLU_lower_b + ReLU_lower_sum_b
else:
upper_A, upper_b, lower_A, not_ReLU_lower_b = module.bound_backward(upper_A, lower_A)
not_ReLU_lower_sum_b = not_ReLU_lower_b + not_ReLU_lower_sum_b
upper_sum_b = upper_b + upper_sum_b
self.ReLU_lower_b = ReLU_lower_sum_b.detach().cpu()
self.not_ReLU_lower_b = not_ReLU_lower_sum_b.detach().cpu()
lower_sum_b = ReLU_lower_sum_b + not_ReLU_lower_sum_b + lower_sum_b
def _get_concrete_bound(A, sum_b, sign = -1):
if A is None:
return None , None, None
A = A.view(A.size(0), A.size(1), -1)
# A has shape (batch, specification_size, flattened_input_size)
logger.debug('Final A: %s', A.size())
if norm == np.inf:
x_ub = x_U.view(x_U.size(0), -1, 1)
x_lb = x_L.view(x_L.size(0), -1, 1)
center = (x_ub + x_lb) / 2.0
diff = (x_ub - x_lb) / 2.0
logger.debug('A_0 shape: %s', A.size())
logger.debug('sum_b shape: %s', sum_b.size())
# we only need the lower bound
bound = A.bmm(center) + sign * A.abs().bmm(diff)
logger.debug('bound shape: %s', bound.size())
else:
x = x_U.view(x_U.size(0), -1, 1)
dual_norm = np.float64(1.0) / (1 - 1.0 / norm)
deviation = A.norm(dual_norm, -1) * eps
bound = A.bmm(x) + sign * deviation.unsqueeze(-1)
bound = bound.squeeze(-1) + sum_b
return bound , A.bmm(center).detach().cpu() , sign * A.abs().bmm(diff).detach().cpu()
lb, self.A_x , self.A_norm = _get_concrete_bound(lower_A, lower_sum_b, sign = -1)
ub, _, _ = _get_concrete_bound(upper_A, upper_sum_b, sign = +1)
if ub is None:
ub = x_U.new([np.inf])
if lb is None:
lb = x_L.new([-np.inf])
self.lower_A = lower_A
return ub, upper_sum_b.detach(), lb, lower_sum_b.detach()
def backward_range_frozen(self, norm=np.inf, x_U=None, x_L=None, eps=None, C=None, lower_d_list2=[],upper=False, lower=True, modules=None, frozen=5):
# start propagation from the last layer
modules = list(self._modules.values()) if modules is None else modules
upper_A = C if upper else None
lower_A = C if lower else None
upper_sum_b = lower_sum_b = x_U.new([0])
ReLU_lower_sum_b = lower_sum_b
not_ReLU_lower_sum_b = lower_sum_b
self.lower_d_list = []
j = 0
for jj, module in enumerate(reversed(modules)):
if jj <= frozen:
if isinstance(module, BoundReLU):
if module.bound_opts.get("ours",False):
if lower_d_list2:
upper_A, upper_b, lower_A, ReLU_lower_b = module.bound_backward(upper_A, lower_A,lower_d_list2[j])
j+=1
else:
upper_A, upper_b, lower_A, ReLU_lower_b = module.bound_backward(upper_A, lower_A)
self.lower_d_list.append(module.lower_d)
else:
upper_A, upper_b, lower_A, ReLU_lower_b = module.bound_backward(upper_A, lower_A)
ReLU_lower_sum_b = ReLU_lower_b + ReLU_lower_sum_b
else:
upper_A, upper_b, lower_A, not_ReLU_lower_b = module.bound_backward(upper_A, lower_A)
not_ReLU_lower_sum_b = not_ReLU_lower_b + not_ReLU_lower_sum_b
upper_sum_b = upper_b + upper_sum_b
if jj+1 > frozen:
break
self.ReLU_lower_b = ReLU_lower_sum_b.detach().cpu()
self.not_ReLU_lower_b = not_ReLU_lower_sum_b.detach().cpu()
lower_sum_b = ReLU_lower_sum_b + not_ReLU_lower_sum_b + lower_sum_b
def _get_concrete_bound(A, sum_b,module, sign = -1):
if A is None:
return None , None, None
A = A.view(A.size(0), A.size(1), -1)
# A has shape (batch, specification_size, flattened_input_size)
logger.debug('Final A: %s', A.size())
if norm == np.inf:
x_lb = module.lower_l.view(x_U.size(0), -1, 1)
x_ub = module.upper_u.view(x_U.size(0), -1, 1)
center = (x_ub + x_lb) / 2.0
diff = (x_ub - x_lb) / 2.0
logger.debug('A_0 shape: %s', A.size())
logger.debug('sum_b shape: %s', sum_b.size())
# we only need the lower bound
bound = A.bmm(center) + sign * A.abs().bmm(diff)
logger.debug('bound shape: %s', bound.size())
else:
x = x_U.view(x_U.size(0), -1, 1)
dual_norm = np.float64(1.0) / (1 - 1.0 / norm)
deviation = A.norm(dual_norm, -1) * eps
bound = A.bmm(x) + sign * deviation.unsqueeze(-1)
bound = bound.squeeze(-1) + sum_b
return bound , A.bmm(center).detach().cpu() , sign * A.abs().bmm(diff).detach().cpu()
lb, self.A_x , self.A_norm = _get_concrete_bound(lower_A, lower_sum_b, module, sign = -1)
ub, _, _ = _get_concrete_bound(upper_A, upper_sum_b, module, sign = +1)
if ub is None:
ub = x_U.new([np.inf])
if lb is None:
lb = x_L.new([-np.inf])
self.lower_A = lower_A
return ub, upper_sum_b.detach(), lb, lower_sum_b.detach()
def interval_range(self, norm=np.inf, x_U=None, x_L=None, eps=None, C=None):
losses = 0
unstable = 0
dead = 0
alive = 0
h_U = x_U
h_L = x_L
for i, module in enumerate(list(self._modules.values())[:-1]):
# all internal layers should have Linf norm, except for the first layer
norm, h_U, h_L, loss, uns, d, a = module.interval_propagate(norm, h_U, h_L, eps)
# this is some stability loss used for initial experiments, not used in CROWN-IBP as it is not very effective
losses += loss
unstable += uns
dead += d
alive += a
# last layer has C to merge
norm, h_U, h_L, loss, uns, d, a = list(self._modules.values())[-1].interval_propagate(norm, h_U, h_L, eps, C)
losses += loss
unstable += uns
dead += d
alive += a
return h_U, h_L, losses, unstable, dead, alive
class BoundDataParallel(DataParallel):
# This is a customized DataParallel class for our project
def __init__(self, *inputs, **kwargs):
super(BoundDataParallel, self).__init__(*inputs, **kwargs)
self._replicas = None
# Overide the forward method
def forward(self, *inputs, **kwargs):
disable_multi_gpu = False
if "disable_multi_gpu" in kwargs:
disable_multi_gpu = kwargs["disable_multi_gpu"]
kwargs.pop("disable_multi_gpu")
if not self.device_ids or disable_multi_gpu:
return self.module(*inputs, **kwargs)
# Only replicate during forwarding propagation. Not during interval bounds
# and CROWN-IBP bounds, since weights have not been updated. This saves 2/3
# of communication cost.
if self._replicas is None or kwargs.get("method_opt", "forward") == "forward":
self._replicas = self.replicate(self.module, self.device_ids)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
outputs = self.parallel_apply(self._replicas[:len(inputs)], inputs, kwargs)
return self.gather(outputs, self.output_device)