forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
scope.h
208 lines (173 loc) · 6.64 KB
/
scope.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#pragma once
#include <ATen/core/jit_type.h>
#include <ATen/core/symbol.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <unordered_map>
namespace torch {
namespace jit {
struct ModuleInstanceInfo;
constexpr size_t kModuleInstanceInfo = 2;
namespace utils {
std::string get_module_info(const ModuleInstanceInfo& module_instance_info);
} // namespace utils
// Scope is a node of a trie that represents the tree of nested scopes.
// Individual scopes are pushed and popped from Graph, which holds a
// pointer to the current scope. Each Node in Graph holds a pointer
// to the scope that was current when the node was created.
// The trie never needs to shrink, it only grows until it is disposed
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
// will always be valid as long as Graph is alive.
struct Scope;
using ScopePtr = c10::intrusive_ptr<Scope>;
using c10::Symbol;
struct TORCH_API Scope : public c10::intrusive_ptr_target {
private:
ScopePtr parent_;
Symbol name_;
ScopePtr intrusive_from_this();
public:
Scope();
Scope(ScopePtr parent, Symbol name);
ScopePtr push(Symbol name);
ScopePtr parent();
bool isRoot() const;
bool isBlank() const;
ScopePtr getRoot();
size_t getDepth();
Symbol name() const;
std::string namesFromRoot(const std::string& separator = "/") const;
};
struct Function;
struct InlinedCallStack;
/**
* ModuleInstanceInfo is a structure to include the module type and instance
* name. It also provide public methods to get the pointer to module type and
* instance name.
*
* This structure is mainly used as a private member in InlinedCallStack, such
* that one can follow the callstack to find the relevant module hierarchy.
*/
struct ModuleInstanceInfo {
private:
c10::ClassTypePtr module_type_{nullptr};
std::string instance_name_;
public:
ModuleInstanceInfo() = default;
ModuleInstanceInfo(c10::ClassTypePtr module_type, std::string instance_name);
c10::ClassTypePtr class_type() {
return module_type_;
}
c10::ClassTypePtr class_type() const {
return module_type_;
}
std::string instance_name() const {
return instance_name_;
}
bool operator==(const ModuleInstanceInfo& rhs) const {
return (class_type() == rhs.class_type()) &&
(instance_name() == rhs.instance_name());
}
};
/**
* InlinedCallStack is an element in a list representing callstack of functions
* that have been inlined.
*
* Each such element holds info about the current callsite (Function and
* SourceRange) and a pointer to the next element in the list. The last element
* in the list represents the innermost function that was inlined.
*
* For instance, if a node has a callstack
* [foo, source_range1] -> [bar, source_range2]
* it means that this node was originally from function 'bar' that was called
* at 'source_range2' in function 'foo' that was called in the current function
* at 'source_range1'.
*
* If a node did not come from any inlined function, its callstack will be
* empty.
*
* The callstack lists only grow, we never remove elements from them, which
* allows us to reuse same elements in different lists. For instance, if we
* inline function 'bar' to 'foo' and then inline 'foo' to two functions 'ham'
* and 'baz', the callstacks would look like:
*
* [baz, source_range3] --
* \
* --> [foo, source_range1] -> [bar, source_range2]
* /
* [ham, source_range4] --
*/
using InlinedCallStackPtr = c10::intrusive_ptr<InlinedCallStack>;
using InlinedCallStackEntry =
std::tuple<Function*, SourceRange, c10::optional<ModuleInstanceInfo>>;
struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target {
private:
c10::optional<InlinedCallStackPtr> callee_;
Function* fn_;
// Reason for fn_name_ even though we have fn_
// Serialized callstack is used in circustmances where InlinedCallstack
// cannot be constructed during runtime, e.g. mobile runtime or
// delegated backends.
// Since in those cases we do not have Function* we store function name
// fn_name does not give you access to the same information that Function*
// does, however in mobile/delegated backend runtime we use InlindedCallStack
// for exception stack and for that purpose fn_name_ suffices.
std::string fn_name_;
SourceRange source_range_;
InlinedCallStackPtr intrusive_from_this();
c10::optional<ModuleInstanceInfo> module_instance_info_;
public:
// Constructor for a leaf callstack node.
InlinedCallStack(Function* fn, SourceRange source_range);
// Constructor for a leaf callstack node.
InlinedCallStack(
Function* fn,
SourceRange source_range,
c10::optional<ModuleInstanceInfo> module_instance_info);
// Constructor for an inner callstack node.
InlinedCallStack(
InlinedCallStackPtr callee,
Function* fn,
SourceRange source_range);
InlinedCallStack(
InlinedCallStackPtr callee,
Function* fn,
SourceRange source_range,
c10::optional<ModuleInstanceInfo> module_instance_info);
// Return next element in the callstack list.
c10::optional<InlinedCallStackPtr> callee() const;
// Return module instance associated with the current element.
c10::optional<ModuleInstanceInfo> module_instance() const;
// Returns the source range of the node
SourceRange source_range() const;
Function* function() const;
void set_function_name(std::string fn_name);
std::string function_name() const;
// Return callstack as a vector of [Function, SourceRange] pairs.
std::vector<InlinedCallStackEntry> vec();
void setCallee(c10::optional<InlinedCallStackPtr>);
bool operator==(const InlinedCallStack& rhs) const {
// No need to compare fn_, since source_range equivalence check
// should suffice.
return (module_instance().has_value() ==
rhs.module_instance().has_value()) &&
(module_instance().has_value() &&
module_instance().value() == rhs.module_instance().value()) &&
callee() == rhs.callee() && source_range() == rhs.source_range();
}
bool operator!=(const InlinedCallStack& rhs) const {
return !(*this == rhs);
}
};
// {source range, node name, InlinedCallStack}
// We store node name because same debug infor will be used for
// profiling as well, so we need to know op names as well.
using DebugInfoTuple =
std::tuple<SourceRange, std::string, InlinedCallStackPtr>;
constexpr size_t kDebugInfoTupleSourceRangeIndex{0};
constexpr size_t kDebugInfoTupleNodeNameIndex{1};
constexpr size_t kDebugInfoTupleInlinedCSIndex{2};
} // namespace jit
} // namespace torch