From eb2a4bc8e2ca982427983f999dc61d7c0861088f Mon Sep 17 00:00:00 2001 From: Anatol Liu Date: Tue, 10 Oct 2023 10:35:54 -0400 Subject: [PATCH] [TFLite][Frontend] Fix test failures caused by div-by-zero (#15844) * [TFLite][Frontend] Support quantized floor_mod * [TVM][Frontend] Fix zero-point issues in quantized div/floor_div * [TVM][Frontend] Fix zero-point issues in quantized div/floor_div --- tests/python/frontend/tflite/test_forward.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index d9a53fd51719..f60166702454 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2452,6 +2452,7 @@ def _test_elemwise( qnn_op=None, same_qnn_params=False, comparison_op=False, + exclude_zero_point=False, ): """One iteration of elemwise""" @@ -2480,6 +2481,16 @@ def __test_elemwise(in_data): inq0_min, inq0_max = (out_min, out_max) inq1_min, inq1_max = (out_min, out_max) + if exclude_zero_point: + if inq1_max == inq1_min: + raise ZeroDivisionError("Input range is 0.") + + # only compute for rhs. + quant_scale = 255 / (inq1_max - inq1_min) + zero_point = int(round(-inq1_min * quant_scale)) + data[1][data[1] == zero_point] += 1 + data[1][data[1] == 0] += 1 + # fake_quant will keep the tensors in float32 until the conversion in the session inq_data = [ tf.quantization.fake_quant_with_min_max_args( @@ -2619,6 +2630,7 @@ def _test_div(data, fused_activation_function=None, quantized=False, qnn_op=None quantized, qnn_op, same_qnn_params=True, + exclude_zero_point=True, ) @@ -2802,6 +2814,7 @@ def _test_floor_divide(data, fused_activation_function=None, quantized=False, qn quantized, qnn_op, same_qnn_params=True, + exclude_zero_point=True, ) @@ -2882,7 +2895,7 @@ def _test_elemwise_qnn_out_range(qnn_op): def test_all_elemwise(): - """All_elewise""" + """All_elemwise""" _test_forward_elemwise(_test_add) _test_forward_elemwise_quantized(_test_add) _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))