forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
JitLoops.cuh
191 lines (165 loc) · 6.82 KB
/
JitLoops.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#pragma once
#include <ATen/jit_macros.h>
#if AT_USE_JITERATOR()
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#if !AT_ROCM_ENABLED()
#include <ATen/native/cuda/CUDAJitLoops.cuh>
#else
#error Jiterator not supported on ROCm
#endif
namespace at {
namespace native {
/* Note [Jiterator]
The "jiterator" simply just-in-time compiles the same kernels that
Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time,
build size, and initial CUDA context size.
By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels.
This behavior is controlled with two environment variables:
- USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use
- PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels
The jiterator currently has some limitations, however. It cannot:
- handle math on complex datatypes
- handle kernels with scalar parameters
These improvements will likely come soon.
For examples of how to use the jiterator see the i1 and gcd kernel
implementations, which pass jittable strings implementing their
operations instead of the typical CUDA functors.
To pass a runtime argument (similar to lambda captures in non-JIT kernels),
we need to pass to additional arguments to `jitted_gpu_kernel` by value.
Currently only primitive C++ types used for computation are valid.
The order of these extra arguments should be same as the order they appear
in kernel's function signature. (look at polygamma for example)
NOTE: One big restriction being that these arguments should be after the
arguments provided by TensorIterator. Eg. While capturing `n`, where
`scalar_t x` and `scalar_t y` are provided by TensorIterator,
* foo(scalar_t x, scalar_t y, int n) works!
* foo(int n, scalar_t x, scalar_y) doesn't work
* foo(scalar_t x, int n, scalar_y) doesn't work
*/
// Entrypoint for jitted GPU kernels.
// Only handles elementwise unary and binary kernels with a
// common dtype and a single output.
// NOTE: this assumes the op's iterator has a common_dtype.
// NOTE: We use std::tuple instead of parameter pack
// for `extra_args` due to following
// bug on older versions of clang
// https://bugs.llvm.org/show_bug.cgi?id=23029
template <
char const* name,
typename return_type,
typename f_inputs_type,
int arity,
typename... Args>
void jitted_gpu_kernel(
TensorIteratorBase& iter,
const std::string& f,
at::cuda::jit::BinaryFuncVariant scalar_pos =
at::cuda::jit::BinaryFuncVariant::NoScalar,
at::opmath_type<f_inputs_type> scalar_val = 0,
std::tuple<Args...> extra_args = std::make_tuple()) {
// TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
// Maybe it could be refactored?
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(
iter.device(arg).is_cuda(),
"argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
sub_iter, f, scalar_pos, scalar_val, extra_args);
}
return;
}
// Computes if dynamic casting is needed
// Dynamic casting is needed if an input's dtype differs from the common dtype
// or if the result dtype differs from the output's dtype
// Note: this is intentionally divergent from calling needs_dynamic_casting,
// which is more general and inspects a lambda to determine if dynamic
// casting is needed.
bool needs_dynamic_casting = false;
// Checks output
const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value;
const auto dtype0 = iter.dtype(0);
if (dtype0 != return_scalar_type) {
needs_dynamic_casting = true;
}
// Checks input(s)
const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value;
for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
const auto dtypei = iter.dtype(i);
if (dtypei != inputs_scalar_type) {
needs_dynamic_casting = true;
break;
}
}
if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
// NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used
// for computation in the generated code and hence we pass a dummy
// value of `0`.
jitted_gpu_kernel_impl<
/*name*/ name,
/*return_type=*/return_type,
/*f_inputs_type=*/f_inputs_type,
arity,
at::cuda::jit::BinaryFuncVariant::NoScalar>(
iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args);
} else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
jitted_gpu_kernel_impl<
/*name*/ name,
/*return_type=*/return_type,
/*f_inputs_type=*/f_inputs_type,
arity,
at::cuda::jit::BinaryFuncVariant::RhsScalar>(
iter,
f,
needs_dynamic_casting,
scalar_val,
extra_args);
} else {
jitted_gpu_kernel_impl<
/*name*/ name,
/*return_type=*/return_type,
/*f_inputs_type=*/f_inputs_type,
arity,
at::cuda::jit::BinaryFuncVariant::LhsScalar>(
iter,
f,
needs_dynamic_casting,
scalar_val,
extra_args);
}
}
// TODO: support runtime state capture similar to `jitted_gpu_kernel`.
template <char const *name, typename return_type, typename f_inputs_type>
void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
//currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type)
using opmath_t = at::opmath_type<f_inputs_type>;
if (iter.is_cpu_scalar(1)) {
auto scalar_val = iter.scalar_value<opmath_t>(1);
iter.remove_operand(1);
// TODO: When all kernels that use gpu_kernel_with_scalars are
// ported to structured, this device guard can be deleted. This
// works around incorrect device guard generation for pre-structured
// kernels device guards, but structured kernels do it right and
// we can assume the device is already set correctly
const OptionalDeviceGuard device_guard(iter.device(1));
jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
} else if (iter.is_cpu_scalar(2)) {
auto scalar_val = iter.scalar_value<opmath_t>(2);
iter.remove_operand(2);
jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
} else {
jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
}
}
}} // at::native
#endif // AT_USE_JITERATOR()