From 5a46ab0e10d3bceeca5d92c98387055a4fa7a2c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Fri, 1 Nov 2024 20:57:00 +0100 Subject: [PATCH] Add kernels optimized for size flag to FC and SVDF (#2734) The kernels optimized for size flag provides an alternative implementation where size is prioritized over latency. For size option (speed option is default) it means the CMSIS-NN kernels are calculating kernel sums during inference. BUG=no bug but this will let users prioritize speed vs size even more --- .../micro/kernels/cmsis_nn/fully_connected.cc | 11 +++++-- .../lite/micro/kernels/cmsis_nn/svdf.cc | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc index 8e6fc5a9ccb..7b4e1319532 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc @@ -148,15 +148,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } else if (input->type == kTfLiteInt8) { buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims); - int8_t* filter_data = GetTensorData(filter); data->kernel_sums = nullptr; +#if defined(KERNELS_OPTIMIZED_FOR_SPEED) + const int8_t* filter_data = GetTensorData(filter); + if (buf_size > 0 && filter_data != nullptr) { + const int32_t input_offset = -data->reference_op_data.input_zero_point; + const int32_t filter_offset = + -data->reference_op_data.filter_zero_point; + data->kernel_sums = static_cast( context->AllocatePersistentBuffer(context, buf_size)); - int32_t input_offset = -data->reference_op_data.input_zero_point; - int32_t filter_offset = -data->reference_op_data.filter_zero_point; arm_vector_sum_s8(data->kernel_sums, filter_dims.n, data->output_depth, filter_data, input_offset, filter_offset, tflite::GetTensorData(bias)); @@ -164,6 +168,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Do not request a scratch buffer since using persistent memory buf_size = 0; } +#endif } } diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc b/tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc index d39ae616c0f..b48dcb4a69d 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc @@ -39,6 +39,9 @@ struct CmsisNnOpDataSvdf { int effective_scale_1_b; int effective_scale_2_b; int scratch_tensor_index; +#if defined(KERNELS_OPTIMIZED_FOR_SIZE) + int scratch_weight_tensor_index; +#endif int scratch_output_tensor_index; // Cached tensor zero point values for quantized operations. @@ -189,6 +192,7 @@ TfLiteStatus CmsisNnPrepareSvdf(TfLiteContext* context, TfLiteNode* node) { const int32_t buf_size = arm_svdf_s8_get_buffer_size(&weights_feature_dims); if (buf_size > 0) { +#if defined(KERNELS_OPTIMIZED_FOR_SPEED) data->kernel_sums = static_cast( context->AllocatePersistentBuffer(context, buf_size)); @@ -196,6 +200,17 @@ TfLiteStatus CmsisNnPrepareSvdf(TfLiteContext* context, TfLiteNode* node) { GetTensorData(weights_feature), -data->input_zero_point, -data->activation_state_zero_point, nullptr); +#elif defined(KERNELS_OPTIMIZED_FOR_SIZE) + const TfLiteStatus scratch_kernel_status = + context->RequestScratchBufferInArena( + context, buf_size, &(data->scratch_weight_tensor_index)); + TF_LITE_ENSURE_OK(context, scratch_kernel_status); +#else + MicroPrintf( + "Either KERNELS_OPTIMIZED_FOR_SIZE or KERNELS_OPTIMIZED_FOR_SPEED " + "must be defined"); + return kTfLiteError; +#endif } } else { @@ -291,7 +306,21 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, switch (weights_time_tensor->type) { case kTfLiteInt8: { cmsis_nn_context ctx; + +#if defined(KERNELS_OPTIMIZED_FOR_SPEED) ctx.buf = data.kernel_sums; +#elif defined(KERNELS_OPTIMIZED_FOR_SIZE) + ctx.buf = static_cast( + context->GetScratchBuffer(context, data.scratch_weight_tensor_index)); + + const int input_size = input_tensor->dims->data[1]; + const int num_filters = weights_feature_tensor->dims->data[0]; + + arm_vector_sum_s8( + static_cast(ctx.buf), input_size, num_filters, + tflite::micro::GetTensorData(weights_feature_tensor), + -data.input_zero_point, -data.activation_state_zero_point, nullptr); +#endif arm_svdf_s8( &ctx, &scratch_ctx, &scratch_output_ctx, &svdf_params,