diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 57cce631fa0..dbe2cb22af8 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -8,6 +8,7 @@ cc_library( srcs = ["array.cc"], hdrs = ["array.h"], deps = [ + "//tensorflow/lite/c:common", "//tensorflow/lite/core/c:common", ], ) diff --git a/tensorflow/lite/array.cc b/tensorflow/lite/array.cc index 1b1ff2e4557..21d704a76c4 100644 --- a/tensorflow/lite/array.cc +++ b/tensorflow/lite/array.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/array.h" +#include "tensorflow/lite/c/common.h" + namespace tflite { namespace array_internal { diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h index 767ad6ab0af..d83696219c2 100644 --- a/tensorflow/lite/kernels/internal/reference/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, const float* scaling_factors, const int32_t* input_offset, int32_t* row_sums, const RuntimeShape& output_shape, float* output_data, - bool* compute_row_sums) { + bool* compute_row_sums, + const float* per_channel_scales) { const RuntimeShape extended_lhs_shape = RuntimeShape::ExtendedShape(5, lhs_shape); const RuntimeShape extended_rhs_shape = @@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, int32_t row_sum = woff_ptr2[i]; total -= row_sum * batch_offset; int idx = lhs_rows * j + i; - out_ptr[idx] += batch_scaling_factor * total; + float scale = batch_scaling_factor; + if (per_channel_scales) { + scale *= per_channel_scales[i]; + } + out_ptr[idx] += scale * total; } } }