Skip to content

Commit

Permalink
Enable int16_t support for CMSIS-NN LSTM kernel
Browse files Browse the repository at this point in the history
Change-Id: Icee7e3d4448883a68df011746dfd6f15e8445ef2
  • Loading branch information
AdrianLundell committed Apr 23, 2024
1 parent a17682d commit 4f25468
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ void CMSIS_NN_VectorSum(int32_t* kernel_sum, const int32_t size1,
arm_vector_sum_s8(kernel_sum, size1, size2, weights, offset, biases);
}

void CMSIS_NN_VectorSum(int64_t* kernel_sum, const int32_t size1,
const int32_t size2, const int8_t* weights,
const int32_t offset, const int64_t* biases) {
arm_vector_sum_s8_s64(kernel_sum, size1, size2, weights, offset, biases);
}

template <typename BiasType>
TfLiteStatus CMSIS_NN_PortOpData(TfLiteContext* context, OpDataLSTM* params_ref,
const LSTMKernelContents& kernel_content,
Expand Down Expand Up @@ -289,6 +295,32 @@ TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm(
return kTfLiteOk;
}

TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm(
const OpData& op_data, const LSTMKernelContents& kernel_content,
const LSTMBuffers<int16_t>& buffers) {
TFLITE_DCHECK(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >=
2 &&
kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size <=
3);

const int16_t* input = tflite::micro::GetOptionalTensorData<int16_t>(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor));
int16_t* output =
tflite::micro::GetTensorData<int16_t>(kernel_content.output_tensor);

// Create lstm buffer struct
cmsis_nn_lstm_context cmsis_buffers;
cmsis_buffers.temp1 = reinterpret_cast<int16_t*>(buffers.buffer0);
cmsis_buffers.temp2 = reinterpret_cast<int16_t*>(buffers.buffer1);
cmsis_buffers.cell_state = reinterpret_cast<int16_t*>(buffers.buffer2);

arm_lstm_unidirectional_s16(input, output, &op_data.params_cmsis_nn,
&cmsis_buffers);

return kTfLiteOk;
}

/*Kernel functions*/
void* UnidirectionalSequenceLstmInit(TfLiteContext* context, const char* buffer,
size_t length) {
Expand Down Expand Up @@ -351,6 +383,12 @@ TfLiteStatus UnidirectionalSequenceLstmPrepare(TfLiteContext* context,
number_of_buffers = 3;
CMSIS_NN_PortOpData<int32_t>(context, op_data_lstm, kernel_content,
&op_data->params_cmsis_nn);
} else if (activation_type == kTfLiteInt16 &&
cell_state_type == kTfLiteInt16) {
auto kernel_content = CreateLSTMKernelContent(context, node);
number_of_buffers = 3;
CMSIS_NN_PortOpData<int64_t>(context, op_data_lstm, kernel_content,
&op_data->params_cmsis_nn);
} else {
number_of_buffers = 4;
}
Expand Down Expand Up @@ -394,8 +432,7 @@ TfLiteStatus UnidirectionalSequenceLstmEval(TfLiteContext* context,
// 8(activation)x8(weight)->16(cell) LSTM with 32 bits bias
LSTMBuffers<int16_t> buffers =
CMSIS_NN_CreateLSTMBuffers(context, op_data_lstm.buffer_indices);
return CMSIS_NN_EvalInteger8x8_16Lstm(op_data, kernel_content,
buffers);
CMSIS_NN_EvalInteger8x8_16Lstm(op_data, kernel_content, buffers);
break;
}
default: {
Expand All @@ -411,9 +448,8 @@ TfLiteStatus UnidirectionalSequenceLstmEval(TfLiteContext* context,
case kTfLiteInt8: {
// 16(activation)x8(weight)->16(cell) LSTM with 64 bits bias
LSTMBuffers<int16_t> buffers =
CreateLSTMBuffers<int16_t>(context, op_data_lstm.buffer_indices);
EvalLstm<int16_t, int8_t, int16_t, int64_t>(op_data_lstm,
kernel_content, buffers);
CMSIS_NN_CreateLSTMBuffers(context, op_data_lstm.buffer_indices);
CMSIS_NN_EvalInteger16x8_16Lstm(op_data, kernel_content, buffers);
break;
}
default: {
Expand Down Expand Up @@ -460,6 +496,33 @@ TfLiteStatus UnidirectionalSequenceLstmEvalInt8(TfLiteContext* context,
return kTfLiteOk;
}

TfLiteStatus UnidirectionalSequenceLstmEvalInt16(TfLiteContext* context,
TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& op_data = *reinterpret_cast<const OpData*>(node->user_data);
const OpDataLSTM& op_data_lstm = op_data.params_ref;
auto kernel_content = CreateLSTMKernelContent(context, node);
const auto activation_type =
kernel_content.internal_tensors[kLstmInputTensor]->type;
const auto weight_type =
kernel_content.internal_tensors[kLstmInputToInputWeightsTensor]->type;

TFLITE_DCHECK(weight_type == kTfLiteInt16 &&
"Only int16 filter type supported.");

if (activation_type == kTfLiteInt16) {
LSTMBuffers<int16_t> buffers =
CMSIS_NN_CreateLSTMBuffers(context, op_data_lstm.buffer_indices);

return CMSIS_NN_EvalInteger16x8_16Lstm(op_data, kernel_content, buffers);
} else {
MicroPrintf("Input type %s (%d) not supported.",
TfLiteTypeGetName(activation_type), activation_type);
return kTfLiteError;
}
return kTfLiteOk;
}

} // namespace

TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
Expand All @@ -474,4 +537,10 @@ TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT8() {
UnidirectionalSequenceLstmEvalInt8);
}

TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT16() {
return tflite::micro::RegisterOp(UnidirectionalSequenceLstmInit,
UnidirectionalSequenceLstmPrepare,
UnidirectionalSequenceLstmEvalInt16);
}

} // namespace tflite
11 changes: 10 additions & 1 deletion tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,10 +36,19 @@ TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
// implementations.
TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT8();

// Returns a TFLMRegistration struct for kernel variant that only supports
// int16 activations and int8 weights and uses the latency optimized
// implementations.
TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT16();

#else
inline TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT8() {
return Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
}

inline TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT16() {
return Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
}
#endif

} // namespace tflite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ if [ -d ${DOWNLOADED_CMSIS_NN_PATH} ]; then
echo >&2 "${DOWNLOADED_CMSIS_NN_PATH} already exists, skipping the download."
else

ZIP_PREFIX_NN="6cc31fb36fa330325b2bb0ffde3a7288384e58ab"
CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/6cc31fb36fa330325b2bb0ffde3a7288384e58ab.zip"
CMSIS_NN_MD5="42000f264b93b7b6cd60c1b507792daf"
ZIP_PREFIX_NN="8492d82a1a81651977c5f5128492b4a0f0cf6715"
CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip"
CMSIS_NN_MD5="2cb03e4f044b78af6751009cd53247a8"

# wget is much faster than git clone of the entire repo. So we wget a specific
# version and can then apply a patch, as needed.
Expand Down

0 comments on commit 4f25468

Please sign in to comment.