forked from NVIDIA/DALI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkernel.h
171 lines (151 loc) · 5.81 KB
/
kernel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
// Copyright (c) 2018-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DALI_KERNELS_KERNEL_H_
#define DALI_KERNELS_KERNEL_H_
#include <vector>
#include <functional>
#include "dali/kernels/context.h"
#include "dali/core/tensor_view.h"
#include "dali/kernels/kernel_params.h"
#include "dali/kernels/kernel_req.h"
#include "dali/kernels/kernel_traits.h"
#include "dali/core/tuple_helpers.h"
#include "dali/core/util.h"
namespace dali {
/**
* @brief Returns version of available cuFFT library
*
* @returns MAJOR*1000 + MINOR*10 + PATH or -1 if not available
*/
int GetCufftVersion();
/**
* @brief Defines the DALI kernel API. See dali::kernels::examples::Kernel for details
*/
namespace kernels {
namespace examples {
/**
* @brief DALI Kernel example
*
* This class represents a "concept" of a DALI kernel.
* A kernel must provide two non-overloaded functions:
* Run and Setup.
*
* Run and Setup functions are expected to accept arguments in strictly specified order:
* Setup(KernelContext, [inputs], [arguments])
* Run(KernelContext, [outputs], [inputs], [arguments])
* Additionally, both of these functions accept the same sets of inputs and arguments.
*
* The kernel can be run directly or its inputs, outputs and arguments can be tied
* into tuples and then the kernel be configured and launched using:
*
* `dali::kernels::kernel::Setup`
*
* `dali::kernels::kernel::Run`
*
* Programmer can check whether their type satisfies conditions for being a kernel
* through instantiating check_kernel<KernelType>. If the type does not meet requirements,
* static_asserts should produce meaningful diagnostics that will help to rectify the problem.
*/
template <typename OutputType, typename Input1, typename Input2>
struct Kernel {
/**
* @brief Returns kernel output(s) shape(s) and additional memory requirements
*
* Setup receives full input tensor lists and any extra arguments that
* are going to be passed to a subsequent call to Run.
*
* @remarks The inputs are provided mainly to inspect their shapes; actually looking at the
* data may degrade performance severely.
*
* @param context - environment of the kernel;, cuda stream, batch info, etc.
* At the time of call to Setup, its scratch area is undefined.
*
* @param in1 - example input, consisting of a list of 3D tensors with element type Input1
* @param in2 - example input, consisting of a 4D tensor with element type Input2
* @param aux - some extra parameters (e.g. convolution kernel, mask)
*/
KernelRequirements Setup(
KernelContext &context,
const InListGPU<Input1, 3> &in1,
const InTensorGPU<Input2, 4> &in2,
const std::vector<float> &aux);
/**
* @brief Runs the kernel
*
* Run processes the inputs and populates the pre-allocated output. Output shape is expected
* to match that returned by Setup.
*
* @param context - environment; provides scratch memory, cuda stream, batch info, etc.
* Scratch area must satisfy requirements returned by Setup.
* @param in1 - example input, consisting of a list of 3D tensors with element type Input1
* @param in2 - example input, consisting of a 4D tensor with element type Input2
* @param aux - some extra parameters (e.g. convolution kernel, mask)
*/
void Run(
KernelContext &context,
const OutListGPU<OutputType, 3> &out,
const InListGPU<Input1, 3> &in1,
const InTensorGPU<Input2, 4> &in2,
const std::vector<float> &aux);
};
} // namespace examples
/**
* @brief A collection of pseudo-methods to operate on Kernel classes/objects
*/
namespace kernel {
// avoid retyping "Kernel" every second word...
template <typename Kernel>
using inputs = kernel_inputs<Kernel>;
template <typename Kernel>
using outputs = kernel_outputs<Kernel>;
template <typename Kernel>
using args = kernel_args<Kernel>;
using Context = KernelContext;
using Requirements = KernelRequirements;
/**
* @brief Gets requirements for given Kernel
* @param context - execution environment (without scratch memory)
* @param input - kernel inputs, convertible to kernel_inputs<Kernel>
* @param args - kernel extra arguments, convertible to kernel_args<Kernel>
*/
template <typename Kernel>
Requirements Setup(
Kernel &instance,
Context &context,
const inputs<Kernel> &input,
const args<Kernel> &args) {
check_kernel<Kernel>();
return apply_all(std::mem_fn(&Kernel::Setup), instance, context, input, args);
}
/**
* @brief Executes a Kernel on an input set
* @param context - execution environment (with scratch memory)
* @param input - kernel inputs, convertible to kernel_inputs<Kernel>
* @param outputs - kernel outputs, convertible to kernel_outputs<Kernel>
* @param args - kernel extra arguments, convertible to kernel_args<Kernel>
*/
template <typename Kernel>
void Run(
Kernel &instance,
Context &context,
const outputs<Kernel> &output,
const inputs<Kernel> &input,
const args<Kernel> &args) {
check_kernel<std::remove_const_t<Kernel>>();
apply_all(std::mem_fn(&Kernel::Run), instance, context, output, input, args);
}
} // namespace kernel
} // namespace kernels
} // namespace dali
#endif // DALI_KERNELS_KERNEL_H_