Skip to content

Commit

Permalink
Merge pull request #45 from majianjia/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
majianjia authored Jun 8, 2019
2 parents e04f20b + c724fe3 commit 1bbf1d3
Showing 1 changed file with 78 additions and 34 deletions.
112 changes: 78 additions & 34 deletions scripts/nnom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,44 +117,31 @@ def is_shift_layer(layer):
'conv1d' in layer.name or
'dense' in layer.name or
'softmax' in layer.name or
'add' in layer.name or
('add' in layer.name and 'zero' not in layer.name) or # the name, zero_padding contains 'add'
'subtract' in layer.name or
'multiply' in layer.name or
('activation' in layer.name and layer.get_config()['activation'] == 'softmax')
):
return True
return False

def generate_weights(model, name='weights.h', shift_list=None):
# Quantize weights to 8-bits using (min,max) and write to file
f = open(name, 'w')
f.close()

for curr_idx, layer in enumerate(model.layers):
if (not layer.weights):
continue

# before merging bn layer, check if the bn is "legally" after Conv
if('batch_normalization' in layer.name) and \
('conv2d' not in layer._inbound_nodes[0].inbound_layers[0].name):
raise Exception('Currently only support batch_normalization after conv2d', layer.name,
layer._inbound_nodes[0].inbound_layers[0].name)

# try to fuse BN layer to convolutional
if ('conv2d' in layer.name) and \
def fuse_bn_to_conv(layer):
# try to fuse BN layer to convolutional
if ('conv' in layer.name) and \
('batch_normalization' in layer._outbound_nodes[0].outbound_layer.name):

print("fusing batch normalization to", layer.name)
bn_layer = layer._outbound_nodes[0].outbound_layer
c_w = layer.get_weights()[0]
c_b = layer.get_weights()[1]
print('original weight max', c_w.max(), 'min', c_w.min())
print('original bias max', c_b.max(), 'min', c_b.min())
bn_gamma = bn_layer.get_weights()[0]
bn_beta = bn_layer.get_weights()[1]
bn_mean = bn_layer.get_weights()[2]
bn_variance = bn_layer.get_weights()[3]

print("fusing batch normalization to", layer.name)
bn_layer = layer._outbound_nodes[0].outbound_layer
c_w = layer.get_weights()[0]
c_b = layer.get_weights()[1]
print('original weight max', c_w.max(), 'min', c_w.min())
print('original bias max', c_b.max(), 'min', c_b.min())
bn_gamma = bn_layer.get_weights()[0]
bn_beta = bn_layer.get_weights()[1]
bn_mean = bn_layer.get_weights()[2]
bn_variance = bn_layer.get_weights()[3]

if ('conv2d' in layer.name):
epsilon = 1e-3 # default epsilon for tf.slim.batch_norm
for l in range(c_w.shape[3]):
for k in range(c_w.shape[2]):
Expand All @@ -171,12 +158,49 @@ def generate_weights(model, name='weights.h', shift_list=None):
depth_dim = c_w.shape[3]
for l in range(depth_dim):
c_b[l] = (bn_gamma[l] * (c_b[l] - bn_mean[l]) / np.sqrt(bn_variance[l] + epsilon)) + bn_beta[l]
# conv1d
else:
epsilon = 1e-3 # default epsilon for tf.slim.batch_norm
for k in range(c_w.shape[2]):
for j in range(c_w.shape[1]):
for i in range(c_w.shape[0]):
if "depthwise" in layer.name: # depthwise batchnorm params are ordered differently
c_w[i][j][k] *= bn_gamma[j] / np.sqrt(bn_variance[j] + epsilon)
else:
c_w[i][j][k] *= bn_gamma[k] / np.sqrt(bn_variance[k] + epsilon)

print('fused weight max', c_w.max(), 'min', c_w.min())
print('fused bias max', c_b.max(), 'min', c_b.min())
# write the weights back to the layer
# after that, the model will be destroyed.. need a better way to pass the new weight
layer.set_weights([c_w, c_b])
if "depthwise" in layer.name:
depth_dim = c_w.shape[1]
else:
depth_dim = c_w.shape[2]
for l in range(depth_dim):
c_b[l] = (bn_gamma[l] * (c_b[l] - bn_mean[l]) / np.sqrt(bn_variance[l] + epsilon)) + bn_beta[l]

print('fused weight max', c_w.max(), 'min', c_w.min())
print('fused bias max', c_b.max(), 'min', c_b.min())
# write the weights back to the layer
# after that, the model will be destroyed.. need a better way to pass the new weight
layer.set_weights([c_w, c_b])

def generate_weights(model, name='weights.h', shift_list=None):
# Quantize weights to 8-bits using (min,max) and write to file
f = open(name, 'w')
f.close()

for curr_idx, layer in enumerate(model.layers):
if (not layer.weights):
continue

# before merging bn layer, check if the bn is "legally" after Conv
if('batch_normalization' in layer.name) and \
('conv' not in layer._inbound_nodes[0].inbound_layers[0].name):
raise Exception('Currently only support batch_normalization after conv', layer.name,
layer._inbound_nodes[0].inbound_layers[0].name)

# try to fuse BN layer to convolutional
if ('conv' in layer.name) and \
('batch_normalization' in layer._outbound_nodes[0].outbound_layer.name):
fuse_bn_to_conv(layer)

# generate weights and bias now
weight_dec_shift = 0
Expand Down Expand Up @@ -532,6 +556,26 @@ def is_skipable_layer(layer):
elif('1d' in layer.name):
fp.write('\tlayer[{0}] = model.hook(UpSample(kernel(1,{1})), layer[{2}]);\n'.format(
id, cfg['size'][0], LI[inp][0]))
# zero padding
elif ('zero_padding' in layer.name):
inp = layer.input.name.replace(':','/').split('/')[0]
cfg = layer.get_config()
if('2d' in layer.name):
fp.write('\tlayer[{0}] = model.hook(ZeroPadding(border({1},{2},{3},{4})), layer[{5}]);\n'.format(
id, cfg['padding'][0][0], cfg['padding'][0][1], cfg['padding'][1][0],cfg['padding'][1][1], LI[inp][0]))
elif('1d' in layer.name):
fp.write('\tlayer[{0}] = model.hook(ZeroPadding(border(0,0,{1},{2})), layer[{3}]);\n'.format(
id, cfg['padding'][0], cfg['padding'][1], LI[inp][0]))
# Cropping
elif ('cropping' in layer.name):
inp = layer.input.name.replace(':','/').split('/')[0]
cfg = layer.get_config()
if('2d' in layer.name):
fp.write('\tlayer[{0}] = model.hook(Cropping(border({1},{2},{3},{4})), layer[{5}]);\n'.format(
id, cfg['cropping'][0][0], cfg['cropping'][0][1], cfg['cropping'][1][0],cfg['cropping'][1][1], LI[inp][0]))
elif('1d' in layer.name):
fp.write('\tlayer[{0}] = model.hook(Cropping(border(0,0,{1},{2})), layer[{3}]);\n'.format(
id, cfg['cropping'][0], cfg['cropping'][1], LI[inp][0]))

# others
elif('concatenate' in layer.name):
Expand Down

0 comments on commit 1bbf1d3

Please sign in to comment.