From 134979f57fc5be068ce19318b9d6d588db9005c7 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 9 Jan 2025 15:05:55 -0600 Subject: [PATCH] Fuse output reshapes --- src/targets/gpu/fuse_mlir.cpp | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index dce1a30cf0d..bc4e6a63e34 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1040,6 +1040,32 @@ struct find_unpack_int4_mlir_op } }; +struct find_mlir_reshape_ops +{ + auto matcher() const + { + auto reshapes = reshaper_names(); + // slice is not supported + reshapes.erase("slice"); + return match::name(reshapes)(match::arg(0)(match::name("gpu::mlir_op")(match::used_once())), match::used_once()); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto mlir_ins = ins->inputs().front(); + + auto* mm = mlir_ins->module_inputs().front(); + module_ref nm = mpm.create_module(mm->name() + ":" + ins->name()); + nm->set_bypass(); + + auto y = nm->fuse(*mm, mlir_ins->inputs()); + auto ret = nm->add_instruction(ins->get_operator(), y); + nm->add_return({ret}); + mpm.get_module().replace_instruction(ins, mlir_ins->get_operator(), mlir_ins->inputs(), {nm}); + } +}; + struct find_convolution_reshape { auto matcher() const @@ -1069,6 +1095,7 @@ struct find_convolution_reshape if(out_dims[1] < 4) return; auto reshape = mpm.get_module().insert_instruction(ins,ins->get_operator(), ins->inputs()); + // auto t2 = mpm.get_module().insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 3, 4, 2}}}), reshape); auto t1 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), reshape); auto c = mpm.get_module().insert_instruction(ins, make_op("contiguous"), t1); auto t2 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 4, 2, 3}}}), c); @@ -1097,9 +1124,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const return std::max(m1, m2); }; - mpm.get_module().debug_print(); match::find_matches(mpm, find_convolution_reshape{}); - mpm.get_module().debug_print(); // Attention offloads; default disabled if(mlir_attention_enabled(ctx) or enable_extra) { @@ -1131,6 +1156,9 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches(mpm, find_pointwise_mlir{}); match::find_matches(mpm, find_unpack_int4_mlir_op{}); + + for(int i=0;i<4;i++) + match::find_matches(mpm, find_mlir_reshape_ops{}); #else (void)mpm;