Skip to content

Commit

Permalink
Fix issue with transposing shape in CMSIS-NN batch matmul
Browse files Browse the repository at this point in the history
BUG=#2740

Signed-off-by: Ryan O'Shea <[email protected]>
  • Loading branch information
ArmRyan committed Oct 30, 2024
1 parent 389e775 commit 1420149
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ inline TfLiteStatus PopulateEvalData(
RuntimeShape tmp_r = SwapRowColumnDims(*rhs_shape);
rhs_shape->ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData());
}
if (!params->adj_x) {
// ReferenceOps and CMSIS-NN have different requirements for when the
// lhs shape should be transposed, so we have to treat float differently.
if (!params->adj_x && original_lhs_input->type == kTfLiteFloat32) {
RuntimeShape tmp_l = SwapRowColumnDims(*lhs_shape);
lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData());
} else if (params->adj_x && original_lhs_input->type != kTfLiteFloat32) {
RuntimeShape tmp_l = SwapRowColumnDims(*lhs_shape);
lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData());
}
Expand Down

0 comments on commit 1420149

Please sign in to comment.