forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DepthwiseConv2d.cu
638 lines (567 loc) · 25.2 KB
/
DepthwiseConv2d.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/AccumulateType.h>
#include <ATen/div_rtn.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/cuda/block_reduce.cuh>
#include <ATen/native/Resize.h>
#include <ATen/native/IndexingUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/_conv_depthwise2d_native.h>
#endif
namespace at {
namespace native {
namespace {
using at::cuda::detail::CUDA_NUM_THREADS;
using at::cuda::detail::GET_BLOCKS;
template <typename scalar_t, int ndim, template <typename U> class PtrTraits = DefaultPtrTraits>
PackedTensorAccessor32<scalar_t, ndim, PtrTraits> dummy_packed_accessor32() {
std::array<int64_t, ndim> zeros{};
return {nullptr, zeros.data(), zeros.data()};
}
template <int kSize, typename scalar_t, typename index_t>
__global__ void conv_depthwise2d_forward_kernel(
const PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> input,
PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> output,
const PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> weight,
const PackedTensorAccessor32<scalar_t, 1, DefaultPtrTraits> bias,
bool biasEnabled,
index_t totalElements,
const int outputChannels,
const int depthwiseMultiplier,
const int inputWidth, const int inputHeight,
const int outputWidth, const int outputHeight,
const int kernelWidth, const int kernelHeight,
const int strideWidth, const int strideHeight,
const int padWidth, const int padHeight,
const int dilationWidth, const int dilationHeight) {
using acc_t = at::acc_type<scalar_t, true>;
const int KW_LIMIT = (kSize != 0) ? kSize : kernelWidth;
const int KH_LIMIT = (kSize != 0) ? kSize : kernelHeight;
CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
//calculate n,c,h,w indices, replacing modulos by divide and multiply add,
//result is same as would be in the code below
//const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth
//const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth
//const int h = (linearIndex / outputWidth) % outputHeight;
//const int w = linearIndex % outputWidth;
int indtmp1 = linearIndex/outputWidth;
const int w = linearIndex - indtmp1 * outputWidth;
int indtmp2 = indtmp1/outputHeight;
const int h = indtmp1 - indtmp2 * outputHeight;
indtmp1 = indtmp2;
indtmp2 = indtmp1/outputChannels;
const int c = indtmp1 - indtmp2 * outputChannels;
const int n = indtmp2;
int inputChannel = c;
int inputChannels = outputChannels;
if (depthwiseMultiplier !=1) {
inputChannel /= depthwiseMultiplier;
inputChannels /= depthwiseMultiplier;
}
int weightOffset = c * kernelHeight * kernelWidth;
acc_t value = biasEnabled ? static_cast<acc_t>(bias.data()[c]) : acc_t(0);
const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int kH = 0; kH < KH_LIMIT; ++kH) {
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int kW = 0; kW < KW_LIMIT; ++kW) {
const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) {
const index_t offset = offset0 + h_in * inputWidth + w_in;
value += (static_cast<acc_t>(weight.data()[weightOffset]) *
static_cast<acc_t>(input.data()[offset]));
}
++weightOffset;
}
}
output.data()[linearIndex] = static_cast<scalar_t>(value);
}
}
template <int kSize, int stride, typename scalar_t, typename index_t>
__global__ void conv_depthwise2d_backward_kernel(
const PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> grad_output,
PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> grad_input,
const PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> weight,
index_t totalElements,
const int inputChannels,
const int depthwiseMultiplier,
const int outputChannels,
const int inputWidth, const int inputHeight,
const int outputWidth, const int outputHeight,
const int kernelWidth, const int kernelHeight,
const int strideWidth, const int strideHeight,
const int padWidth, const int padHeight,
const int dilationWidth, const int dilationHeight) {
using acc_t = at::acc_type<scalar_t, true>;
const int KW_LIMIT = (kSize != 0) ? kSize : kernelWidth;
const int KH_LIMIT = (kSize != 0) ? kSize : kernelHeight;
const int strideW = (stride != 0) ? stride : strideWidth;
const int strideH = (stride != 0) ? stride : strideHeight;
CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
int indtmp1 = linearIndex/inputWidth;
const int w = linearIndex - indtmp1 * inputWidth;
int indtmp2 = indtmp1/inputHeight;
const int h = indtmp1 - indtmp2 * inputHeight;
indtmp1 = indtmp2;
indtmp2 = indtmp1/inputChannels;
const int c = indtmp1 - indtmp2 * inputChannels;
const int n = indtmp2;
acc_t value(0);
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) {
int och = (c * depthwiseMultiplier) + multiplier;
int weightOffset = och * kernelHeight * kernelWidth;
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int kh = 0; kh < KH_LIMIT; ++kh) {
#if defined(USE_ROCM)
#pragma unroll
#endif
for (int kw = 0; kw < KW_LIMIT; ++kw) {
int h_out = h + padHeight - kh * dilationHeight;
int w_out = w + padWidth - kw * dilationWidth;
if ((h_out % strideH == 0) && (w_out % strideW == 0)) {
h_out = h_out / strideH;
w_out = w_out / strideW;
if ((h_out >= 0) && (h_out < outputHeight)
&& (w_out >= 0) && (w_out < outputWidth)) {
const int offset = ((n * outputChannels + och) * outputHeight + h_out)
* outputWidth + w_out;
value += (static_cast<acc_t>(weight.data()[weightOffset]) *
static_cast<acc_t>(grad_output.data()[offset]));
}
}
++weightOffset;
}
}
}
grad_input.data()[linearIndex] = static_cast<scalar_t>(value);
}
}
template <typename scalar_t, typename index_t=unsigned>
__global__ void conv_depthwise2d_grad_weight_kernel(
const PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> grad_output,
const PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> input,
PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> grad_weight,
const int batchSize,
const int inputChannels,
const int kernelChannels,
const int depthwiseMultiplier,
const int inputWidth, const int inputHeight,
const int outputWidth, const int outputHeight,
const int kernelWidth, const int kernelHeight,
const int strideWidth, const int strideHeight,
const int padWidth, const int padHeight,
const int dilationWidth, const int dilationHeight) {
using acc_t = at::acc_type<scalar_t, true>;
const int channelStride = kernelWidth * kernelHeight;
// Each Block is responsible for accumulating over a permutation of
// (channels x kH x kW), use blockIdx to determine which one
int bidx = blockIdx.x;
int kW = bidx % kernelWidth;
int kH = (bidx / kernelWidth) % kernelHeight;
int ch = (bidx / channelStride);
// Need to calculate which input channel is associated with this filter
// channel
int inputCh = ch / depthwiseMultiplier;
acc_t grad(0);
const int laneId = threadIdx.x % C10_WARP_SIZE;
const int batch = threadIdx.x / C10_WARP_SIZE;
const int nwarps = blockDim.x / C10_WARP_SIZE;
const int imageElements = outputWidth * outputHeight;
// Use warp per item. In the original kernel, a threadblock was used to sum over NHW.
// Here, we use a warp to sum values over HW dimension, and if batchSize is larger than the
// number of warps, a warp would loop over remaining batch items (e.g. if there are 8 warps,
// warp 0 would go over 0-8-16 etc image, warp 1 over 1-9-17 etc). Later in blockReduce,
// all the warps will be reduced anyway, thus the full reduction will be over NHW, like it
// should be. That allows to get rid of one modulo operation inside the loop (because n/batchIdx
// now does not have to be computed through modulo, you are just looping over it), and
// bring a nice speed-up.
for (int batchIdx = batch; batchIdx < batchSize; batchIdx += nwarps){
// Warp-stride loop over elements in a batch item
for (index_t idx = laneId; idx < imageElements; idx += C10_WARP_SIZE) {
// Need to calculate the following: batch position, and offset into the grad_output
// in height, and width. We can intuit the corresponding position in the input from
// the other parameters we have
int go_w_offset = idx % outputWidth;
int go_h_offset = (idx / outputWidth);
int i_w_offset = (go_w_offset * strideWidth) + (kW * dilationWidth) - padWidth;
int i_h_offset = (go_h_offset * strideHeight) + (kH * dilationHeight) - padHeight;
if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < inputWidth && i_h_offset < inputHeight) {
int inputOffset = ((batchIdx * inputChannels + inputCh) * inputHeight + i_h_offset) * inputWidth + i_w_offset;
int outputOffset = ((batchIdx * kernelChannels + ch) * outputHeight ) * outputWidth + idx;
grad += (static_cast<acc_t>(input.data()[inputOffset]) *
static_cast<acc_t>(grad_output.data()[outputOffset]));
}
}
}
__syncthreads();
// At this point each thread in the block has a local gradient, which we need to
// accumulate prior to writing the global value
extern __shared__ char smem[];
acc_t* buf = reinterpret_cast<acc_t*>(smem);
acc_t tval = cuda_utils::BlockReduceSum(grad, buf);
// After reduction, first thread in the block has the gradient, so its responsible
// for writing it to grad_weight
if (threadIdx.x == 0) {
int weightOffset = kW + (kernelWidth * kH) + (kernelWidth * kernelHeight * ch);
grad_weight.data()[weightOffset] = static_cast<scalar_t>(tval);
}
}
void conv_depthwise2d_forward_out(
const Tensor &input,
const Tensor &output,
const Tensor &weight,
const Tensor &bias,
const int kW, const int kH,
const int dW, const int dH,
const int padW, const int padH,
const int dilationW, const int dilationH) {
// Only handle 4D Input Tensors for now
TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
TORCH_CHECK(weight.numel() > 0 && weight.dim() == 4);
TORCH_CHECK(output.is_contiguous());
auto in_sizes = input.sizes();
auto w_sizes = weight.sizes();
// We assume that the input and weight Tensors are shaped properly by
// the caller, so we verify that here to some extent
// Weight Tensor is shape (output_channels, 1, kH, kW)
TORCH_CHECK(w_sizes[1] == 1);
// Input Tensor is shape (N, input_channels, H, W)
// We verify that the # of output_channels is a multiple of input_channels
TORCH_CHECK(w_sizes[0] % in_sizes[1] == 0);
// Bias has same # of channels as output
const bool has_bias = bias.defined();
TORCH_CHECK(!has_bias || (bias.dim() <= 1 && bias.numel() == w_sizes[0]));
// Following the behavior of other THCUNN functions, we shape the output
// Tensor ourselves
int64_t height = in_sizes[2];
int64_t width = in_sizes[3];
int64_t outputChannels = w_sizes[0];
auto out_sizes = conv_output_size(in_sizes, weight.sizes(), {padH, padW}, {dH, dW},
{dilationH, dilationW});
const auto outputWidth = out_sizes[3];
const auto outputHeight = out_sizes[2];
resize_output(output, out_sizes);
int64_t inputChannels = in_sizes[1];
int64_t depthwiseMultiplier = outputChannels / inputChannels;
// One thread per output value
TORCH_CHECK(canUse32BitIndexMath(input) && canUse32BitIndexMath(output));
int32_t n = output.numel();
int blocks = GET_BLOCKS(n);
dim3 grid(blocks);
dim3 block(CUDA_NUM_THREADS);
const auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"conv_depthwise2d_forward_cuda", [&] {
// Create PackedTensorAccessor
// Kernel currently relies upon all the Tensors to be contiguous, but we made
// them contiguous above
const auto input_a = input.packed_accessor32<scalar_t, 4>();
const auto weight_a = weight.packed_accessor32<scalar_t, 4>();
const auto output_a = output.packed_accessor32<scalar_t, 4>();
const auto bias_a = has_bias ?
bias.packed_accessor32<scalar_t, 1>() :
dummy_packed_accessor32<scalar_t, 1>();
if (kW == 3 && kH == 3) {
conv_depthwise2d_forward_kernel<3> <<<grid, block, 0, stream>>>(
input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
width, height, outputWidth, outputHeight,
kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (kW == 1 && kH == 1) {
conv_depthwise2d_forward_kernel<1> <<<grid, block, 0, stream>>>(
input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
width, height, outputWidth, outputHeight,
kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
conv_depthwise2d_forward_kernel<0> <<<grid, block, 0, stream>>>(
input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
width, height, outputWidth, outputHeight,
kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
void conv_depthwise2d_backward_out(
const Tensor &input,
const Tensor &grad_output,
const Tensor &grad_input,
const Tensor &weight,
const int kW, const int kH,
const int dW, const int dH,
const int padW, const int padH,
const int dilationW, const int dilationH) {
// Only handle 4D Input Tensors for now
TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
TORCH_CHECK(weight.numel() > 0 && weight.dim() == 4);
TORCH_CHECK(grad_output.numel() > 0 && grad_output.dim() == 4);
// Minimal shape checking, as above
// Same # of elements in batch
TORCH_CHECK(input.sizes()[0] == grad_output.sizes()[0]);
// Same # of filters as outputChannels
TORCH_CHECK(weight.sizes()[0] == grad_output.sizes()[1]);
// Resize Grainput_a
auto in_sizes = input.sizes();
resize_output(grad_input, in_sizes);
int inputChannels = in_sizes[1];
int height = in_sizes[2];
int width = in_sizes[3];
auto gO_sizes = grad_output.sizes();
int outputChannels = gO_sizes[1];
int outputHeight = gO_sizes[2];
int outputWidth = gO_sizes[3];
int depthwiseMultiplier = outputChannels / inputChannels;
// Kernel currently relies upon all the Tensors to be contiguous
TORCH_CHECK(grad_output.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(grad_input.is_contiguous());
// One thread per grainput_a value
TORCH_CHECK(canUse32BitIndexMath(grad_input) &&
canUse32BitIndexMath(grad_output));
int32_t n = grad_input.numel();
int blocks = GET_BLOCKS(n);
dim3 grid(blocks);
dim3 block(CUDA_NUM_THREADS);
const auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
"conv_depthwise2d_backward_cuda", [&] {
auto grad_output_a = grad_output.packed_accessor32<scalar_t, 4>();
auto grad_input_a = grad_input.packed_accessor32<scalar_t, 4>();
auto weight_a = weight.packed_accessor32<scalar_t, 4>();
if (kW == 3 && kH == 3) {
if (dW == 1 && dH == 1){
conv_depthwise2d_backward_kernel<3, 1><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (dW == 2 && dH == 2) {
conv_depthwise2d_backward_kernel<3, 2><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
conv_depthwise2d_backward_kernel<3, 0><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} else if (kW == 1 && kH == 1) {
if (dW == 1 && dH == 1){
conv_depthwise2d_backward_kernel<1, 1><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (dW == 2 && dH == 2) {
conv_depthwise2d_backward_kernel<1, 2><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
conv_depthwise2d_backward_kernel<1, 0><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} else if (dW == 1 && dH == 1) {
conv_depthwise2d_backward_kernel<0, 1><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (dW == 2 && dH == 2) {
conv_depthwise2d_backward_kernel<0, 2><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
conv_depthwise2d_backward_kernel<0, 0><<<grid, block, 0, stream>>>(
grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width,
height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
// Crude benchmarks suggest 256 is better than 512 and 1024
// TODO: Autotune/use better heuristics, improve speed more.
int getGradParamsNumThreads(int batchSize) {
//warp per item in a batch, up to a maximum
constexpr int MAX_BLOCK_SIZE = 256;
return std::min(batchSize * at::cuda::warp_size(), MAX_BLOCK_SIZE);
}
void conv_depthwise2d_grad_weight_out(
const Tensor &input,
const Tensor &grad_output,
const Tensor &grad_weight,
const int kW, const int kH,
const int dW, const int dH,
const int padW, const int padH,
const int dilationW, const int dilationH) {
// Only handle 4D Input Tensors for now
TORCH_CHECK(input.numel() > 0 && input.dim() == 4);
TORCH_CHECK(grad_output.numel() > 0 && grad_output.dim() == 4);
// Minimal shape checking as above
// Same # of elements in batch
TORCH_CHECK(input.sizes()[0] == grad_output.sizes()[0]);
auto in_sizes = input.sizes();
int batchSize = in_sizes[0];
int inputChannels = in_sizes[1];
int height = in_sizes[2];
int width = in_sizes[3];
auto gO_sizes = grad_output.sizes();
int outputChannels = gO_sizes[1];
int outputHeight = gO_sizes[2];
int outputWidth = gO_sizes[3];
int depthwiseMultiplier = outputChannels / inputChannels;
resize_output(grad_weight, {outputChannels, 1, kH, kW});
// Kernel currently relies upon all the Tensors to be contiguous
TORCH_CHECK(grad_output.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(grad_weight.is_contiguous());
// We parallelize so that each block computes a single value in grad_weight
TORCH_CHECK(canUse32BitIndexMath(input) &&
canUse32BitIndexMath(grad_output));
int blocks = outputChannels * kH * kW;
// Make sure we have enough threads to perform the reduction, and use this number
// to create the shared memory size for the reduction
dim3 grid(blocks);
dim3 block(getGradParamsNumThreads(batchSize));
const auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
"conv_depthwise2d_grad_weight_cuda", [&] {
const auto grad_output_a = grad_output.packed_accessor32<scalar_t, 4>();
const auto input_a = input.packed_accessor32<scalar_t, 4>();
const auto grad_weight_a = grad_weight.packed_accessor32<scalar_t, 4>();
using acc_t = at::acc_type<scalar_t, true>;
int warp_size = at::cuda::warp_size();
TORCH_INTERNAL_ASSERT(block.x % warp_size == 0);
int smem = (block.x / warp_size) * sizeof(acc_t);
conv_depthwise2d_grad_weight_kernel<<<grid, block, smem, stream>>>(
grad_output_a, input_a, grad_weight_a, batchSize, inputChannels, outputChannels, depthwiseMultiplier,
width, height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
} // namespace (anonymous)
const Tensor& conv_depthwise2d_cuda_out(
const Tensor &input_,
const Tensor &weight_,
IntArrayRef kernel_size,
const c10::optional<Tensor> &bias_opt,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
const Tensor &out) {
TORCH_CHECK(kernel_size.size() == 2);
TORCH_CHECK(stride.size() == 2);
TORCH_CHECK(padding.size() == 2);
TORCH_CHECK(dilation.size() == 2);
auto input = input_.expect_contiguous();
auto weight = weight_.expect_contiguous();
auto bias = [&] {
if (bias_opt.has_value() && bias_opt->defined()) {
return bias_opt->expect_contiguous();
}
return c10::MaybeOwned<Tensor>::owned(c10::in_place);
}();
conv_depthwise2d_forward_out(
*input,
out,
*weight,
*bias,
kernel_size[1], kernel_size[0],
stride[1], stride[0],
padding[1], padding[0],
dilation[1], dilation[0]);
return out;
}
Tensor conv_depthwise2d_cuda(
const Tensor &input,
const Tensor &weight,
IntArrayRef kernel_size,
const c10::optional<Tensor> &bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation) {
auto out = at::empty({0}, input.options());
return conv_depthwise2d_cuda_out(input, weight, kernel_size, bias,
stride, padding, dilation, out);
}
std::tuple<Tensor&, Tensor&> conv_depthwise2d_backward_cuda_out(
const Tensor & grad_output_,
const Tensor & self_,
const Tensor & weight_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
Tensor & grad_input,
Tensor & grad_weight) {
auto grad_output = grad_output_.expect_contiguous();
if (grad_weight.defined()) {
auto self = self_.expect_contiguous();
conv_depthwise2d_grad_weight_out(
*self, *grad_output, grad_weight,
kernel_size[1], kernel_size[0],
stride[1], stride[0],
padding[1], padding[0],
dilation[1], dilation[0]);
}
if (grad_input.defined()) {
auto weight = weight_.expect_contiguous();
conv_depthwise2d_backward_out(
self_, *grad_output, grad_input, *weight,
kernel_size[1], kernel_size[0],
stride[1], stride[0],
padding[1], padding[0],
dilation[1], dilation[0]);
}
return std::forward_as_tuple(grad_input, grad_weight);
}
std::tuple<Tensor, Tensor> conv_depthwise2d_backward_cuda(
const Tensor& grad_output,
const Tensor& self,
const Tensor& weight,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
std::array<bool, 2> output_mask) {
Tensor grad_input;
Tensor grad_weight;
if (output_mask[0]) {
grad_input = at::empty({0}, grad_output.options());
}
if (output_mask[1]) {
grad_weight = at::empty({0}, grad_output.options());
}
return conv_depthwise2d_backward_cuda_out(
grad_output,
self,
weight,
kernel_size,
stride,
padding,
dilation,
grad_input,
grad_weight);
}
REGISTER_CUDA_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_cuda);
} // namespace native
} // namespace at