Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address ZeroK case for Gemm for CPU and CUDA #22111

Merged
merged 12 commits into from
Sep 21, 2024
4 changes: 1 addition & 3 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ onnxruntime_add_shared_library_module(onnxruntime_pybind11_state ${onnxruntime_p

if(MSVC)
target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>" "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
if(onnxruntime_ENABLE_TRAINING)
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj")
endif()
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj")
snnn marked this conversation as resolved.
Show resolved Hide resolved
endif()
if(HAS_CAST_FUNCTION_TYPE)
target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type")
Expand Down
59 changes: 37 additions & 22 deletions onnxruntime/core/providers/cpu/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ void Gemm<T>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b,
// Broadcast the bias as needed if bias is given
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);

if (K == 0) {
if (beta == 0 || c_data == nullptr) {
auto output_span = gsl::make_span(y_data, SafeInt<size_t>(M) * N);
std::fill(output_span.begin(), output_span.end(), T{});
}
return;
}

math::Gemm<T>(trans_a, trans_b,
M, N, K,
alpha,
Expand All @@ -179,16 +187,18 @@ void Gemm<MLFloat16>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans
if (M == 0 || N == 0)
return;

#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
// MLFloat16's constructor is explicit, so here we need to use memset
if (K == 0) {
if (beta != onnxruntime::MLFloat16::Zero && c_data != nullptr) {
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);
} else {
auto output_span = gsl::make_span(y_data, SafeInt<size_t>(M) * N);
std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero);
}
return;
}

if (c_data == nullptr)
memset(&beta, 0, sizeof(MLFloat16));
#if defined(__GNUC__) && defined(HAS_CLASS_MEMACCESS)
#pragma GCC diagnostic pop
#endif
beta = onnxruntime::MLFloat16::Zero;
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
bool support_mlas = false;
if (c_shape == nullptr) {
Expand Down Expand Up @@ -413,19 +423,24 @@ Status Gemm<float>::Compute(OpKernelContext* context) const {
c_data, c_shape, y_data, thread_pool);
} else {
GemmBroadcastBias(M, N, beta_, c_data, c_shape, y_data);
MlasGemm(
trans_A_,
static_cast<size_t>(M),
static_cast<size_t>(N),
static_cast<size_t>(K),
alpha_,
A->Data<float>(),
static_cast<size_t>(trans_A_ != CblasNoTrans ? M : K),
packed_b_.get(),
c_data != nullptr ? beta_ : 0.0f,
y_data,
static_cast<size_t>(N),
thread_pool);
if (K > 0) {
MlasGemm(
trans_A_,
static_cast<size_t>(M),
static_cast<size_t>(N),
static_cast<size_t>(K),
alpha_,
A->Data<float>(),
static_cast<size_t>(trans_A_ != CblasNoTrans ? M : K),
packed_b_.get(),
c_data != nullptr ? beta_ : 0.0f,
y_data,
static_cast<size_t>(N),
thread_pool);
} else {
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
auto output_span = Y->MutableDataAsSpan<float>();
std::fill(output_span.begin(), output_span.end(), onnxruntime::MLFloat16::Zero);
}
}

ComputeActivation(y_data, SafeInt<size_t>(M) * N, thread_pool);
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cpu/math/gemm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class GemmHelper {
status_ = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast");

// it is possible the input is empty tensor, for example the output of roipool in fast rcnn.
ORT_ENFORCE(M_ >= 0 && K_ > 0 && N_ >= 0);
// it is also possible that K == 0
ORT_ENFORCE(M_ >= 0 && K_ >= 0 && N_ >= 0);
}

ptrdiff_t M() const { return M_; }
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/cuda/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ Status Gemm<T>::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const
}
}

if (K == 0) {
if (beta_ == 0 || B == nullptr) {
// When we have (M, 0, N) then the output should be filled out with zeros
// unless we have a bias
Fill<CudaT>(Stream(ctx), reinterpret_cast<CudaT*>(Y->MutableData<T>()), CudaT(0.f),
Y->Shape().Size());
}
return Status::OK();
}

CudaT alpha = ToCudaType<T>::FromFloat(alpha_);
CudaT beta = ToCudaType<T>::FromFloat(beta_);
// Gemm, note that CUDA assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y
Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/test/providers/cpu/math/gemm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,47 @@ TYPED_TEST(GemmOpTypedTests, GemmEmptyTensor) {
.Config(run_with_tunable_op)
.RunWithConfig();
}

TYPED_TEST(GemmOpTypedTests, ZeroKWithBias) {
snnn marked this conversation as resolved.
Show resolved Hide resolved
OpTester test("Gemm", 13);

test.AddAttribute("transA", static_cast<int64_t>(0));
test.AddAttribute("transB", static_cast<int64_t>(0));
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", 1.0f);

test.AddInput<TypeParam>("A", {4, 0}, {});
test.AddInput<TypeParam>("B", {0, 4}, {});
test.AddInput<TypeParam>("C", {4}, std::vector<TypeParam>(4, static_cast<TypeParam>(1.0f)));
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(1.0f)));

test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
kOpenVINOExecutionProvider})
.Config(run_with_tunable_op)
.RunWithConfig();
}

TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) {
OpTester test("Gemm", 13);

test.AddAttribute("transA", static_cast<int64_t>(0));
test.AddAttribute("transB", static_cast<int64_t>(0));
test.AddAttribute("alpha", 1.0f);
test.AddAttribute("beta", .0f);

test.AddInput<TypeParam>("A", {4, 0}, {});
test.AddInput<TypeParam>("B", {0, 4}, {});
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(0.0f)));

test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
kOpenVINOExecutionProvider})
.Config(run_with_tunable_op)
.RunWithConfig();
}


TYPED_TEST(GemmOpTypedTests, MissingBias) {
OpTester test("Gemm", 11);

Expand Down
Loading