Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tile scale and bias for block quantization #3718

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/include/migraphx/op/dequantizelinear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ struct dequantizelinear
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims().has(2, 3);
if(inputs.size() == 3 and inputs[0].type() != inputs[2].type())
{
MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type.");
}
return inputs[0].with_lens(inputs[1].type(), inputs[0].lens());
}

Expand Down
5 changes: 2 additions & 3 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
auto outputs = ins->outputs();
for(auto out : outputs)
{
// TODO: Check for possible cycles
if(out != rep)
if(out != rep and not reaches(out, rep))
{
instruction::replace_argument(out, ins, rep);
}
Expand All @@ -385,7 +384,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(),
ins->outputs().end(),
[&](auto i) { return i == rep; }));
[&](auto i) { return i == rep or reaches(i, rep); }));
assert(ins->valid(begin()));
assert(rep->valid(begin()));
return rep;
Expand Down
56 changes: 56 additions & 0 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,60 @@ void add_int4_pack_unpack_pair(module& m)
}
}

struct match_block_dequantize
{
std::size_t pack_by = 32;

static auto reshape_broadcast(const std::string& name)
{
auto broadcast_input = match::name("multibroadcast", "broadcast")(match::arg(0)(match::is_constant().bind(name)));
return match::name("reshape")(match::arg(0)(broadcast_input));
}

auto matcher() const
{
return match::name("dequantizelinear")(match::arg(1)(reshape_broadcast("scale")), match::arg(2)(reshape_broadcast("zero")));
}


bool can_pack(const shape& s) const
{
return (s.lens().front() % pack_by) == 0;
}

void pack_layout(module& m, instruction_ref ins) const
{
auto dims = ins->get_shape().lens();

dims.front() /= pack_by;
dims.insert(dims.begin() + 1, pack_by);

std::vector<std::size_t> perm(dims.size());
std::iota(perm.begin() + 1, perm.end() - 1, 2);
perm.back() = 1;

auto insert_ins = std::next(ins);
auto reshape1 = m.insert_instruction(insert_ins, make_op("reshape", {{"dims", dims}}), ins);
auto layout = m.insert_instruction(insert_ins, make_op("layout", {{"permutation", perm}}), reshape1);
auto reshape2 = m.insert_instruction(insert_ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), layout);
m.replace_instruction(ins, reshape2);
}

void apply(module& m, const match::matcher_result& r) const
{
auto scale_ins = r.instructions["scale"];
auto zero_ins = r.instructions["zero"];

if(not can_pack(scale_ins->get_shape()))
return;
if(not can_pack(zero_ins->get_shape()))
return;

pack_layout(m, scale_ins);
pack_layout(m, zero_ins);
}
};

} // namespace

void simplify_qdq::apply(module& m) const
Expand All @@ -468,6 +522,8 @@ void simplify_qdq::apply(module& m) const
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_block_dequantize{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
}

Expand Down
5 changes: 2 additions & 3 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
normalize_ops{},
dead_code_elimination{},
simplify_reshapes{},
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
simplify_qdq{},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
// workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling
eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
insert_pad{{"convolution"}},
dead_code_elimination{},
Expand Down
Loading