forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NestedTensorImpl.cpp
99 lines (90 loc) · 3.56 KB
/
NestedTensorImpl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/NestedTensorImpl.h>
#include <c10/core/DispatchKey.h>
namespace at {
namespace native {
inline std::vector<int64_t> construct_opt_sizes(const at::Tensor& sizes) {
if (sizes.dim() == 0) {
return std::vector<int64_t>();
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
std::vector<int64_t> result(1, sizes.sizes()[0]);
if (sizes.dim() > 0) {
size_t nested_dim = result.size();
int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
result.resize(nested_dim + sizes.sizes()[1]);
int64_t sizes_size_0 = sizes.sizes()[0];
int64_t sizes_size_1 = sizes.sizes()[1];
for (const auto i : c10::irange(sizes_size_1)) {
result[nested_dim + i] = sizes_ptr[i];
}
for (const auto j : c10::irange(sizes_size_1)) {
for (const auto i : c10::irange(sizes_size_0)) {
if (result[nested_dim + j] &&
(result[nested_dim + j] != sizes_ptr[i * sizes.size(1) + j])) {
result[nested_dim + j] = -1;
}
}
}
}
return result;
}
NestedTensorImpl::NestedTensorImpl(
at::Tensor buffer,
at::Tensor nested_size_tensor)
: TensorImpl(
(c10::DispatchKeySet(DispatchKey::NestedTensor) |
c10::DispatchKeySet(buffer.is_cuda() ? BackendComponent::CUDABit : BackendComponent::CPUBit)),
buffer.dtype(),
buffer.device()),
buffer_(std::move(buffer)),
nested_size_tensor_(std::move(nested_size_tensor)),
opt_sizes_(construct_opt_sizes(nested_size_tensor_))
{
TORCH_WARN_ONCE(
"The PyTorch API of nested tensors is in prototype stage and will change "
"in the near future.");
TORCH_INTERNAL_ASSERT(buffer_.is_cuda() || buffer_.is_cpu(), "NestedTensorImpl buffer must be either CUDA or CPU but got ", buffer_);
TORCH_INTERNAL_ASSERT(nested_size_tensor_.is_contiguous());
int64_t size_dim = nested_size_tensor_.dim();
TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
remove_autograd_key();
key_set_ =
key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView});
refresh_dim();
set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
}
void NestedTensorImpl::refresh_dim() {
const auto my_dim = nested_size_tensor_.dim() ? nested_size_tensor_.sizes()[1] + 1 : 1;
sizes_and_strides_.resize(my_dim);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim() == my_dim);
}
int64_t NestedTensorImpl::dim_custom() const {
return dim_default();
}
int64_t NestedTensorImpl::numel_custom() const {
TORCH_CHECK(false, "numel is disabled.");
}
bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const {
TORCH_CHECK(false, "is_contiguous is disabled.");
}
IntArrayRef NestedTensorImpl::sizes_custom() const {
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
}
c10::SymIntArrayRef NestedTensorImpl::sym_sizes_custom() const {
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
}
c10::SymIntArrayRef NestedTensorImpl::sym_sizes() const {
return sym_sizes_custom();
}
IntArrayRef NestedTensorImpl::strides_custom() const {
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");
}
const char* NestedTensorImpl::tensorimpl_type_name() const {
return "NestedTensorImpl";
}
} // namespace native
} // namespace at