forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNestedTensorImpl.h
286 lines (262 loc) · 9.98 KB
/
NestedTensorImpl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
#pragma once
#include <ATen/MemoryOverlap.h>
#include <ATen/Tensor.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/irange.h>
namespace at::native {
struct NestedTensorImpl;
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
at::Tensor construct_nested_strides(const at::Tensor& nested_size);
at::Tensor construct_offsets(const at::Tensor& nested_size);
struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
explicit NestedTensorImpl(
Storage storage,
c10::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
at::Tensor nested_sizes,
at::Tensor nested_strides,
at::Tensor storage_offsets);
explicit NestedTensorImpl(
const at::Tensor& buffer,
at::Tensor nested_sizes,
at::Tensor nested_strides,
at::Tensor storage_offsets);
// assume contiguous, `nested_strides` and `offsets`
// can be infered from `nested_sizes`
explicit NestedTensorImpl(
const at::Tensor& buffer,
const at::Tensor& nested_sizes);
// This constructor is used creating view tensors from nested tensors
explicit NestedTensorImpl(
c10::TensorImpl::ImplType impl_type,
const at::Tensor& base_tensor,
at::Tensor nested_sizes,
at::Tensor nested_strides,
at::Tensor storage_offsets);
// TODO: don't expose private implementation details like this; in
// particular, resizing this tensor will mess up our dim() and
// callers cannot fix it.
const Tensor& get_nested_sizes() const {
return nested_sizes_;
}
// TODO: don't expose private implementation details like this
const Tensor& get_nested_strides() const {
return nested_strides_;
}
const Tensor& get_storage_offsets() const {
return storage_offsets_;
}
// Returns nullopt if the ith dimension is irregular. The ith dimension
// of a NestedTensor is regular if the unbound tensors match in
// size at the (i-1)th dimension.
std::optional<int64_t> opt_size(int64_t d) const;
int64_t size(int64_t d) const {
std::optional<int64_t> optional_size = this->opt_size(d);
TORCH_CHECK(
optional_size.has_value(),
"Given dimension ",
d,
" is irregular and does not have a size.");
return *optional_size;
}
/**
* Return a view of the nested tensor as a 1 dimensional contiguous tensor.
*
* The buffer tensor created by this function shares the same storage_impl as
* the original nested tensor, and therefore can be seen as a view.
*
* @return A newly constructed view tensor
*/
at::Tensor get_buffer() const {
TORCH_CHECK(
nested_tensor_impl_is_contiguous(this),
"NestedTensor must be contiguous to get buffer.");
return get_unsafe_storage_as_tensor();
}
/**
* If possible use get_buffer() instead. This function returns the storage
* as a tensor directly, which is not safe to use in general. If using this
* function, The caller must ensure to account for nested_sizes,
* nested_strides and storage_offsets.
*
* @return A newly constructed view tensor
*/
at::Tensor get_unsafe_storage_as_tensor() const {
auto buffer_key_set_ = generate_buffer_key_set();
const auto buffer_size = get_buffer_size();
auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
buffer_tensor_impl->set_sizes_contiguous(
c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
return Tensor(buffer_tensor_impl);
}
size_t get_buffer_size() const {
return storage_.nbytes() / data_type_.itemsize();
}
protected:
const char* tensorimpl_type_name() const override;
// TODO: numel_custom and is_contiguous_custom can be profitably overridden
// with real implementations
int64_t numel_custom() const override;
c10::SymInt sym_numel_custom() const override;
bool is_contiguous_custom(MemoryFormat) const override;
int64_t size_custom(int64_t d) const override {
return this->size(d);
}
c10::SymInt sym_size_custom(int64_t d) const override {
return c10::SymInt{this->size(d)};
}
IntArrayRef sizes_custom() const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
IntArrayRef strides_custom() const override;
c10::SymIntArrayRef sym_strides_custom() const override;
// this one is real
int64_t dim_custom() const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
copy_tensor_metadata(
/*src_impl=*/impl.get(),
/*dest_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
}
private:
// Must be called after any changes to our dim() to sync the state
// to TensorImpl.
void refresh_dim();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor nested_sizes_, nested_strides_;
// The starting positions of the underlying tensors in contiguous buffer
// i.e. the buffer memory offsets to get the underlying tensors
// The reason to keep this metadata is that, without strong enough constraint
// it cannot be derived from `nested_sizes_`
// and `nested_strides_`:
// 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
// this can happen e.g. after slicing a nested tensor
// 2. when multiple tensors share a same memory
// 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
// Some strong enough constraints are:
// 1. every underlying tensor is contiguous in memory
// && nesting in ascending order
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor storage_offsets_;
// NOTE: -1 here means the size is missing
// Optional to allow it to be computed lazily from nested.
// TODO: maybe we can remove this metadata since
// we can compute it from `nested_sizes_`
mutable std::optional<std::vector<int64_t>> opt_sizes_;
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
/**
* Generates a non-nested key_set from a nested tensor.
*
* For many nested tensor kernel implementations a buffer tensor
* is generated and redispatched to a non-nested kernel this function
* generates the key set used by that buffer tensor
*
* @return Appropriate key set for non-nested tensor
*/
inline c10::DispatchKeySet generate_buffer_key_set() const {
auto buffer_key_set = this->key_set();
const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
// Remove nested tensor specific keys
buffer_key_set = buffer_key_set -
c10::DispatchKeySet{
c10::DispatchKey::NestedTensor,
c10::DispatchKey::AutogradNestedTensor};
// Add dense tensor specific keys
buffer_key_set =
buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
buffer_key_set = Autograd
? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
: buffer_key_set;
return buffer_key_set;
}
};
inline NestedTensorImpl* get_nested_tensor_impl_or_null(
const at::Tensor& tensor) {
if (tensor.is_nested()) {
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
return nullptr;
}
inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
TORCH_CHECK(
tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
int64_t ntensors = nt->size(0);
if (ntensors == 0) {
return true;
}
const Tensor &sizemat = nt->get_nested_sizes(),
&stridemat = nt->get_nested_strides();
const int64_t* offsets_ptr =
nt->get_storage_offsets().const_data_ptr<int64_t>();
int64_t orig_dim = sizemat.size(1);
// nesting scalars
if (orig_dim == 0) {
// each scalar must be contiguous
// if there is blank memory between underlying scalars
for (int64_t i = 0; i < ntensors; i++) {
if (offsets_ptr[i] != i) {
return false;
}
}
}
// nesting tensors
else {
// if any underlying tensor is non-contiguous
const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
*stridemat_ptr = stridemat.const_data_ptr<int64_t>();
for (int64_t i = 0; i < ntensors; i++) {
if (stridemat_ptr[orig_dim - 1] != 1) {
return false;
}
int64_t product = sizemat_ptr[orig_dim - 1];
for (int64_t j = orig_dim - 2; j >= 0; j--) {
if (stridemat_ptr[j] != product) {
return false;
}
product *= sizemat_ptr[j];
}
sizemat_ptr += orig_dim;
stridemat_ptr += orig_dim;
}
// if there is blank memory between underlying tensors
if (offsets_ptr[0] != 0) {
return false;
}
sizemat_ptr = sizemat.const_data_ptr<int64_t>();
stridemat_ptr = stridemat.const_data_ptr<int64_t>();
for (int64_t i = 1; i < ntensors; i++) {
if (offsets_ptr[i] !=
offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
return false;
}
sizemat_ptr += orig_dim;
stridemat_ptr += orig_dim;
}
}
// everything is fine
return true;
}
inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
return get_nested_tensor_impl(tensor)->get_nested_sizes();
}
} // namespace at::native