forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
function.h
137 lines (111 loc) · 3.95 KB
/
function.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#pragma once
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/utils/memory.h>
#include <mutex>
namespace torch {
namespace jit {
using Kwargs = std::unordered_map<std::string, IValue>;
TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
// A Function is a pure Graph with no implicit `self` object bound.
// It contains schema information, and the executor that manages the
// execution of the function. script::Method is a wrapper around a
// underlying Function that also provides a `self` object.
struct TORCH_API Function {
Function(
c10::QualifiedName name,
std::shared_ptr<Graph> graph,
std::function<void(Function&)> function_creator)
: name_(std::move(name)),
graph_(std::move(graph)),
function_creator_(std::move(function_creator)) {}
void run(Stack& stack);
void run(Stack&& stack);
IValue operator()(
std::vector<IValue> stack,
const Kwargs& kwargs = Kwargs());
std::shared_ptr<Graph> graph() const {
return graph_;
}
std::shared_ptr<Graph> optimized_graph() const {
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
if (optimized_graph_) {
return *optimized_graph_;
}
optimized_graph_ = graph_->copy();
preoptimizeGraph(*optimized_graph_);
return *optimized_graph_;
}
const c10::QualifiedName& qualname() const {
return name_;
}
const std::string& name() const {
return name_.name();
}
// if this isn't yet defined, run its method_creator function
void ensure_defined();
size_t num_inputs() const {
return graph()->inputs().size();
}
Function& setSchema(FunctionSchema schema) {
schema_ = make_unique<FunctionSchema>(std::move(schema));
return *this;
}
const FunctionSchema& getSchema() const;
std::string pretty_print_schema() const {
AT_ASSERT(schema_);
std::stringstream ss;
ss << *schema_;
return ss.str();
}
GraphExecutorState getDebugState() {
return get_executor().getDebugState();
}
bool is_optimized() const {
AT_WARN(
"Function::is_optimized() is deprecated and always returns true. "
"Please use getGraphExecutorOptimize()");
return true;
}
void check_single_output() {
TORCH_CHECK(
graph()->outputs().size() == 1,
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
}
GraphExecutor& get_executor() {
ensure_defined();
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
if (executor_) {
return executor_;
}
check_single_output();
executor_ = GraphExecutor(optimized_graph());
return executor_;
}
private:
c10::QualifiedName name_;
// The original, non-optimized graph
std::shared_ptr<Graph> graph_; // for debugging and for inlining
// Optimized graph, computed lazily. Used for inlining.
// Note: this graph is not specialized, only generic optimizations are applied
// here.
mutable c10::optional<std::shared_ptr<Graph>> optimized_graph_;
// Functions are invokable from multiple threads, so this lock needs to be
// held when we're initializing graph executor for the first time or computing
// the optimized graph.
// We're using reentrant mutex so that we don't need to worry about causing a
// deadlock by calling one method from another (e.g. optimized_graph() from
// get_executor()).
mutable std::recursive_mutex compile_mutex;
GraphExecutor executor_; // for execution
// an optional function that actually creates the method when
// ensure_defined() is called. This is used by the compiler so
// that it can construct methods out of order
std::function<void(Function&)> function_creator_;
// if absent, then we generate a default schema based on the graph
// mutable because getSchema caches the default schema if one is requested
// before a call to setSchema
mutable std::unique_ptr<FunctionSchema> schema_;
};
} // namespace jit
} // namespace torch