forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
import.h
153 lines (127 loc) · 4.89 KB
/
import.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#pragma once
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <istream>
namespace caffe2::serialize {
class ReadAdapterInterface;
} // namespace caffe2::serialize
namespace torch::jit {
class DeserializationStorageContext;
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
const std::string& filename,
std::optional<c10::Device> device = c10::nullopt,
bool load_debug_files = true);
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::istream& in,
std::optional<c10::Device> device = c10::nullopt,
bool load_debug_files = true);
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
std::optional<c10::Device> device = c10::nullopt,
bool load_debug_files = true);
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
const std::string& filename,
std::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files = true,
bool restore_shapes = false);
// For reading unified serialization format from torch.Package
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
std::shared_ptr<torch::jit::DeserializationStorageContext> storage_context,
std::optional<at::Device> device,
const std::string& ts_id /* torchscript identifier inside package */);
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::istream& in,
std::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files = true,
bool restore_shapes = false);
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
std::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files = true);
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
std::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files = true);
/// Loads a serialized `Module` from the given `istream`.
///
/// The istream must contain a serialized `Module`, exported via
/// `torch::jit::ExportModule` in C++.
TORCH_API Module load(
std::istream& in,
std::optional<c10::Device> device = c10::nullopt,
bool load_debug_files = true);
TORCH_API Module load(
std::istream& in,
std::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files = true);
/// Loads a serialized `Module` from the given `filename`.
///
/// The file stored at the location given in `filename` must contain a
/// serialized `Module`, exported either via `ScriptModule.save()` in
/// Python or `torch::jit::ExportModule` in C++.
TORCH_API Module load(
const std::string& filename,
std::optional<c10::Device> device = c10::nullopt,
bool load_debug_files = true);
TORCH_API Module load(
const std::string& filename,
std::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files = true);
/// Loads a serialized `Module` from the given shared_ptr `rai`.
///
/// The reader adapter, which is for customized input stream, must contain a
/// serialized `Module`, exported either via `ScriptModule.save()` in
/// Python or `torch::jit::ExportModule` in C++.
TORCH_API Module load(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
std::optional<c10::Device> device = c10::nullopt,
bool load_debug_files = true);
TORCH_API Module load(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
std::optional<c10::Device> device,
ExtraFilesMap& extra_files,
bool load_debug_files = true);
TORCH_API Module jitModuleFromSourceAndConstants(
const IValue& ivalue,
const ExtraFilesMap& source,
const std::vector<IValue>& constants,
int32_t version);
TORCH_API Module parse_and_initialize_jit_module(
const std::shared_ptr<char>& data,
size_t size,
ExtraFilesMap& extra_files,
std::optional<at::Device> device = c10::nullopt);
TORCH_API Module load_jit_module_from_file(
const std::string& filename,
ExtraFilesMap& extra_files,
std::optional<at::Device> device = c10::nullopt);
TORCH_API Module load_jit_module_from_stream(
std::istream& in,
ExtraFilesMap& extra_files,
std::optional<at::Device> device = c10::nullopt);
TORCH_API Module parse_and_initialize_jit_module(
const std::shared_ptr<char>& data,
size_t size,
ExtraFilesMap& extra_files,
std::optional<at::Device> device);
TORCH_API c10::intrusive_ptr<c10::ivalue::Object> ObjLoaderFunc(
const at::StrongTypePtr& type,
IValue input);
} // namespace torch::jit