-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support fuse bn into ConvTranspose. #106
base: master
Are you sure you want to change the base?
Conversation
44658d1
to
6e5ac70
Compare
Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024 |
6e5ac70
to
5d4d388
Compare
@daquexian Done, please review. |
Signed-off-by: wenyuchi.wyc <[email protected]>
5d4d388
to
ff1229e
Compare
Hello, i try to used this commit to fuse the bn layer and convtranspose layer in my model and find some bugs: From the doc of onnx website (https://onnx.ai/onnx/operators/onnx__ConvTranspose.html), the shape of weight array of convtranspose is (Cin, Cout, K, K), which is different to normal Conv layer (Cout, Cin, K, K). |
Hi, i would like to share my codes for fusing convtranspose and bn. It has been tested on my own model. I hope it will help others who have the same issue. import numpy as np
import onnx
import sclblonnx as so
model = onnx.load('../onnx/models/backbone_clean.onnx')
all_initializer = model.graph.initializer
all_node = model.graph.node
ConvTranspose_list = []
BatchNormalization_list = []
for i, node in enumerate(all_node):
# search convtranspose and batchnormalization
if node.op_type == "ConvTranspose":
# print(i, node.name, node.op_type, node.input, node.output)
ConvTranspose_list.append(node)
if node.op_type == "BatchNormalization":
# print(i, node.name, node.op_type, node.input, node.output)
BatchNormalization_list.append(node)
valid_ConvTranspose_list = []
for node in ConvTranspose_list:
output = node.output
for bn_node in BatchNormalization_list:
bn_inputs = bn_node.input
if output[0] in bn_inputs:
valid_ConvTranspose_list.append({"conv": node, "bn": bn_node})
continue
# print(valid_ConvTranspose_list)
param_dict = {}
for node in valid_ConvTranspose_list:
conv = node["conv"]
bn = node["bn"]
# find params
param_name = list(conv.input) + list(bn.input)
for i, initializer in enumerate(all_initializer):
if initializer.name in param_name:
param_dict[initializer.name] = onnx.numpy_helper.to_array(initializer)
# print(param_dict)
for node in valid_ConvTranspose_list:
conv = node["conv"]
bn = node["bn"]
bn_eps = bn.attribute[0].f
bn_mom = bn.attribute[1].f
bn_w = param_dict[bn.input[1]] # [Cout, ]
bn_b = param_dict[bn.input[2]] # [Cout, ]
bn_mean = param_dict[bn.input[3]] # [Cout, ]
bn_var = param_dict[bn.input[4]] # [Cout, ]
conv_w = param_dict[conv.input[1]] # [Cin, Cout, H, W]
if len(conv.input) > 2:
conv_b = param_dict[conv.input[2]]
else:
conv_b = np.zeros_like(bn_b) # [Cout, ]
conv_w_tran = conv_w.transpose(1, 0, 2, 3)
Cout = conv_w_tran.shape[0]
conv_w_reshape = conv_w_tran.reshape([Cout, -1])
w_bn = np.diag(bn_w / (np.sqrt(bn_eps + bn_var)))
new_conv_w = np.matmul(w_bn, conv_w_reshape).reshape(conv_w_tran.shape).transpose(1, 0, 2, 3)
bn_b_tmp = bn_b - (np.multiply(bn_w, bn_mean) / (np.sqrt(bn_eps + bn_var)))
new_conv_b = np.matmul(bn_w, conv_b) + bn_b_tmp
new_node = onnx.helper.make_node(
name=conv.name+'_bn',
op_type="ConvTranspose",
inputs=[conv.input[0], conv.name+'_bn.weights', conv.name+'_bn.bias'],
outputs=[bn.output[0]],
dilations=conv.attribute[0].ints,
group=conv.attribute[1].i,
kernel_shape=conv.attribute[2].ints,
pads=conv.attribute[3].ints,
strides=conv.attribute[4].ints
)
initializer_w = onnx.helper.make_tensor(
name=conv.name+'_bn.weights',
data_type=onnx.helper.TensorProto.DataType.FLOAT,
dims=new_conv_w.shape,
vals=new_conv_w.tobytes(),
raw=True
)
initializer_b = onnx.helper.make_tensor(
name=conv.name+'_bn.bias',
data_type=onnx.helper.TensorProto.DataType.FLOAT,
dims=new_conv_b.shape,
vals=new_conv_b.tobytes(),
raw=True
)
model.graph.initializer.append(initializer_w)
model.graph.initializer.append(initializer_b)
# insert node
for i, node in enumerate(all_node):
if conv.name == node.name:
model.graph.node.insert(i, new_node)
break
# clean node
model.graph.node.remove(conv)
model.graph.node.remove(bn)
onnx.checker.check_model(model)
onnx.save(model, '../onnx/models/backbone_fuse.onnx')
graph = so.graph_from_file('../onnx/models/backbone_fuse.onnx')
graph = so.clean(graph)
so.check(graph)
so.graph_to_file(graph, '../onnx/models/backbone_fuse.onnx') |
No description provided.