This repository has been archived by the owner on Apr 5, 2023. It is now read-only.
forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Unique.cu
164 lines (141 loc) · 6.08 KB
/
Unique.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/execution_policy.h>
#include <tuple>
#include <thrust/unique.h>
#include <thrust/sort.h>
#include <thrust/scan.h>
#include <thrust/scatter.h>
namespace at {
namespace native{
namespace {
template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_cuda_template(
const Tensor& self,
const bool return_inverse) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
const Tensor& input = self.contiguous();
int64_t num_inp = input.numel();
const scalar_t* input_data = input.data<scalar_t>();
//sort & unique
Tensor output = input.clone();
output = output.view(-1);
scalar_t* output_data = output.data<scalar_t>();
Tensor inverse_indices;
if (!return_inverse) {
inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
thrust::sort(policy, output_data, output_data + num_inp);
} else {
Tensor sorted_indices = at::arange(0, num_inp, self.type().toScalarType(kLong));
int64_t* sorted_indices_ptr = sorted_indices.data<int64_t>();
thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
Tensor inv_loc = at::empty({num_inp}, self.type().toScalarType(kLong));
inverse_indices = at::empty({num_inp}, self.type().toScalarType(kLong));
int64_t* inv_loc_ptr = inv_loc.data<int64_t>();
int64_t* inverse_indices_ptr = inverse_indices.data<int64_t>();
thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }});
inv_loc[0] = 0;
thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
inverse_indices.resize_(input.sizes());
}
int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
output.resize_(num_out);
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}
template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
input_flat = input_flat.contiguous().view({input_flat.size(0), -1});
scalar_t* input_flat_ptr = input_flat.data<scalar_t>();
Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
int64_t* indices_ptr = indices.data<int64_t>();
int64_t numel = input_flat.size(1);
// sort indices using data
thrust::sort(policy, indices_ptr, indices_ptr + indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_flat_ptr[i + a * numel];
scalar_t rhs = input_flat_ptr[i + b * numel];
if (lhs < rhs) {
return true;
} else if (lhs > rhs) {
return false;
}
}
return false;
});
Tensor input_sorted = input_flat.index_select(0, indices);
// get unique tensors
scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_sorted_ptr[i + a * numel];
scalar_t rhs = input_sorted_ptr[i + b * numel];
if (lhs != rhs) {
return false;
}
}
return true;
});
input_sorted_indices.resize_(last - input_sorted_indices_ptr);
Tensor output = input_sorted.index_select(0, input_sorted_indices);
// reshape back
auto new_sizes = std::vector<int64_t>(orig_sizes);
new_sizes[0] = -1;
output = output.view(new_sizes);
output = output.transpose(0, dim);
// calculate inverse indices
Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
if (return_inverse) {
int64_t size = self.size(dim);
inverse_indices.resize_(size);
Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
mask[0] = 1;
for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
if (!at::equal(input_sorted[i], input_sorted[i+1])) {
mask[i+1] = 1;
} else {
mask[i+1] = 0;
}
}
Tensor imask = at::cumsum(mask, 0) - 1;
for (int i = 0; i < indices.size(0); ++i) {
inverse_indices[indices[i]] = imask[i];
}
}
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}
} // namespace
std::tuple<Tensor, Tensor>
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
return _unique_cuda_template<scalar_t>(self, return_inverse);
});
}
std::tuple<Tensor, Tensor>
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
});
}
} // namespace native
} // namespace at