From fa96a3e8f24934cbe640c5ab78138486d2e2dbf5 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Fri, 6 Oct 2023 09:32:48 +0200 Subject: [PATCH] Use fallback reference for 8-bit fully connnected operators with non-zero filter offset in CMSIS-NN Change-Id: I10a732d7ab0c7bb15ce6517e9ae1cb8a42d50b7e --- .../micro/kernels/cmsis_nn/fully_connected.cc | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc index d9729a5fb51..dbba5b27f58 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc @@ -104,7 +104,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims); - } else if (input->type == kTfLiteInt8) { + } else if (input->type == kTfLiteInt8 && + data->reference_op_data.filter_zero_point == 0) { const RuntimeShape input_shape = GetTensorShape(input); TFLITE_DCHECK_GE(output_dim_count, 2); @@ -374,8 +375,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: { switch (filter_int8.type) { case kTfLiteInt8: - return EvalQuantizedInt8(context, node, data, input, &filter_int8, - bias, output); + if (data.reference_op_data.filter_zero_point == 0) { + return EvalQuantizedInt8(context, node, data, input, &filter_int8, + bias, output); + } else { + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data.reference_op_data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; + } default: MicroPrintf("Filter Type %s (%d) not supported.", TfLiteTypeGetName(filter->type), filter->type);