forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Matmul.cpp
225 lines (197 loc) · 7.34 KB
/
Matmul.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/mkldnn/Matmul.h>
#if !AT_MKLDNN_ENABLED()
namespace at {
namespace native {
void mkldnn_matmul(
const Tensor &mat1,
const Tensor &mat2,
const Tensor &result,
float beta,
float alpha) {
TORCH_CHECK(false, "mkldnn_matmul: ATen not compiled with MKLDNN support");
}
bool use_mkldnn_bf16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result_opt){
return false;
}
bool mkldnn_bf16_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const c10::BFloat16 *a, int64_t lda,
const c10::BFloat16 *b, int64_t ldb,
float beta,
c10::BFloat16 *c, int64_t ldc) {
return false;
}
} // namespace native
} // namespace at
#else // AT_MKLDNN_ENABLED
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
namespace at {
namespace native {
static bool use_mkldnn_bf16_matmul() {
return (
at::globalContext().userEnabledMkldnn() &&
mkldnn_bf16_device_check());
}
bool mkldnn_bf16_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const c10::BFloat16 *a_data, int64_t lda,
const c10::BFloat16 *b_data, int64_t ldb,
float beta,
c10::BFloat16 *c_data, int64_t ldc) {
if (!use_mkldnn_bf16_matmul() ||
(m * n * k <= 16 * 16 * 16) ||
(alpha == 0.0f)) {
return false;
}
ideep::attr_t op_attr;
// Use mkldnn post ops to perform the add.
if (beta != 0.0f) {
op_attr = ideep::attr_t::fuse_sum();
}
// NOTE: View as c-contiguous to avoid extra reordering in mkldnn
// Use identity: C = AB <=> C^T = B^T A^T
ideep::tensor::dims a_strides{{lda, 1}}, b_strides{{ldb, 1}}, c_strides{{ldc, 1}};
if (transa != TransposeType::NoTranspose) {
std::swap(a_strides[0], a_strides[1]);
}
if (transb != TransposeType::NoTranspose) {
std::swap(b_strides[0], b_strides[1]);
}
ideep::tensor a({
/*sizes=*/{k, m},
ideep::tensor::data_type::bf16,
/*strides=*/a_strides},
const_cast<c10::BFloat16*>(a_data));
ideep::tensor b({
/*sizes=*/{n, k},
ideep::tensor::data_type::bf16,
/*strides=*/b_strides},
const_cast<c10::BFloat16*>(b_data));
ideep::tensor c({
/*sizes=*/{n, m},
ideep::tensor::data_type::bf16,
/*strides=*/c_strides},
c_data);
ideep::matmul_forward::compute(
b, a, c, alpha, beta,
ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
if (c.get_data_handle() != c_data){
// ideep will query onednn expect format of output
// if given output format is not expected, ideep will re-init an output buffer
// under this case, we need copy the re-inited buffer back to given buffer
ideep::tensor real_output({
/*sizes=*/{n, m},
ideep::tensor::data_type::bf16,
/*strides=*/c_strides},
c_data);
c.reorder_to(real_output);
}
return true;
}
void mkldnn_matmul(
const Tensor &mat1,
const Tensor &mat2,
const Tensor &result,
float beta,
float alpha) {
TORCH_CHECK((mat1.dim() == 2 && mat2.dim() == 2) || // aten::addmm
(mat1.dim() == 3 && mat2.dim() == 3) || // aten::bmm, aten::baddbmm
(mat1.dim() == 2 && mat2.dim() == 1) || // aten::mv
(mat1.dim() == 1 && mat2.dim() == 1), // aten::dot
"mkldnn_matmul: unsupported dims for mat and mat2");
TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 &&
mat2.scalar_type() == at::kBFloat16 &&
result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path");
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1;
auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;
auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result;
ideep::attr_t op_attr;
// "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor
// but mkldnn matmul primitive only support bias be 1-D tensors
// to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over
if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum();
// If alpha = 0, dose not need actually do gemm computation
if (alpha == 0)
return;
auto is_mkldnn_optimized_format = [&](const Tensor& t) {
if (t.is_contiguous()) return true;
const auto sizes = t.sizes();
const auto strides = t.strides();
if (t.dim() == 2){
return strides[0] == 1 && strides[1] == sizes[0];
} else {
// dim = 3
return strides[0] == sizes[1] * sizes[2] && strides[1] == 1 && strides[2] == sizes[1];
}
};
// Mkldnn only optimized for contiguous or transposed (transpose last 2 dim if 3-D tensor) format now
// Will remove this "contiguous" after mkldnn have fully supported
Tensor mat1_ = is_mkldnn_optimized_format(mat1_unsqueezed) ? mat1_unsqueezed : mat1_unsqueezed.contiguous();
Tensor mat2_ = is_mkldnn_optimized_format(mat2_unsqueezed) ? mat2_unsqueezed : mat2_unsqueezed.contiguous();
// mkldnn_matmul only proceed CPU tensor
const ideep::tensor x = itensor_view_from_dense(mat1_);
const ideep::tensor w = itensor_view_from_dense(mat2_);
ideep::tensor y = itensor_view_from_dense(result_unsqueezed);
ideep::matmul_forward::compute(x, w, y, alpha, beta,
ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
if (y.get_data_handle() != result.data_ptr()){
// ideep will query onednn expect format of output
// if given output format is not expected, ideep will re-init an output buffer
// under this case, we need copy the re-inited buffer back to given buffer
ideep::tensor public_y = itensor_view_from_dense(result);
y.reorder_to(public_y);
}
if (mat1.dim() == 1 && mat2.dim() == 1){
// aten::dot
result.squeeze_();
}
}
inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
if (mat1.dim() == 1 && mat2.dim() == 1) {
// aten::dot
return mat1.size(0) > mkldnn_gemm_min_size;
} else if (mat1.dim() == 2 && mat2.dim() == 1) {
// aten::mv
return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
} else if (mat2.dim() == 2 && mat2.dim() == 2) {
// aten::addmm
return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
} else {
// aten::bmm, aten::baddbmm
return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
}
}
bool use_mkldnn_bf16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (
use_mkldnn_bf16_matmul() &&
mat1.scalar_type() == kBFloat16 &&
mat2.scalar_type() == kBFloat16 &&
(!result.defined() || result.scalar_type() == kBFloat16) &&
mat1.numel() != 0 &&
mat2.numel() != 0 &&
checksize(mat1, mat2));
}
} // namespace native
} // namespace at
#endif // AT_MKLDNN_ENABLED