forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 20
/
epilogue.h
543 lines (436 loc) · 18.8 KB
/
epilogue.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
The shared memory resource is time-sliced across warps.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
typename OutputOp_, ///< Output operator
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
(!IsEpilogueFunctorHeavy<OutputOp_>::value)
>
class Epilogue :
public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>,
public EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>
{
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using BaseStreamK = EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// Number of warps per block
using WarpCount = typename Base::WarpCount;
/// Number of threads per block
static int const kBlockThreads = 32 * WarpCount::kCount;
/// Per-thread accumulator tile type
using AccumulatorTile = typename Base::AccumulatorTile;
/// Numerical accumulation element type
using ElementAccumulator = typename WarpMmaOperator::ElementC;
/// Fragment type used by the accumulator tile's fragment iterator
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Vector type used by the global output iterator
using OutputAccessType = Array<
typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Vector type used by the shared output iterator
using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
public:
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
"Divisibility");
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
public:
/// Aspect for when epilogue source is not needed
struct SourceAspectNotNeeded
{
/// Constructor
CUTLASS_DEVICE
SourceAspectNotNeeded()
{}
// No-op
CUTLASS_DEVICE
void load() { }
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};
/// Aspect for when epilogue source is needed
struct SourceAspectNeeded
{
OutputTileIterator source_iterator;
typename OutputTileIterator::Fragment source_fragment;
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
static void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Constructor
CUTLASS_DEVICE
SourceAspectNeeded(OutputTileIterator source_iterator) :
source_iterator(source_iterator)
{
source_fragment.clear();
}
// Load addend source fragment from global memory
CUTLASS_DEVICE
void load() {
source_iterator.load(source_fragment);
++source_iterator;
}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
{
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
}
};
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
/// Thread index in the threadblock
int thread_idx;
/// Warp index in the threadblock
int warp_idx;
public:
/// Constructor
CUTLASS_DEVICE
Epilogue(
typename Base::SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx) ///< Id of thread within warp
:
Base(shared_storage, thread_idx, warp_idx, lane_idx),
BaseStreamK(thread_idx),
shared_load_iterator_(shared_storage.reference(), thread_idx),
thread_idx(thread_idx),
warp_idx(warp_idx)
{}
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
/// performing epilogue computations, writing to output
CUTLASS_DEVICE
void reduce(
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
void *element_workspace,
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
{
// Reduce peer accumulator fragments into one fragment
AccumulatorFragment accum_fragment;
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
// Store fragment to shared memory
this->warp_tile_iterator_.store(accum_fragment);
syncthreads();
// Initialize/load source-fragment data
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
if (output_op.is_source_needed())
{
source_iterator += reduce_fragment_idx;
source_iterator.load(source_fragment);
}
// Load fragment from shared memory
typename SharedLoadIterator::Fragment aligned_accum_fragment;
shared_load_iterator_.load(aligned_accum_fragment);
// Add fragments shared by other k partitions
if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
typename SharedLoadIterator::Fragment aligned_addend_fragment;
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_addend_fragment);
aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment);
}
}
// Compute the output result
typename OutputTileIterator::Fragment output_fragment;
// Apply the output operator
SourceAspectNeeded::apply_output_operator(
output_fragment,
output_op,
aligned_accum_fragment,
source_fragment);
// Store the final result
destination_iterator += reduce_fragment_idx;
destination_iterator.store(output_fragment);
}
/// Perform the epilogue computations and stream the result to global memory.
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
/// Perform the epilogue computations and stream the result to global memory. Implements
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
{
if (output_op.is_source_needed())
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}
else
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
}
/// Perform the epilogue computations and stream the result to global memory. Implements a
/// single codepath, regardless of whether the output op requires addend data to be loaded
CUTLASS_DEVICE
void unified(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
{
if (!output_op.is_source_needed())
{
source_iterator.clear_mask();
syncthreads(); // Dummy (CUDA 11.0)
}
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}
template<class Seq>
struct acc2smem;
template <size_t... Seq>
struct acc2smem<cutlass::index_sequence<Seq...>> {
template<int Advance>
CUTLASS_DEVICE
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
template <typename SourceAspect>
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
SourceAspect source)
{
// Iterator over warp-level accumulator fragment
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter)
{
//
// Load the source
//
source.load();
//
// Convert and store fragment
//
syncthreads();
acc2smem<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator, this->warp_tile_iterator_);
syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (kPartitionsK > 1) {
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////