From c6d04e170dc94238b79310b2b4eca84a4e2e3243 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 16 Dec 2024 15:35:06 -0600 Subject: [PATCH] Tile scale and bias for block quantization --- src/include/migraphx/op/dequantizelinear.hpp | 4 -- src/module.cpp | 5 +- src/simplify_qdq.cpp | 56 ++++++++++++++++++++ src/targets/gpu/target.cpp | 5 +- 4 files changed, 60 insertions(+), 10 deletions(-) diff --git a/src/include/migraphx/op/dequantizelinear.hpp b/src/include/migraphx/op/dequantizelinear.hpp index 60500b168d6..8934ca2979f 100644 --- a/src/include/migraphx/op/dequantizelinear.hpp +++ b/src/include/migraphx/op/dequantizelinear.hpp @@ -50,10 +50,6 @@ struct dequantizelinear shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.same_dims().has(2, 3); - if(inputs.size() == 3 and inputs[0].type() != inputs[2].type()) - { - MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type."); - } return inputs[0].with_lens(inputs[1].type(), inputs[0].lens()); } diff --git a/src/module.cpp b/src/module.cpp index 7e02478b385..471bb8c9934 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -373,8 +373,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref auto outputs = ins->outputs(); for(auto out : outputs) { - // TODO: Check for possible cycles - if(out != rep) + if(out != rep and not reaches(out, rep)) { instruction::replace_argument(out, ins, rep); } @@ -385,7 +384,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref // Output of the original instruction should only be the replacement or empty assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(), ins->outputs().end(), - [&](auto i) { return i == rep; })); + [&](auto i) { return i == rep or reaches(i, rep); })); assert(ins->valid(begin())); assert(rep->valid(begin())); return rep; diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 86c2100a995..44e28a2f244 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -456,6 +456,60 @@ void add_int4_pack_unpack_pair(module& m) } } +struct match_block_dequantize +{ + std::size_t pack_by = 32; + + static auto reshape_broadcast(const std::string& name) + { + auto broadcast_input = match::name("multibroadcast", "broadcast")(match::arg(0)(match::is_constant().bind(name))); + return match::name("reshape")(match::arg(0)(broadcast_input)); + } + + auto matcher() const + { + return match::name("dequantizelinear")(match::arg(1)(reshape_broadcast("scale")), match::arg(2)(reshape_broadcast("zero"))); + } + + + bool can_pack(const shape& s) const + { + return (s.lens().front() % pack_by) == 0; + } + + void pack_layout(module& m, instruction_ref ins) const + { + auto dims = ins->get_shape().lens(); + + dims.front() /= pack_by; + dims.insert(dims.begin() + 1, pack_by); + + std::vector perm(dims.size()); + std::iota(perm.begin() + 1, perm.end() - 1, 2); + perm.back() = 1; + + auto insert_ins = std::next(ins); + auto reshape1 = m.insert_instruction(insert_ins, make_op("reshape", {{"dims", dims}}), ins); + auto layout = m.insert_instruction(insert_ins, make_op("layout", {{"permutation", perm}}), reshape1); + auto reshape2 = m.insert_instruction(insert_ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), layout); + m.replace_instruction(ins, reshape2); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto scale_ins = r.instructions["scale"]; + auto zero_ins = r.instructions["zero"]; + + if(not can_pack(scale_ins->get_shape())) + return; + if(not can_pack(zero_ins->get_shape())) + return; + + pack_layout(m, scale_ins); + pack_layout(m, zero_ins); + } +}; + } // namespace void simplify_qdq::apply(module& m) const @@ -468,6 +522,8 @@ void simplify_qdq::apply(module& m) const migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); match::find_matches(m, match_qlinear_reused{}); migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + match::find_matches(m, match_block_dequantize{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); remove_zero_point(m); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index ad98fb680fe..469481f899d 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -177,7 +177,9 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, normalize_ops{}, dead_code_elimination{}, + simplify_reshapes{}, eliminate_identity{}, + eliminate_pad{}, dead_code_elimination{}, simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), @@ -185,9 +187,6 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti // workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}}, eliminate_data_type{unsupported_types, shape::type_t::float_type}, - simplify_reshapes{}, - eliminate_identity{}, - eliminate_pad{}, dead_code_elimination{}, insert_pad{{"convolution"}}, dead_code_elimination{},