forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_autograd.py
255 lines (209 loc) · 9.25 KB
/
gen_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""
To run this file by hand from the root of the PyTorch
repository, run:
python -m tools.autograd.gen_autograd \
build/aten/src/ATen/Declarations.yaml \
$OUTPUT_DIR
Where $OUTPUT_DIR is where you would like the files to be
generated. In the full build system, OUTPUT_DIR is
torch/csrc/autograd/generated/
"""
# gen_autograd.py generates C++ autograd functions and Python bindings.
#
# It delegates to the following scripts:
#
# gen_autograd_functions.py: generates subclasses of torch::autograd::Node
# gen_variable_type.py: generates VariableType.h which contains all tensor methods
# gen_python_functions.py: generates Python bindings to THPVariable
#
import argparse
import copy
import os
import yaml
import re
from collections import defaultdict
from .utils import YamlLoader, split_name_params
# See NOTE [ Autograd View Variables ] in variable.h for details.
# A map: function name => two options:
# 1. name of the argument that all outputs are view of
# 2. map: output idx => name of the argument that this result is view of
VIEW_FUNCTIONS = {
'numpy_T': 'self',
'alias': 'self',
'as_strided': 'self',
'diagonal': 'self',
'expand': 'self',
'narrow': 'self',
'permute': 'self',
'select': 'self',
'slice': 'self',
'squeeze': 'self',
't': 'self',
'transpose': 'self',
'unfold': 'self',
'unsqueeze': 'self',
'view': 'self',
'unbind': 'self',
'_indices': 'self',
'_values': 'self',
'indices': 'self',
'values': 'self',
# sparse_coo ctor output should really be views of both indices and values,
# but we only supports making as view of a single varible, and indices is
# discrete anyways.
# FIXME: clone indices on construction.
'sparse_coo_tensor_with_dims_and_tensors': 'values',
}
# note: some VIEW_FUNCTIONS are just compositions of the view functions above
# this list contains both the root view functions and any that are purely composed
# of viewing functions, and is used by the JIT to determine when an operator
# returns a view of its inputs
RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({'chunk', 'split'})
def format_return_type(returns):
if len(returns) == 0:
return 'void'
elif len(returns) == 1:
return returns[0]['type']
else:
return_types = [r['type'] for r in returns]
return 'std::tuple<{}>'.format(','.join(return_types))
def get_simple_type(arg):
simple_type = arg['type']
simple_type = simple_type.replace(' &', '').replace('const ', '')
simple_type = simple_type.replace('Generator *', 'Generator')
opt_match = re.match(r'c10::optional<(.+)>', simple_type)
if opt_match:
simple_type = '{}?'.format(opt_match.group(1))
return simple_type
def load_aten_declarations(path):
with open(path, 'r') as f:
declarations = yaml.load(f, Loader=YamlLoader)
# enrich declarations with additional information
selected_declarations = []
for declaration in declarations:
if declaration.get('deprecated'):
continue
for arg in declaration['arguments']:
arg['simple_type'] = get_simple_type(arg)
for ret in declaration['returns']:
ret['simple_type'] = get_simple_type(ret)
declaration['formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['arguments']]
declaration['args'] = [arg['name'] for arg in declaration['arguments']]
declaration['type_method_formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['arguments']]
declaration['type_method_args'] = [arg['name'] for arg in declaration['arguments']]
declaration['api_name'] = declaration['name']
declaration['return_type'] = format_return_type(declaration['returns'])
declaration['base_name'] = declaration['name']
selected_declarations.append(declaration)
return selected_declarations
def load_deprecated_signatures(aten_decls, deprecated_path):
def group_declarations_by_signature():
d = defaultdict(list)
for declaration in aten_decls:
name = declaration['name']
base_name = name[:-1] if declaration['inplace'] else name
simple_types = [arg['simple_type'] for arg in declaration['arguments']]
signature = '{}({})'.format(base_name, ', '.join(simple_types))
d[signature].append(declaration)
return d
with open(deprecated_path, 'r') as f:
deprecated_defs = yaml.load(f, Loader=YamlLoader)
declarations = []
declarations_by_signature = group_declarations_by_signature()
def get_signature(name, params, call_args):
# create a mapping of parameter name to parameter type
types = dict([param.split(' ')[::-1] for param in params if param != '*'])
# if the name in the call is not in the parameter list, assume it's
# a literal Scalar
rearranged_types = [types.get(arg, 'Scalar') for arg in call_args]
return '{}({})'.format(name, ', '.join(rearranged_types))
for deprecated in deprecated_defs:
aten_name, call_args = split_name_params(deprecated['aten'])
name, params = split_name_params(deprecated['name'])
signature = get_signature(aten_name, params, call_args)
for declaration in declarations_by_signature[signature]:
declaration = copy.deepcopy(declaration)
declaration['deprecated'] = True
declaration['call_args'] = call_args
call_arg_to_idx = {arg: i for i, arg in enumerate(call_args)}
original_args = declaration['arguments']
# Create an arguments list that uses the types from the original
# ATen declaration, but the ordering and parameter names from
# the deprecated overload. Any default parameter values from the
# original ATen declaration are ignored.
arguments = []
kwarg_only = False
for param in params:
if param == '*':
kwarg_only = True
continue
_, param_name = param.split(' ')
original = original_args[call_arg_to_idx[param_name]]
arguments.append({
'name': param_name,
'kwarg_only': kwarg_only,
'type': original['type'],
'simple_type': original['simple_type'],
'dynamic_type': original['dynamic_type'],
'output': original.get('output', False),
})
declaration['arguments'] = arguments
declarations.append(declaration)
return declarations
def gen_autograd(aten_path, out, autograd_dir, disable_autograd=False):
aten_decls = load_aten_declarations(aten_path)
# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
autograd_functions = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
template_path = os.path.join(autograd_dir, 'templates')
# Generate VariableType.h/cpp
if not disable_autograd:
from .gen_variable_type import gen_variable_type
gen_variable_type(out, aten_decls, template_path)
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_lib
gen_autograd_functions_lib(
out, autograd_functions, template_path)
# Generate variable_factories.h
from .gen_variable_factories import gen_variable_factories
gen_variable_factories(
out, aten_decls, template_path, disable_autograd=disable_autograd)
def gen_autograd_python(aten_path, out, autograd_dir):
# TODO Deduplicate these four variable assignments
aten_decls = load_aten_declarations(aten_path)
# Parse and load derivatives.yaml
from .load_derivatives import load_derivatives
autograd_functions = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
template_path = os.path.join(autograd_dir, 'templates')
# Load deprecated signatures
deprecated = load_deprecated_signatures(
aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_python
gen_autograd_functions_python(
out, autograd_functions, template_path)
# Generate Python bindings
from . import gen_python_functions
gen_python_functions.gen_py_variable_methods(
out, aten_decls + deprecated, template_path)
gen_python_functions.gen_py_torch_functions(
out, aten_decls + deprecated, template_path)
gen_python_functions.gen_py_nn_functions(
out, aten_decls, template_path)
def main():
parser = argparse.ArgumentParser(
description='Generate autograd C++ files script')
parser.add_argument('declarations', metavar='DECL',
help='path to Declarations.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
parser.add_argument('autograd', metavar='AUTOGRAD',
help='path to autograd directory')
args = parser.parse_args()
gen_autograd(args.declarations, args.out, args.autograd)
if __name__ == '__main__':
main()