-
Notifications
You must be signed in to change notification settings - Fork 65
/
update_ops.py
72 lines (56 loc) · 2.38 KB
/
update_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding: utf-8 -*-
# The codegen script to build oopb opset functions
import os
from onnxconverter_common import onnx_ops
OPTIONAL_ARG_BEGIN = 5
OPTIONAL_ARG_OMIT = 1
def str_func_kwarg(func_obj):
lst_covars = func_obj.__code__.co_varnames[OPTIONAL_ARG_BEGIN:]
lst_defaults = func_obj.__defaults__[OPTIONAL_ARG_OMIT:]
output = ''
for covar, val in zip(lst_covars, lst_defaults):
def_val = str(val)
if isinstance(val, str):
def_val = '\'' + def_val + '\''
output += ", {}={}".format(covar, def_val)
return output
def str_pair_kwarg(func_obj):
code_obj = func_obj.__code__
lst_covars = code_obj.co_varnames[OPTIONAL_ARG_BEGIN:code_obj.co_argcount]
output = ''
for covar in lst_covars:
output += ", {}={}".format(covar, covar)
return output
def format_line(line):
return line
def gen_apply_func(name, func_obj):
output = "\n"
output += " def {}(self, inputs, name=None, outputs=None{}):\n".format(name[6:], str_func_kwarg(func_obj))
output += " return self.apply_op(onnx_ops.{}, inputs, name, outputs{})".format(name,
str_pair_kwarg(func_obj))
return output
HINT_LINE = " # !!!!CODE-AUTOGEN!!!! #"
fname_oopb = 'onnxconverter_common/oopb.py'
with open(fname_oopb, 'r') as old_one:
with open(fname_oopb + '.tmp', 'w') as new_one:
while True:
line = old_one.readline()
if line.strip() == HINT_LINE.strip():
break
new_one.write(line)
print(HINT_LINE, file=new_one)
print(
" # The following code was generated by ../update_ops.py",
file=new_one)
apply_fx = {v1: v2 for v1, v2 in onnx_ops.__dict__.items() if v1.startswith('apply_')}
for v1, v2 in apply_fx.items():
if v1.startswith('apply_constant'): # skip the constant and constant_of_shape since they are overrided.
continue
code_obj = v2.__code__
stardard_args = ('container', 'operator_name')
args = code_obj.co_varnames[3:3 + len(stardard_args)]
if tuple(args) == stardard_args:
print(gen_apply_func(v1, v2), file=new_one)
os.remove(fname_oopb)
os.rename(fname_oopb + '.tmp', fname_oopb)
print("{} updated successfully!".format(fname_oopb))