Skip to content

Commit

Permalink
Keep output shape from pointwise instruction after fusing with concat (
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Mar 13, 2024
1 parent 4def653 commit 4bedf9e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ struct find_contiguous
}
};

struct find_contiguous_pointwise
struct find_pointwise_contiguous
{
auto matcher() const
{
Expand Down Expand Up @@ -842,15 +842,17 @@ struct find_concat_pointwise
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());

auto op = concat->get_operator();
op.from_value({{"additional_args", ins->inputs().size() - 1}, {"ignore_modules", true}});
op.from_value({{"additional_args", ins->inputs().size() - 1},
{"ignore_modules", true},
{"output_shape", to_value(ins->get_shape())}});

m.replace_instruction(ins, op, inputs, {pm});
}
};

void fuse_ops::apply(module& m) const
{
match::find_matches(m, find_contiguous_pointwise{});
match::find_matches(m, find_pointwise_contiguous{});
run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx});
run_passes(m, {dead_code_elimination{}});
Expand Down
67 changes: 66 additions & 1 deletion test/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ TEST_CASE(layernorm_pointwise)
}
}

TEST_CASE(contiguous_pointwise)
TEST_CASE(pointwise_contiguous)
{
migraphx::shape s1{migraphx::shape::float_type, {128, 4, 196, 32}};
migraphx::shape s2{migraphx::shape::float_type, {128, 196, 4, 32}};
Expand Down Expand Up @@ -160,4 +160,69 @@ TEST_CASE(contiguous_pointwise)
EXPECT(p1 == p2);
}

TEST_CASE(concat_pointwise_contiguous)
{
migraphx::shape s1 = migraphx::shape::from_permutation(
migraphx::shape::float_type, {128, 2, 196, 32}, {0, 2, 1, 3});
migraphx::shape s2 = migraphx::shape::from_permutation(
migraphx::shape::float_type, {128, 4, 196, 32}, {0, 2, 1, 3});
migraphx::shape s3{migraphx::shape::float_type, {128, 4, 196, 32}};
auto create_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x1 = mm->add_parameter("x1", s1);
auto x2 = mm->add_parameter("x2", s1);
auto y = mm->add_parameter("y", s2);

auto concat_op = migraphx::make_op("concat", {{"axis", 1}});
auto concat_precompile_op =
migraphx::make_op("gpu::precompile_op", {{"op", migraphx::to_value(concat_op)}});
auto x_alloc =
mm->add_instruction(migraphx::make_op("allocate", {{"shape", to_value(s2)}}));
auto x = mm->add_instruction(concat_precompile_op, {x1, x2, x_alloc});
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s3)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x, y}, single_pointwise("add"));

auto pw_op = migraphx::make_op("pointwise");
auto pre_comp_op = migraphx::make_op(
"gpu::precompile_op",
{{"op", migraphx::to_value(pw_op)}, {"output_shape", migraphx::to_value(s3)}});
auto add1 = mm->add_instruction(pre_comp_op, {x, y, alloc_ins}, {pw_add1});
auto rsp =
mm->add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {25088, 128}}}), add1);
mm->add_return({rsp});
return p;
};
auto create_fused_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x1 = mm->add_parameter("x1", s1);
auto x2 = mm->add_parameter("x2", s1);
auto y = mm->add_parameter("y", s2);
auto concat_op = migraphx::make_op("concat", {{"axis", 1}});
auto concat_precompile_op = migraphx::make_op("gpu::precompile_op",
{{"op", migraphx::to_value(concat_op)},
{"additional_args", 2},
{"ignore_modules", true},
{"output_shape", migraphx::to_value(s3)}});
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s3)}});
auto alloc_ins = mm->add_instruction(alloc);
// use y's input shape for creating pointwise module for both the params
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {y, y}, single_pointwise("add"));
auto x = mm->add_instruction(concat_precompile_op, {x1, x2, y, alloc_ins}, {pw_add1});
auto pw_op = migraphx::make_op("pointwise");
auto rsp =
mm->add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {25088, 128}}}), x);
mm->add_return({rsp});
return p;
};
migraphx::program p1 = create_program();
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 4bedf9e

Please sign in to comment.