From a924f45eda4e94b45602f29758a147d9d4f90cf8 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Thu, 7 Mar 2024 16:36:03 +0000 Subject: [PATCH] Added a test with non-standard bias shape --- .../migraphx/gpu/gemm_softmax_gemm.hpp | 2 +- test/verify/gemm_softmax_gemm_relu.cpp | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp index 8ce4ec69b61..ec8561e779c 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp @@ -70,7 +70,7 @@ struct gemm_softmax_gemm if(is_bias_enabled) { auto bias_shape = inputs[2]; - if(bias_shape != gemm0_shape) + if(bias_shape.lens() != gemm0_shape.lens()) { std::stringstream err_msg; err_msg << name() << ": has inconsistent bias size" diff --git a/test/verify/gemm_softmax_gemm_relu.cpp b/test/verify/gemm_softmax_gemm_relu.cpp index c6201f2f5d1..d3f08d88108 100644 --- a/test/verify/gemm_softmax_gemm_relu.cpp +++ b/test/verify/gemm_softmax_gemm_relu.cpp @@ -27,8 +27,9 @@ #include #include -template -struct gemm_softmax_gemm_relu : verify_program> +template +struct gemm_softmax_gemm_relu + : verify_program> { migraphx::program create_program() const { @@ -50,7 +51,14 @@ struct gemm_softmax_gemm_relu : verify_program> std::optional add_bias{std::nullopt}; if constexpr(WithBias) { - auto bias = mm->add_parameter("4", m1_shape); + auto bias_strides = m1_shape.strides(); + if(not WithStandardBiasShape) + { + const auto last_index{bias_strides.size() - 1}; + std::swap(bias_strides[last_index], bias_strides[last_index - 1]); + } + migraphx::shape bias_shape{m1_shape.type(), m1_shape.lens(), bias_strides}; + auto bias = mm->add_parameter("4", bias_shape); add_bias = mm->add_instruction(migraphx::make_op("add"), scale, bias); } @@ -63,5 +71,6 @@ struct gemm_softmax_gemm_relu : verify_program> std::string section() const { return "gemm"; } }; -template struct gemm_softmax_gemm_relu; -template struct gemm_softmax_gemm_relu; +template struct gemm_softmax_gemm_relu; +template struct gemm_softmax_gemm_relu; +template struct gemm_softmax_gemm_relu;