Skip to content

Commit

Permalink
comment out qlinear_reuse matcher and test
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 committed Jul 10, 2024
1 parent d15a23a commit 56bd330
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 95 deletions.
68 changes: 35 additions & 33 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,38 +324,39 @@ bool precedes(Iterator x, Iterator y, Iterator last)
return any_of(iterator_for(r), [&](auto it) { return it == y; });
}

struct match_qlinear_reused
{
auto matcher() const
{
return match::name("quantizelinear")(
match::used_once(), match::arg(0)(match::none_of(match::used_once()).bind("x")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
assert(ins != x_ins);

auto dq_inputs = ins->inputs();
dq_inputs[0] = ins;
auto outputs = x_ins->outputs();
if(outputs.size() != 2)
return;
for(auto output : outputs)
{
if(output->name() == "quantizelinear")
continue;
if(not output->get_operator().attributes().contains("pointwise"))
continue;
if(not precedes(ins, output, m.end()))
continue;
auto dq = m.insert_instruction(std::next(ins), make_op("dequantizelinear"), dq_inputs);
instruction::replace_argument(output, x_ins, dq);
}
}
};
// TODO: disabled for 6.2 release due to accuracy bug for quantized resnet50
// struct match_qlinear_reused
//{
// auto matcher() const
// {
// return match::name("quantizelinear")(
// match::used_once(), match::arg(0)(match::none_of(match::used_once()).bind("x")));
// }
//
// void apply(module& m, const match::matcher_result& r) const
// {
// auto ins = r.result;
// auto x_ins = r.instructions["x"];
// assert(ins != x_ins);
//
// auto dq_inputs = ins->inputs();
// dq_inputs[0] = ins;
// auto outputs = x_ins->outputs();
// if(outputs.size() != 2)
// return;
// for(auto output : outputs)
// {
// if(output->name() == "quantizelinear")
// continue;
// if(not output->get_operator().attributes().contains("pointwise"))
// continue;
// if(not precedes(ins, output, m.end()))
// continue;
// auto dq = m.insert_instruction(std::next(ins), make_op("dequantizelinear"),
// dq_inputs); instruction::replace_argument(output, x_ins, dq);
// }
// }
// };

bool is_same_value(instruction_ref a, instruction_ref b)
{
Expand Down Expand Up @@ -401,7 +402,8 @@ void simplify_qdq::apply(module& m) const
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_qdq_pairs(m);
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
// TODO: disabled for 6.2 release due to accuracy bug for quantized resnet50
// match::find_matches(m, match_qlinear_reused{});
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
125 changes: 63 additions & 62 deletions test/simplify_qdq_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1478,68 +1478,69 @@ TEST_CASE(dot_correctness)
EXPECT(migraphx::verify::verify_rms_range(rv1, rv2));
}

TEST_CASE(dot_reused)
{
migraphx::shape sh{migraphx::shape::float_type, {256, 256}};

migraphx::module m1;
{
auto x = m1.add_parameter("x", sh);
auto y = m1.add_parameter("y", sh);
auto w1 = m1.add_parameter("w1", sh);
auto w2 = m1.add_parameter("w2", sh);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});

auto q1 = add_quantize_op(m1, "quantizelinear", x, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", w1, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
auto add1 = m1.add_instruction(migraphx::make_op("add"), dot1, y);

auto q3 = add_quantize_op(m1, "quantizelinear", add1, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto q4 = add_quantize_op(m1, "quantizelinear", w2, scale, zero);
auto d4 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), d3, d4);
auto add2 = m1.add_instruction(migraphx::make_op("add"), dot2, add1);
m1.add_return({add2});
}

migraphx::module m2;
{
auto x = m2.add_parameter("x", sh);
auto y = m2.add_parameter("y", sh);
auto w1 = m2.add_parameter("w1", sh);
auto w2 = m2.add_parameter("w2", sh);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero2 = m2.add_literal(std::int32_t{0});

auto q1 = add_quantize_op(m2, "quantizelinear", x, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", w1, scale, zero);

auto dot1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens());
auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1, zero2);
auto add1 = m2.add_instruction(migraphx::make_op("add"), d1, y);

auto q3 = add_quantize_op(m2, "quantizelinear", add1, scale, zero);
auto q4 = add_quantize_op(m2, "quantizelinear", w2, scale, zero);
auto dot2 = m2.add_instruction(migraphx::make_op("quant_dot"), q3, q4);
auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens());
auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2, zero2);
auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1], q3->inputs()[2]);
auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3);
m2.add_return({add2});
}

run_pass(m1);
run_cse(m1);
run_cse(m2);
EXPECT(m1.sort() == m2.sort());
}
// TODO: disabled for 6.2 release due to accuracy bug for quantized resnet50
// TEST_CASE(dot_reused)
//{
// migraphx::shape sh{migraphx::shape::float_type, {256, 256}};
//
// migraphx::module m1;
// {
// auto x = m1.add_parameter("x", sh);
// auto y = m1.add_parameter("y", sh);
// auto w1 = m1.add_parameter("w1", sh);
// auto w2 = m1.add_parameter("w2", sh);
// auto scale = m1.add_literal(0.5f);
// auto zero = m1.add_literal(std::int8_t{0});
//
// auto q1 = add_quantize_op(m1, "quantizelinear", x, scale, zero);
// auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
// auto q2 = add_quantize_op(m1, "quantizelinear", w1, scale, zero);
// auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
// auto dot1 = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
// auto add1 = m1.add_instruction(migraphx::make_op("add"), dot1, y);
//
// auto q3 = add_quantize_op(m1, "quantizelinear", add1, scale, zero);
// auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
// auto q4 = add_quantize_op(m1, "quantizelinear", w2, scale, zero);
// auto d4 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
// auto dot2 = m1.add_instruction(migraphx::make_op("dot"), d3, d4);
// auto add2 = m1.add_instruction(migraphx::make_op("add"), dot2, add1);
// m1.add_return({add2});
// }
//
// migraphx::module m2;
// {
// auto x = m2.add_parameter("x", sh);
// auto y = m2.add_parameter("y", sh);
// auto w1 = m2.add_parameter("w1", sh);
// auto w2 = m2.add_parameter("w2", sh);
// auto scale = m2.add_literal(0.5f);
// auto zero = m2.add_literal(std::int8_t{0});
// auto zero2 = m2.add_literal(std::int32_t{0});
//
// auto q1 = add_quantize_op(m2, "quantizelinear", x, scale, zero);
// auto q2 = add_quantize_op(m2, "quantizelinear", w1, scale, zero);
//
// auto dot1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
// auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens());
// auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1, zero2);
// auto add1 = m2.add_instruction(migraphx::make_op("add"), d1, y);
//
// auto q3 = add_quantize_op(m2, "quantizelinear", add1, scale, zero);
// auto q4 = add_quantize_op(m2, "quantizelinear", w2, scale, zero);
// auto dot2 = m2.add_instruction(migraphx::make_op("quant_dot"), q3, q4);
// auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens());
// auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2, zero2);
// auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1],
// q3->inputs()[2]); auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3);
// m2.add_return({add2});
// }
//
// run_pass(m1);
// run_cse(m1);
// run_cse(m2);
// EXPECT(m1.sort() == m2.sort());
// }

TEST_CASE(dot_asymmetric_correctness)
{
Expand Down

0 comments on commit 56bd330

Please sign in to comment.