Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar committed Jul 4, 2024
1 parent 24e21c0 commit 3a5ea24
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,97 @@ TEST_CASE(simplify_split_add_relu)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(simplify_split_add_flipped_input)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = m1.add_literal(1);
auto oneb = m1.add_instruction(b, one);
auto two = m1.add_literal(2);
auto twob = m1.add_instruction(b, two);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), twob, y);
auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2);
auto add = m1.add_instruction(migraphx::make_op("add"), relu1, relu2);
m1.add_instruction(pass_op{}, add);
}
run_pass(m1);

migraphx::module m2;
{
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concatb = m2.add_instruction(b, concat);
auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = m2.add_instruction(migraphx::make_op("relu"), sum);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto add = m2.add_instruction(migraphx::make_op("add"), x, y);
m2.add_instruction(pass_op{}, add);
}
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(simplify_split_non_comm_flipped_input)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = m1.add_literal(1);
auto oneb = m1.add_instruction(b, one);
auto two = m1.add_literal(2);
auto twob = m1.add_instruction(b, two);
auto sum1 = m1.add_instruction(migraphx::make_op("sub"), x, oneb);
auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = m1.add_instruction(migraphx::make_op("sub"), twob, y);
auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2);
auto add = m1.add_instruction(migraphx::make_op("sub"), relu1, relu2);
m1.add_instruction(pass_op{}, add);
}
run_pass(m1);

migraphx::module m2;
{
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m2.add_parameter("input", s);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = m2.add_literal(1);
auto neg_one = m2.add_instruction(migraphx::make_op("neg"), one);
auto oneb = m2.add_instruction(b, neg_one);
auto two = m2.add_literal(2);
auto twob = m2.add_instruction(b, two);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = m2.add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = m2.add_instruction(migraphx::make_op("sub"), twob, y);
auto relu2 = m2.add_instruction(migraphx::make_op("relu"), sum2);
auto add = m2.add_instruction(migraphx::make_op("sub"), relu1, relu2);
m2.add_instruction(pass_op{}, add);
}
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(simplify_split_reduce0)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
Expand Down

0 comments on commit 3a5ea24

Please sign in to comment.