Skip to content

Commit

Permalink
Fuse output reshapes
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Jan 9, 2025
1 parent 5c41972 commit 134979f
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,32 @@ struct find_unpack_int4_mlir_op
}
};

struct find_mlir_reshape_ops
{
auto matcher() const
{
auto reshapes = reshaper_names();
// slice is not supported
reshapes.erase("slice");
return match::name(reshapes)(match::arg(0)(match::name("gpu::mlir_op")(match::used_once())), match::used_once());
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto mlir_ins = ins->inputs().front();

auto* mm = mlir_ins->module_inputs().front();
module_ref nm = mpm.create_module(mm->name() + ":" + ins->name());
nm->set_bypass();

auto y = nm->fuse(*mm, mlir_ins->inputs());
auto ret = nm->add_instruction(ins->get_operator(), y);
nm->add_return({ret});
mpm.get_module().replace_instruction(ins, mlir_ins->get_operator(), mlir_ins->inputs(), {nm});
}
};

struct find_convolution_reshape
{
auto matcher() const
Expand Down Expand Up @@ -1069,6 +1095,7 @@ struct find_convolution_reshape
if(out_dims[1] < 4)
return;
auto reshape = mpm.get_module().insert_instruction(ins,ins->get_operator(), ins->inputs());
// auto t2 = mpm.get_module().insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 3, 4, 2}}}), reshape);
auto t1 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), reshape);
auto c = mpm.get_module().insert_instruction(ins, make_op("contiguous"), t1);
auto t2 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 4, 2, 3}}}), c);
Expand Down Expand Up @@ -1097,9 +1124,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
return std::max(m1, m2);
};

mpm.get_module().debug_print();
match::find_matches(mpm, find_convolution_reshape{});
mpm.get_module().debug_print();
// Attention offloads; default disabled
if(mlir_attention_enabled(ctx) or enable_extra)
{
Expand Down Expand Up @@ -1131,6 +1156,9 @@ void fuse_mlir::apply(module_pass_manager& mpm) const

match::find_matches(mpm, find_pointwise_mlir{});
match::find_matches(mpm, find_unpack_int4_mlir_op{});

for(int i=0;i<4;i++)
match::find_matches(mpm, find_mlir_reshape_ops{});

#else
(void)mpm;
Expand Down

0 comments on commit 134979f

Please sign in to comment.