diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index d22876a274b..cecacc576e9 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -324,38 +324,39 @@ bool precedes(Iterator x, Iterator y, Iterator last) return any_of(iterator_for(r), [&](auto it) { return it == y; }); } -struct match_qlinear_reused -{ - auto matcher() const - { - return match::name("quantizelinear")( - match::used_once(), match::arg(0)(match::none_of(match::used_once()).bind("x"))); - } - - void apply(module& m, const match::matcher_result& r) const - { - auto ins = r.result; - auto x_ins = r.instructions["x"]; - assert(ins != x_ins); - - auto dq_inputs = ins->inputs(); - dq_inputs[0] = ins; - auto outputs = x_ins->outputs(); - if(outputs.size() != 2) - return; - for(auto output : outputs) - { - if(output->name() == "quantizelinear") - continue; - if(not output->get_operator().attributes().contains("pointwise")) - continue; - if(not precedes(ins, output, m.end())) - continue; - auto dq = m.insert_instruction(std::next(ins), make_op("dequantizelinear"), dq_inputs); - instruction::replace_argument(output, x_ins, dq); - } - } -}; +// TODO: disabled for 6.2 release due to accuracy bug for quantized resnet50 +// struct match_qlinear_reused +//{ +// auto matcher() const +// { +// return match::name("quantizelinear")( +// match::used_once(), match::arg(0)(match::none_of(match::used_once()).bind("x"))); +// } +// +// void apply(module& m, const match::matcher_result& r) const +// { +// auto ins = r.result; +// auto x_ins = r.instructions["x"]; +// assert(ins != x_ins); +// +// auto dq_inputs = ins->inputs(); +// dq_inputs[0] = ins; +// auto outputs = x_ins->outputs(); +// if(outputs.size() != 2) +// return; +// for(auto output : outputs) +// { +// if(output->name() == "quantizelinear") +// continue; +// if(not output->get_operator().attributes().contains("pointwise")) +// continue; +// if(not precedes(ins, output, m.end())) +// continue; +// auto dq = m.insert_instruction(std::next(ins), make_op("dequantizelinear"), +// dq_inputs); instruction::replace_argument(output, x_ins, dq); +// } +// } +// }; bool is_same_value(instruction_ref a, instruction_ref b) { @@ -401,7 +402,8 @@ void simplify_qdq::apply(module& m) const migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); remove_qdq_pairs(m); migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); - match::find_matches(m, match_qlinear_reused{}); + // TODO: disabled for 6.2 release due to accuracy bug for quantized resnet50 + // match::find_matches(m, match_qlinear_reused{}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index 86c2a7b082d..1e331a346ff 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -1478,68 +1478,69 @@ TEST_CASE(dot_correctness) EXPECT(migraphx::verify::verify_rms_range(rv1, rv2)); } -TEST_CASE(dot_reused) -{ - migraphx::shape sh{migraphx::shape::float_type, {256, 256}}; - - migraphx::module m1; - { - auto x = m1.add_parameter("x", sh); - auto y = m1.add_parameter("y", sh); - auto w1 = m1.add_parameter("w1", sh); - auto w2 = m1.add_parameter("w2", sh); - auto scale = m1.add_literal(0.5f); - auto zero = m1.add_literal(std::int8_t{0}); - - auto q1 = add_quantize_op(m1, "quantizelinear", x, scale, zero); - auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); - auto q2 = add_quantize_op(m1, "quantizelinear", w1, scale, zero); - auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); - auto dot1 = m1.add_instruction(migraphx::make_op("dot"), d1, d2); - auto add1 = m1.add_instruction(migraphx::make_op("add"), dot1, y); - - auto q3 = add_quantize_op(m1, "quantizelinear", add1, scale, zero); - auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero); - auto q4 = add_quantize_op(m1, "quantizelinear", w2, scale, zero); - auto d4 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); - auto dot2 = m1.add_instruction(migraphx::make_op("dot"), d3, d4); - auto add2 = m1.add_instruction(migraphx::make_op("add"), dot2, add1); - m1.add_return({add2}); - } - - migraphx::module m2; - { - auto x = m2.add_parameter("x", sh); - auto y = m2.add_parameter("y", sh); - auto w1 = m2.add_parameter("w1", sh); - auto w2 = m2.add_parameter("w2", sh); - auto scale = m2.add_literal(0.5f); - auto zero = m2.add_literal(std::int8_t{0}); - auto zero2 = m2.add_literal(std::int32_t{0}); - - auto q1 = add_quantize_op(m2, "quantizelinear", x, scale, zero); - auto q2 = add_quantize_op(m2, "quantizelinear", w1, scale, zero); - - auto dot1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); - auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); - auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1, zero2); - auto add1 = m2.add_instruction(migraphx::make_op("add"), d1, y); - - auto q3 = add_quantize_op(m2, "quantizelinear", add1, scale, zero); - auto q4 = add_quantize_op(m2, "quantizelinear", w2, scale, zero); - auto dot2 = m2.add_instruction(migraphx::make_op("quant_dot"), q3, q4); - auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); - auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2, zero2); - auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1], q3->inputs()[2]); - auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3); - m2.add_return({add2}); - } - - run_pass(m1); - run_cse(m1); - run_cse(m2); - EXPECT(m1.sort() == m2.sort()); -} +// TODO: disabled for 6.2 release due to accuracy bug for quantized resnet50 +// TEST_CASE(dot_reused) +//{ +// migraphx::shape sh{migraphx::shape::float_type, {256, 256}}; +// +// migraphx::module m1; +// { +// auto x = m1.add_parameter("x", sh); +// auto y = m1.add_parameter("y", sh); +// auto w1 = m1.add_parameter("w1", sh); +// auto w2 = m1.add_parameter("w2", sh); +// auto scale = m1.add_literal(0.5f); +// auto zero = m1.add_literal(std::int8_t{0}); +// +// auto q1 = add_quantize_op(m1, "quantizelinear", x, scale, zero); +// auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); +// auto q2 = add_quantize_op(m1, "quantizelinear", w1, scale, zero); +// auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); +// auto dot1 = m1.add_instruction(migraphx::make_op("dot"), d1, d2); +// auto add1 = m1.add_instruction(migraphx::make_op("add"), dot1, y); +// +// auto q3 = add_quantize_op(m1, "quantizelinear", add1, scale, zero); +// auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero); +// auto q4 = add_quantize_op(m1, "quantizelinear", w2, scale, zero); +// auto d4 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); +// auto dot2 = m1.add_instruction(migraphx::make_op("dot"), d3, d4); +// auto add2 = m1.add_instruction(migraphx::make_op("add"), dot2, add1); +// m1.add_return({add2}); +// } +// +// migraphx::module m2; +// { +// auto x = m2.add_parameter("x", sh); +// auto y = m2.add_parameter("y", sh); +// auto w1 = m2.add_parameter("w1", sh); +// auto w2 = m2.add_parameter("w2", sh); +// auto scale = m2.add_literal(0.5f); +// auto zero = m2.add_literal(std::int8_t{0}); +// auto zero2 = m2.add_literal(std::int32_t{0}); +// +// auto q1 = add_quantize_op(m2, "quantizelinear", x, scale, zero); +// auto q2 = add_quantize_op(m2, "quantizelinear", w1, scale, zero); +// +// auto dot1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); +// auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); +// auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1, zero2); +// auto add1 = m2.add_instruction(migraphx::make_op("add"), d1, y); +// +// auto q3 = add_quantize_op(m2, "quantizelinear", add1, scale, zero); +// auto q4 = add_quantize_op(m2, "quantizelinear", w2, scale, zero); +// auto dot2 = m2.add_instruction(migraphx::make_op("quant_dot"), q3, q4); +// auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); +// auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2, zero2); +// auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1], +// q3->inputs()[2]); auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3); +// m2.add_return({add2}); +// } +// +// run_pass(m1); +// run_cse(m1); +// run_cse(m2); +// EXPECT(m1.sort() == m2.sort()); +// } TEST_CASE(dot_asymmetric_correctness) {