From 1420149f54174576ce7bc79d794ec3b64b7e6583 Mon Sep 17 00:00:00 2001 From: Ryan O'Shea Date: Wed, 30 Oct 2024 17:07:44 +0100 Subject: [PATCH] Fix issue with transposing shape in CMSIS-NN batch matmul BUG=#2740 Signed-off-by: Ryan O'Shea --- tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc b/tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc index f73ceed57c2..7c0cb69b5bf 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc @@ -121,7 +121,12 @@ inline TfLiteStatus PopulateEvalData( RuntimeShape tmp_r = SwapRowColumnDims(*rhs_shape); rhs_shape->ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData()); } - if (!params->adj_x) { + // ReferenceOps and CMSIS-NN have different requirements for when the + // lhs shape should be transposed, so we have to treat float differently. + if (!params->adj_x && original_lhs_input->type == kTfLiteFloat32) { + RuntimeShape tmp_l = SwapRowColumnDims(*lhs_shape); + lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData()); + } else if (params->adj_x && original_lhs_input->type != kTfLiteFloat32) { RuntimeShape tmp_l = SwapRowColumnDims(*lhs_shape); lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData()); }