forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFlashAttentionKernel.cpp
1271 lines (1208 loc) · 47.7 KB
/
FlashAttentionKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/cpu/utils.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#endif
namespace at::native {
namespace {
// out = val * a + b
// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case),
// take b as a scalar pointer.
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
template <typename T1, typename T2>
inline void _scale_attn_mask_fusion_kernel(
T1* a,
T2* b,
const int& size,
T1* out,
T1& val,
bool is_b_stride_zero) {
#else
template <bool is_b_stride_zero, typename T1, typename T2>
inline void _scale_attn_mask_fusion_kernel(
T1* a,
T2* b,
const int& size,
T1* out,
T1& val) {
#endif
const auto vec_size1 = at::vec::Vectorized<T1>::size();
const auto vec_size2 = at::vec::Vectorized<T2>::size();
constexpr int64_t T1_n =
(vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1;
constexpr int64_t T2_n = 1;
auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
int64_t i = 0;
for (; i < size - (size % vec_size2); i += vec_size2) {
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
at::vec::VectorizedN<T2, T2_n> b_n;
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
if (is_b_stride_zero) {
#else
if constexpr(is_b_stride_zero) {
#endif
b_n = at::vec::VectorizedN<T2, T2_n>((T1)b[0]);
} else {
b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
}
auto b_n_convert = at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
auto res = a_n * vec_scale + b_n_convert;
res.store(out + i);
}
for (; i < size; i++) {
auto tmp0 = a[i];
T1 tmp1;
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
if (is_b_stride_zero) {
#else
if constexpr(is_b_stride_zero) {
#endif
tmp1 = (T1)b[0];
} else {
tmp1 = (T1)b[i];
}
out[i] = tmp0 * val + tmp1;
}
}
// 1) out = exp(a - val)
// 2) val = sum(out)
template <typename T1, typename T2>
inline void _exp_reduce_sum_fusion_kernel(
T1* a,
const int& size,
T2* out,
T1& val) {
auto vec_size = vec::Vectorized<T1>::size();
auto vec_max = vec::Vectorized<T1>(val);
T1 tmp_sum = 0;
auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
auto tmp1 = tmp0 - vec_max;
auto tmp2 = tmp1.exp_u20();
vec_tmp_sum += tmp2;
_store(out + i, tmp2);
}
tmp_sum = vec::vec_reduce_all<T1>(
[](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) {
return x + y;
},
vec_tmp_sum);
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = tmp0 - val;
auto tmp2 = exp(tmp1);
tmp_sum += tmp2;
out[i] = tmp2;
}
val = tmp_sum;
}
// 1) out = a * scale
// 2) max = max(out)
template <typename scalar_t>
inline void _mul_reduce_max_fusion_kernel(
const scalar_t* a,
const scalar_t& scale,
const int& size,
scalar_t* out,
scalar_t& max) {
auto vec_size = vec::Vectorized<scalar_t>::size();
auto vec_scale = vec::Vectorized<scalar_t>(scale);
scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
auto vec_tmp_max = vec::Vectorized<scalar_t>(tmp_max);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = vec::Vectorized<scalar_t>::loadu(a + i);
auto tmp1 = tmp0 * vec_scale;
vec_tmp_max = vec::maximum(vec_tmp_max, tmp1);
_store(out + i, tmp1);
}
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = tmp0 * scale;
tmp_max = std::max(tmp_max, tmp1);
out[i] = tmp1;
}
max = std::max(
tmp_max,
vec::vec_reduce_all<scalar_t>(
[](vec::Vectorized<scalar_t>& x, vec::Vectorized<scalar_t>& y) {
return vec::maximum(x, y);
},
vec_tmp_max));
}
template <typename scalar_t>
static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
TORCH_CHECK(ptr2 == nullptr);
return ptr;
}
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) {
return ptr2;
}
template <typename scalar_t>
inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
using Vec = Vectorized<scalar_t>;
Vec data_vec = Vec(val);
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
data_vec.store(data + d);
}
#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
# pragma unroll
#endif
for (; d < size; d++) {
data[d] = val;
}
}
void reshape_attn_mask_to_4d(
Tensor& attn_mask,
int64_t batchSize,
int64_t num_head,
int64_t qSize,
int64_t kvSize) {
// Support mask shapes:
// 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
// 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
// Guaranteed in check_attn_mask_shape
int64_t attn_mask_size_0 = 1;
int64_t attn_mask_size_1 = 1;
if (attn_mask.dim() == 4) {
if (attn_mask.size(0) == batchSize) {
attn_mask_size_0 = batchSize;
}
if (attn_mask.size(1) == num_head) {
attn_mask_size_1 = num_head;
}
}
attn_mask = attn_mask
.view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)})
.expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
}
template <typename scalar_t>
inline void copy_value_with_pad(
const scalar_t* value_ptr,
scalar_t* dst_ptr,
int64_t rows,
int64_t cols,
int64_t prows,
int64_t pcols,
int64_t ldi) {
auto vec_size = at::vec::Vectorized<scalar_t>::size();
int64_t i = 0;
for (; i < rows; i++) {
int64_t j = 0;
for (; j < cols - (cols % vec_size); j += vec_size) {
auto vec_v =
at::vec::Vectorized<scalar_t>::loadu(value_ptr + i * ldi + j);
vec_v.store(dst_ptr + i * pcols + j);
}
if (j < cols) {
auto vec_v = at::vec::Vectorized<scalar_t>::loadu(
value_ptr + i * ldi + j, cols - j);
vec_v.store(dst_ptr + i * pcols + j, cols - j);
}
// col padding
auto psize = pcols - cols;
if (psize > 0) {
auto zero_vec = at::vec::Vectorized<scalar_t>(0);
int64_t pj = 0;
for (; pj < psize - (psize % vec_size); pj += vec_size) {
zero_vec.store(dst_ptr + i * pcols + cols + pj);
}
if (pj < psize) {
zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj);
}
}
}
// row padding
for (; i < prows; i++) {
auto zero_vec = at::vec::Vectorized<scalar_t>(0);
int64_t j = 0;
for (; j < pcols - (pcols % vec_size); j += vec_size) {
zero_vec.store(dst_ptr + i * pcols + j);
}
if (j < pcols) {
zero_vec.store(dst_ptr + i * pcols + j, pcols - j);
}
}
}
template <typename scalar_t>
inline void pad_remain_row_col_zero(
scalar_t* value_ptr,
int rows,
int cols,
int prows,
int pcols,
int ldi) {
auto psize = pcols - cols;
if (psize == 0 && prows == rows) {
return;
}
auto vec_size = at::vec::Vectorized<scalar_t>::size();
auto zero = at::vec::Vectorized<scalar_t>(0);
if (psize > 0) {
for (int i = 0; i < rows; i++) {
int j = 0;
for (; j < psize - (psize % vec_size); j += vec_size) {
zero.store(value_ptr + i * ldi + cols + j);
}
if (j < psize) {
zero.store(value_ptr + i * ldi + cols + j, psize - j);
}
}
}
for (int i = rows; i < prows; i++) {
int j = 0;
for (; j < pcols - (pcols % vec_size); j += vec_size) {
zero.store(value_ptr + i * ldi + j);
}
if (j < pcols) {
zero.store(value_ptr + i * ldi + j, pcols - j);
}
}
}
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size, bool with_pack=false>
void cpu_flash_attention(
const Tensor& output,
const Tensor& logsumexp,
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
double dropout_p,
bool is_causal,
std::optional<Tensor> attn_mask,
std::optional<double> scale) {
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
at::Tensor query = q.transpose(1, 2);
at::Tensor key = k.transpose(1, 2);
at::Tensor value = v.transpose(1, 2);
constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<accum_t>;
accum_t scaling_factor =
sdp::calculate_scale(query, scale).as_float_unchecked();
// Sizes
TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
int64_t batchSize = query.size(0);
int64_t qSize = query.size(1);
int64_t kvSize = value.size(1);
int64_t num_head = query.size(2);
int64_t headSize = query.size(3);
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
if (has_attn_mask) {
reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
}
// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
int64_t qStrideH = query.stride(2);
int64_t kStrideB = key.stride(0);
int64_t kStrideN = key.stride(1);
int64_t kStrideH = key.stride(2);
int64_t vStrideB = value.stride(0);
int64_t vStrideN = value.stride(1);
int64_t vStrideH = value.stride(2);
int64_t oStrideB = output.stride(0);
int64_t oStrideM = output.stride(1);
int64_t oStrideH = output.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t mStrideB =
(has_attn_mask && attn_mask.value().size(0) > 1)
? attn_mask.value().stride(0)
: 0;
int64_t mStrideH =
(has_attn_mask && attn_mask.value().size(1) > 1)
? attn_mask.value().stride(1)
: 0;
int64_t mStrideM =
(has_attn_mask && attn_mask.value().size(2) > 1)
? attn_mask.value().stride(2)
: 0;
int64_t mStrideN =
(has_attn_mask && attn_mask.value().size(3) > 1)
? attn_mask.value().stride(3)
: 0;
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
int64_t num_thread = at::get_num_threads();
const auto dtype = query.scalar_type();
const auto accumulate_dtype = toOpMathType(dtype);
// Whether pack is needed
bool need_pack = false;
// Block size of packing B matrix
int64_t packb_size = 64;
// Use packb_size due to the limitation:
// oneDNN pack only supports output leading dimention being one of (16, 32, 48, 64)
// For instance,
// for q @ k.T [qSplitSize, headSize] * [headSize, kvSplitSize] = [qSplitSize, kvSplitSize],
// we need to split kvSplitSize with packb_size for packing k.T,
// for (q @ k.T) @ v [qSplitSize, kvSplitSize] x [kvSplitSize, headSize] -> [qSplitSize, headSize],
// we need to split headSize with packb_size for packing v
// TODO Simplify the check when oneDNN supports fused pack with transpose and has better performance
if (with_pack) {
need_pack = num_head >= 4 && headSize % packb_size == 0 && kvSize >= packb_size;
if (need_pack) {
float pack_size = batchSize * num_head * kvSize * headSize / 1024;
float gemm_size_per_thread =
(batchSize * num_head * qSlice + num_thread - 1) / num_thread *
qSplitSize * (is_causal ? qSize : kvSize) * headSize / 1024;
float gsize = gemm_size_per_thread / pack_size;
// When the number of gemm is much greater than the number of pack,
// the pack and padding overhead can be overlaped.
if (pack_size < 2688) {
need_pack = gsize >= 36 || (gsize >= 24 && headSize > packb_size);
} else if (pack_size < 16384) {
need_pack = gsize >= (is_causal ? 54 : 52);
} else {
need_pack = gsize >= (is_causal ? 54 : 40);
}
}
}
int64_t rHeadSize = need_pack ? (headSize + packb_size - 1) / packb_size * packb_size : headSize;
int64_t rkvSplitSize = need_pack ? (kvSplitSize + packb_size - 1) / packb_size * packb_size : kvSplitSize;
int64_t rkvTail = need_pack ? (kvTail + packb_size - 1) / packb_size * packb_size : kvTail;
int64_t rkvSize = kv_split_size > kvSize ? rkvTail : rkvSplitSize * kvSlice + rkvTail;
// oneDNN pack does not support odd K now, we need also pad odd K
bool headSize_even = headSize % 2 == 0;
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize;
int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail;
// allocate per thread temp buf (accumulate type)
int64_t size_per_thread =
/* qk */ qSplitSize * rkvSplitSize +
/* qk_max */ qSplitSize +
/* qk_sum */ qSplitSize +
/* dst */ qSplitSize * rHeadSize;
at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
at::Tensor buf_reduced = at::empty(
{num_thread,
qSplitSize,
is_reduced_type ? ekvSplitSize : 0},
query.options());
// Data ptrs
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
const scalar_t* k_data = key.const_data_ptr<scalar_t>();
const scalar_t* v_data = value.const_data_ptr<scalar_t>();
mask_t* mask_data = has_attn_mask
? attn_mask.value().data_ptr<mask_t>()
: nullptr;
scalar_t* out_data = output.data_ptr<scalar_t>();
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
// Buffer to store padding query
scalar_t* query_padding_ptr = nullptr;
std::unique_ptr<scalar_t[]> query_padding_data;
if (!headSize_even && need_pack) {
query_padding_data = std::make_unique<scalar_t[]>(num_thread * qSplitSize * eheadSize);
query_padding_ptr = query_padding_data.get();
}
// Buffer to store Key and Value after transforms
scalar_t* key_reorder_ptr = nullptr;
std::unique_ptr<scalar_t[]> key_reorder_data;
scalar_t* value_reorder_ptr = nullptr;
std::unique_ptr<scalar_t[]> value_reorder_data;
int kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail;
if (need_pack) {
key_reorder_data = std::make_unique<scalar_t[]>(batchSize * num_head * eheadSize * rkvSize);
key_reorder_ptr = key_reorder_data.get();
value_reorder_data = std::make_unique<scalar_t[]>(batchSize * num_head * kv_padding_size * rHeadSize);
value_reorder_ptr = value_reorder_data.get();
}
// Reorder K, V
if (need_pack) {
at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, l = 0, n = 0;
at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice);
std::unique_ptr<scalar_t[]> transpose_buffer = std::make_unique<scalar_t[]>(eheadSize * packb_size);
scalar_t* transpose_buffer_ptr = transpose_buffer.get();
std::unique_ptr<scalar_t[]> v_copy_buffer = std::make_unique<scalar_t[]>(ekvSplitSize * packb_size);
scalar_t* v_copy_buffer_ptr = v_copy_buffer.get();
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
n = l * kvSplitSize;
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
int64_t ekvBlockSize = kvBlockSize % 2 == 0 ? kvBlockSize : kvBlockSize + 1;
// Split kvSplitSize with packb_size
// [kvSplitSize, headSize] -> [div_up(kvSplitSize, packb_size), packb_size, headSize]
// Transpose [packb_size, headSize] -> [headSize, packb_size]
// Pack transposed buffer
for (int64_t b = 0; b < kvBlockSize; b += packb_size) {
bool tail = kvBlockSize - b < packb_size;
// TODO Use fused pack with transpose support when oneDNN supports such usage
utils::transpose<uint16_t>(
tail ? kvBlockSize - b : packb_size,
headSize,
/* src_ptr */
reinterpret_cast<const uint16_t*>(
k_data + i * kStrideB + j * kStrideH + n * kStrideN +
b * kStrideN),
/* ld_src */ kStrideN,
/* dst */ reinterpret_cast<uint16_t*>(transpose_buffer_ptr),
/* ld_dst */ packb_size);
// Pad [headSize, x] -> [eheadSize, x]
if (!headSize_even) {
pad_remain_row_col_zero<scalar_t>(
transpose_buffer_ptr,
headSize,
packb_size,
eheadSize,
packb_size,
packb_size);
}
// Pack
cpublas::pack(
/* K */ eheadSize,
/* N */ packb_size,
/* ld_in */ packb_size,
/* ld_out */ packb_size,
/* dt_in */ dtype,
/* dt_out */ dtype,
transpose_buffer_ptr,
key_reorder_ptr + i * num_head * eheadSize * rkvSize +
j * eheadSize * rkvSize + n * eheadSize + b * eheadSize);
}
// Split headSize with packb_size
// [kvSplitSize, headSize] -> [kvSplitSize, div_up(headSize, packb_size), packb_size]
for (int64_t b = 0; b < headSize; b += packb_size) {
// Do copy due to the limitation of input_ld of oneDNN pack:
// Regarding packing [K, N], only input_ld == N is supported
// TODO: remove the copy when pack supports input_ld >= N
copy_value_with_pad<scalar_t>(
v_data + i * vStrideB + j * vStrideH + n * vStrideN + b,
v_copy_buffer_ptr,
kvBlockSize,
(headSize - b < packb_size) ? headSize - b : packb_size,
ekvBlockSize,
packb_size,
vStrideN);
cpublas::pack(
ekvBlockSize,
packb_size,
packb_size,
packb_size,
dtype,
dtype,
v_copy_buffer_ptr,
value_reorder_ptr +
i * num_head * kv_padding_size * rHeadSize +
j * kv_padding_size * rHeadSize + n * rHeadSize +
ekvBlockSize * b);
}
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice);
}
});
}
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, k = 0;
data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
int ompIdx = at::get_thread_num();
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
accum_t* qk_data = buf_ptr;
accum_t* qk_max_data = qk_data + qSplitSize * rkvSplitSize;
accum_t* qk_sum_data = qk_max_data + qSplitSize;
accum_t* dst_data = qk_sum_data + qSplitSize;
scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize : nullptr;
scalar_t* query_t_padding_ptr = (!headSize_even && need_pack)
? query_padding_ptr + ompIdx * qSplitSize * eheadSize
: nullptr;
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
int64_t m = k * qSplitSize;
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
// Initialize max and sum
fill_stub(qk_max_data,
-std::numeric_limits<accum_t>::infinity(), qBlockSize);
fill_stub(qk_sum_data,
static_cast<accum_t>(0), qBlockSize);
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
if (!headSize_even && need_pack) {
// Pad query if headSize is not even
// [qBlockSize, headSize] -> [qBlockSize, eheadSize]
copy_value_with_pad<scalar_t>(
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
query_t_padding_ptr,
qBlockSize,
headSize,
qBlockSize,
eheadSize,
qStrideM
);
}
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
int64_t ekvBlockSize = (need_pack && kvBlockSize % 2 != 0) ? kvBlockSize + 1 : kvBlockSize;
int64_t rkvBlockSize = kvBlockSize == kvSplitSize ? rkvSplitSize : rkvTail;
// Calculate scale * q @ k.T
if (need_pack) {
if constexpr (std::is_same_v<scalar_t, at::Half>) {
for (int64_t b = 0; b < kvBlockSize; b += packb_size) {
cpublas::brgemm(
qBlockSize,
packb_size,
eheadSize,
headSize_even ? qStrideM : eheadSize,
packb_size,
rkvBlockSize,
1.f,
0.f,
!headSize_even
? query_t_padding_ptr
: q_data + i * qStrideB + j * qStrideH + m * qStrideM,
key_reorder_ptr + i * num_head * eheadSize * rkvSize +
j * eheadSize * rkvSize + n * eheadSize + b * eheadSize,
qk_data + b);
}
}
} else {
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
static_cast<accum_t>(1),
k_data + i * kStrideB + j * kStrideH +
n * kStrideN,
kStrideN,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
static_cast<accum_t>(0),
qk_data,
kvBlockSize);
}
// Apply causal mask, fill unused with -inf
if (is_causal && num_keys - n <= kvSplitSize) {
for (const auto row : c10::irange(qBlockSize)) {
int64_t last_col = m + row - n;
accum_t* row_ptr = qk_data + row * rkvBlockSize;
fill_stub(row_ptr + last_col + 1,
-std::numeric_limits<accum_t>::infinity(),
kvBlockSize - last_col - 1);
}
}
// Update attention weights with attention mask
// And apply scaling factor
// qk <- qk * scaling + attn_mask
if (has_attn_mask) {
for (int64_t row = 0; row < qBlockSize; ++row) {
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
_scale_attn_mask_fusion_kernel(
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor,
mStrideN == 0);
#else
if (mStrideN == 0) {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM,
kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor);
} else {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
qk_data + row * rkvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
qk_data + row * rkvBlockSize,
scaling_factor);
}
#endif
}
}
// Update coefficients with Softmax
accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
for (int64_t row = 0; row < qBlockSize; ++row) {
if (has_attn_mask) {
// max per row
tmp_max = at::vec::reduce_all<accum_t>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
qk_data + row * rkvBlockSize,
kvBlockSize);
} else {
// apply scaling factor and max per row in fusion
_mul_reduce_max_fusion_kernel(
qk_data + row * rkvBlockSize,
scaling_factor,
kvBlockSize,
qk_data + row * rkvBlockSize,
tmp_max);
}
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
// to avoid `nan = exp2f(-inf - (-inf))`
fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
static_cast<scalar_t>(0), kvBlockSize);
} else {
tmp_sum = tmp_max;
// qk <- exp(qk - max) and sum per row
_exp_reduce_sum_fusion_kernel(
qk_data + row * rkvBlockSize, kvBlockSize,
conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
tmp_sum);
// exp_tmp <- exp(max[row] - max)
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
// sum[row] <- sum + exp_tmp * sum[row]
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
// max[row] <- max
qk_max_data[row] = tmp_max;
// dst <- dst * exp_tmp
if (n > 0) {
vec::map<accum_t>(
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
dst_data + row * rHeadSize,
dst_data + row * rHeadSize,
headSize);
}
}
if (need_pack && kvBlockSize % 2 != 0) {
// Pad: [qSplitSize,kvSplitSize] -> [qSplitSize,kvSplitSize + 1]
*(qk_reduced_data + row * (1 + kvBlockSize) + kvBlockSize) = scalar_t(0);
}
}
// Calculate Softmax(q @ k.T) @ v
if (need_pack) {
int64_t psize = n / kvSplitSize * ekvSplitSize;
if constexpr (std::is_same_v<scalar_t, at::Half>) {
for (int64_t b = 0; b < headSize; b += packb_size) {
cpublas::brgemm(
qBlockSize,
packb_size,
ekvBlockSize,
ekvBlockSize,
packb_size,
rHeadSize,
1.0,
n == 0 ? 0.f : 1.f,
qk_reduced_data,
value_reorder_ptr +
i * num_head * kv_padding_size * rHeadSize +
j * kv_padding_size * rHeadSize + psize * rHeadSize +
b * ekvBlockSize,
dst_data + b);
}
}
} else {
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::NoTranspose,
headSize,
qBlockSize,
kvBlockSize,
static_cast<accum_t>(1),
v_data + i * vStrideB + j * vStrideH +
n * vStrideN,
vStrideN,
conditional_data_ptr(qk_data, qk_reduced_data),
kvBlockSize,
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
dst_data,
headSize);
}
}
// dst <- dst / sum[row]
// reorder MHA output with strides
for (int64_t row = 0; row < qBlockSize; ++row) {
// Row sums for full masked out rows are 0, we set them to 1
// in order to avoid NaNs in the output and instead set fully
// masked out rows to 0
qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
accum_t sum_reciprocal = 1 / qk_sum_data[row];
vec::map<scalar_t>(
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
dst_data + row * rHeadSize,
headSize);
}
// Store logsumexp for backward
accum_t* lse_ptr = lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
for (const auto row : c10::irange(qBlockSize)) {
lse_ptr[row * lStrideM] = qk_max_data[row]
+ std::log(qk_sum_data[row]);
}
// Move to the next query
data_index_step(i, batchSize, j, num_head, k, qSlice);
}
});
if (need_pack) {
cpublas::brgemm_release();
}
}
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
void cpu_flash_attention_backward(
const at::Tensor& grad_q,
const at::Tensor& grad_k,
const at::Tensor& grad_v,
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
double dropout_p,
bool is_causal,
std::optional<Tensor> attn_mask,
std::optional<double> scale) {
constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<accum_t>;
accum_t scaling_factor =
sdp::calculate_scale(query, scale).as_float_unchecked();
// Sizes
TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
"scaled_dot_product_attention_flash_attention_backward: Q/K/V should have the same head size");
// Query (Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value (Batch x KV_seq_len x Num_heads x Dim_per_head)
int64_t batchSize = query.size(0);
int64_t qSize = query.size(1);
int64_t kvSize = value.size(1);
int64_t num_head = query.size(2);
int64_t headSize = query.size(3);
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
if (has_attn_mask) {
reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
}
// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
int64_t qStrideH = query.stride(2);
int64_t kStrideB = key.stride(0);
int64_t kStrideN = key.stride(1);
int64_t kStrideH = key.stride(2);
int64_t vStrideB = value.stride(0);
int64_t vStrideN = value.stride(1);
int64_t vStrideH = value.stride(2);
int64_t oStrideB = out.stride(0);
int64_t oStrideM = out.stride(1);
int64_t oStrideH = out.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t mStrideB =
(has_attn_mask && attn_mask.value().size(0) > 1)
? attn_mask.value().stride(0)
: 0;
int64_t mStrideH =
(has_attn_mask && attn_mask.value().size(1) > 1)
? attn_mask.value().stride(1)
: 0;
int64_t mStrideM =
(has_attn_mask && attn_mask.value().size(2) > 1)
? attn_mask.value().stride(2)
: 0;
int64_t mStrideN =
(has_attn_mask && attn_mask.value().size(3) > 1)
? attn_mask.value().stride(3)
: 0;
int64_t grad_qStrideB = grad_q.stride(0);
int64_t grad_qStrideM = grad_q.stride(1);
int64_t grad_qStrideH = grad_q.stride(2);
int64_t grad_kStrideB = grad_k.stride(0);
int64_t grad_kStrideN = grad_k.stride(1);
int64_t grad_kStrideH = grad_k.stride(2);
int64_t grad_vStrideB = grad_v.stride(0);
int64_t grad_vStrideN = grad_v.stride(1);
int64_t grad_vStrideH = grad_v.stride(2);
int64_t grad_oStrideB = grad_out.stride(0);
int64_t grad_oStrideM = grad_out.stride(1);
int64_t grad_oStrideH = grad_out.stride(2);
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
int64_t num_thread = at::get_num_threads();
const auto dtype = query.scalar_type();
const auto accumulate_dtype = toOpMathType(dtype);
// allocate per thread temp buf (accumulate type)
int64_t size_per_thread =
/* attn */ qSplitSize * kvSplitSize +
/* grad_attn */ qSplitSize * kvSplitSize;
at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
// allocate per thread temp buf_reduced (scalar type)
// buf2 is only needed for bfloat16 and float16
int64_t size_per_thread_reduced =
/* attn_reduced */ qSplitSize * kvSplitSize +
/* grad_attn_reduced */ qSplitSize * kvSplitSize;
at::Tensor buf_reduced = at::empty({num_thread, is_reduced_type ? size_per_thread_reduced : 0}, query.options());
scalar_t* grad_q_data = grad_q.data_ptr<scalar_t>();
scalar_t* grad_k_data = grad_k.data_ptr<scalar_t>();
scalar_t* grad_v_data = grad_v.data_ptr<scalar_t>();
const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
const scalar_t* k_data = key.const_data_ptr<scalar_t>();
const scalar_t* v_data = value.const_data_ptr<scalar_t>();
mask_t* mask_data = has_attn_mask
? attn_mask.value().data_ptr<mask_t>()
: nullptr;
const scalar_t* out_data = out.const_data_ptr<scalar_t>();
const accum_t* lse_data = logsumexp.const_data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
at::parallel_for(0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0;
data_index_init(begin, i, batchSize, j, num_head);
int ompIdx = at::get_thread_num();
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
accum_t* attn_data = buf_ptr;
accum_t* grad_attn_data = attn_data + qSplitSize * kvSplitSize;
scalar_t* buf_reduced_ptr = is_reduced_type ? buf_reduced_data + ompIdx * size_per_thread_reduced : nullptr;
scalar_t* attn_reduced_data = is_reduced_type ? buf_reduced_ptr : nullptr;
scalar_t* grad_attn_reduced_data = is_reduced_type ? attn_reduced_data + qSplitSize * kvSplitSize : nullptr;
at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype));
accum_t* dsum_data = dsum.data_ptr<accum_t>();
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
// rowsum of grad_out * out
for (int64_t m = 0; m < qSize; m += qSplitSize) {
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
// dsum <- rowsum(grad_out * out)
for (const auto row : c10::irange(qBlockSize)) {
*(dsum_data + row) = vec::map2_reduce_all<scalar_t>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
grad_out_data + i * grad_oStrideB + j * grad_oStrideH + (m + row) * grad_oStrideM,
out_data + i * oStrideB + j * oStrideH + (m + row) * oStrideM,
headSize);
}
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
// attn <- scale * q @ k.T
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
scaling_factor,
k_data + i * kStrideB + j * kStrideH +
n * kStrideN,
kStrideN,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
static_cast<accum_t>(0),
attn_data,
kvBlockSize);
// attn <- attn + mask
if (has_attn_mask) {
accum_t one = accum_t(1);
for (const auto row : c10::irange(qBlockSize)) {
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
_scale_attn_mask_fusion_kernel(
attn_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
kvBlockSize,
attn_data + row * kvBlockSize,
one,
mStrideN == 0);
#else
if (mStrideN == 0) {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
attn_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM,
kvBlockSize,
attn_data + row * kvBlockSize,
one);
} else {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
attn_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
attn_data + row * kvBlockSize,
one);
}
#endif
}
}