Skip to content

Commit

Permalink
Added a test with non-standard bias shape
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Mar 7, 2024
1 parent da734c6 commit a924f45
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct gemm_softmax_gemm
if(is_bias_enabled)
{
auto bias_shape = inputs[2];
if(bias_shape != gemm0_shape)
if(bias_shape.lens() != gemm0_shape.lens())
{
std::stringstream err_msg;
err_msg << name() << ": has inconsistent bias size"
Expand Down
19 changes: 14 additions & 5 deletions test/verify/gemm_softmax_gemm_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <bool WithBias>
struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu<WithBias>>
template <bool WithBias, bool WithStandardBiasShape>
struct gemm_softmax_gemm_relu
: verify_program<gemm_softmax_gemm_relu<WithBias, WithStandardBiasShape>>
{
migraphx::program create_program() const
{
Expand All @@ -50,7 +51,14 @@ struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu<WithBias>>
std::optional<migraphx::instruction_ref> add_bias{std::nullopt};
if constexpr(WithBias)
{
auto bias = mm->add_parameter("4", m1_shape);
auto bias_strides = m1_shape.strides();
if(not WithStandardBiasShape)
{
const auto last_index{bias_strides.size() - 1};
std::swap(bias_strides[last_index], bias_strides[last_index - 1]);
}
migraphx::shape bias_shape{m1_shape.type(), m1_shape.lens(), bias_strides};
auto bias = mm->add_parameter("4", bias_shape);
add_bias = mm->add_instruction(migraphx::make_op("add"), scale, bias);
}

Expand All @@ -63,5 +71,6 @@ struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu<WithBias>>
std::string section() const { return "gemm"; }
};

template struct gemm_softmax_gemm_relu<false>;
template struct gemm_softmax_gemm_relu<true>;
template struct gemm_softmax_gemm_relu<false, true>;
template struct gemm_softmax_gemm_relu<true, true>;
template struct gemm_softmax_gemm_relu<true, false>;

0 comments on commit a924f45

Please sign in to comment.