Skip to content

Commit

Permalink
Added a test with non-standard bias shape
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Mar 7, 2024
1 parent da734c6 commit bfa1fbc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct gemm_softmax_gemm
if(is_bias_enabled)
{
auto bias_shape = inputs[2];
if(bias_shape != gemm0_shape)
if(bias_shape.lens() != gemm0_shape.lens())
{
std::stringstream err_msg;
err_msg << name() << ": has inconsistent bias size"
Expand Down
23 changes: 15 additions & 8 deletions test/verify/gemm_softmax_gemm_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <bool WithBias>
struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu<WithBias>>
template <bool WithBias, bool WithStandardBiasShape>
struct gemm_softmax_gemm_relu
: verify_program<gemm_softmax_gemm_relu<WithBias, WithStandardBiasShape>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = m2_shape.elements();
auto m2_elements = m1_shape.elements();
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
auto eight = mm->add_literal(migraphx::literal{m1_shape, eights});

b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
Expand All @@ -50,7 +50,13 @@ struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu<WithBias>>
std::optional<migraphx::instruction_ref> add_bias{std::nullopt};
if constexpr(WithBias)
{
auto bias = mm->add_parameter("4", m1_shape);
auto bias_shape = m1_shape;
if(not WithStandardBiasShape)
{
bias_shape = migraphx::shape::from_permutation(
bias_shape.type(), bias_shape.lens(), {0, 1, 3, 2});
}
auto bias = mm->add_parameter("4", bias_shape);
add_bias = mm->add_instruction(migraphx::make_op("add"), scale, bias);
}

Expand All @@ -63,5 +69,6 @@ struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu<WithBias>>
std::string section() const { return "gemm"; }
};

template struct gemm_softmax_gemm_relu<false>;
template struct gemm_softmax_gemm_relu<true>;
template struct gemm_softmax_gemm_relu<false, true>;
template struct gemm_softmax_gemm_relu<true, true>;
template struct gemm_softmax_gemm_relu<true, false>;

0 comments on commit bfa1fbc

Please sign in to comment.