Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce prod porting #2532

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow/lite/kernels/internal/reference/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,20 @@ inline bool QuantizedReduceProd(const T* input_data, int32_t input_zero_point,
return true;
}

template <typename T, typename U>
inline bool QuantizedProdExtraArgs(
const T* input_data, int32_t input_zero_point, float input_scale,
const int* input_dims, const int input_num_dims, T* output_data,
float output_scale, int32_t output_multiplier, int output_shift,
int32_t output_zero_point, const int* output_dims,
const int output_num_dims, const int* axis, const int num_axis_dimensions,
bool keep_dims, int* temp_index, int* resolved_axis, U* temp_prod) {

return QuantizedReduceProd<T> (input_data, input_zero_point, RuntimeShape(input_num_dims, input_dims), output_data,
output_zero_point, RuntimeShape(output_num_dims, output_dims), axis, num_axis_dimensions, keep_dims, temp_index,
resolved_axis, temp_prod, output_multiplier, output_shift);
}

} // namespace reference_ops

} // namespace tflite
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/micro_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ TFLMRegistration Register_PRELU();
TFLMRegistration Register_QUANTIZE();
TFLMRegistration Register_READ_VARIABLE();
TFLMRegistration Register_REDUCE_MAX();
TFLMRegistration Register_REDUCE_PROD();
TFLMRegistration Register_RELU();
TFLMRegistration Register_RELU6();
TFLMRegistration Register_RESHAPE();
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/lite/micro/kernels/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
static_cast<OpDataReduce*>(node->user_data));
}

TfLiteStatus PrepareProd(TfLiteContext* context, TfLiteNode* node) {
return PrepareProdHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
}

TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
return EvalMeanHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
Expand All @@ -59,6 +64,11 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
static_cast<OpDataReduce*>(node->user_data));
}

TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
return EvalProdHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
}

TFLMRegistration Register_MEAN() {
return tflite::micro::RegisterOp(InitReduce, PrepareMeanOrSum, EvalMean);
}
Expand All @@ -71,4 +81,8 @@ TFLMRegistration Register_SUM() {
return tflite::micro::RegisterOp(InitReduce, PrepareMeanOrSum, EvalSum);
}

TFLMRegistration Register_PROD() {
return tflite::micro::RegisterOp(InitReduce, PrepareProd, EvalProd);
}

} // namespace tflite
6 changes: 6 additions & 0 deletions tensorflow/lite/micro/kernels/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,25 @@ TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);

TfLiteStatus PrepareProdHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);

TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalProdHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);

void ReduceResolveAxis(const int* axis_data, int axis_count,
MeanParams* op_params);

TFLMRegistration Register_MEAN();
TFLMRegistration Register_REDUCE_MAX();
TFLMRegistration Register_SUM();
TFLMRegistration Register_PROD();

} // namespace tflite

Expand Down
122 changes: 122 additions & 0 deletions tensorflow/lite/micro/kernels/reduce_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,63 @@ TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

double GetQuantProdScaling(double input_scale, double output_scale,
int reduced_axis_size) {
// The scaling after taking the product of all the quantized values should
// be (input_scale**reduced_axis_size)/output_scale but to avoid overflowing
// the accumulator we instead scale each multiplication by
// input_scale/nth_root(output_scale, reduced_axis_size).
return input_scale / std::pow(output_scale, 1.0 / reduced_axis_size);
}

TfLiteStatus PrepareProdHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);

if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {

const int input_size = NumElements(input);
const int output_size = NumElements(output);
const int reduced_axis_size = input_size / output_size;
const double scaling = GetQuantProdScaling(
static_cast<double>(input->params.scale),
static_cast<double>(output->params.scale),
reduced_axis_size);
QuantizeMultiplier(scaling, &op_data->multiplier, &op_data->shift);
}

int output_size = NumElements(output);
op_data->num_axis = NumElements(axis);

if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
&op_data->temp_buffer_idx);
op_data->input_zp = input->params.zero_point;
op_data->input_scale = input->params.scale;
op_data->output_zp = output->params.zero_point;
op_data->output_scale = output->params.scale;
}

TF_LITE_ENSURE_OK(
context,
PrepareSimple(context, node, &(op_data->multiplier), &(op_data->shift)));
// TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(axis);
return kTfLiteOk;
}

TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);

if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
const double real_multiplier = static_cast<double>(input->params.scale) /
static_cast<double>(output->params.scale);
Expand Down Expand Up @@ -162,6 +213,29 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

template <typename T>
TfLiteStatus QuantizedProd(TfLiteContext* context, TfLiteNode* node,
int* temp_index, int* resolved_axis,
int32_t* temp_prod, OpDataReduce* op_data) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);

bool result = reference_ops::QuantizedProdExtraArgs<T, int32_t>(
tflite::micro::GetTensorData<T>(input), op_data->input_zp,
op_data->input_scale, &input->dims->data[0], input->dims->size,
tflite::micro::GetTensorData<T>(output), op_data->output_scale,
op_data->multiplier, op_data->shift, op_data->output_zp,
&output->dims->data[0], output->dims->size,
tflite::micro::GetTensorData<int>(axis), op_data->num_axis,
params->keep_dims, temp_index, resolved_axis, temp_prod);
TF_LITE_ENSURE(context, result);

return kTfLiteOk;
}

template <typename integer_type>
TfLiteStatus EvalIntegerMean(TfLiteContext* context, TfLiteNode* node,
int num_axis, OpDataReduce* op_data,
Expand Down Expand Up @@ -337,4 +411,52 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

TfLiteStatus EvalProdHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);


// Interpret an axis tensor with null dimensions as a scalar.
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int temp_index[kMaxNumberOfAxis];
int resolved_axis[kMaxNumberOfReducedAxis];

switch (input->type) {
case kTfLiteFloat32: {
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis, /*init_value=*/1.f,
[](const float current, const float in) -> float {
return in * current;
}));
} break;
case kTfLiteInt8: {
int32_t* temp_prod = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
QuantizedProd<int8_t>(context, node, temp_index, resolved_axis,
temp_prod, op_data);
} break;
case kTfLiteInt16: {
int32_t* temp_prod = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
QuantizedProd<int16_t>(context, node, temp_index, resolved_axis,
temp_prod, op_data);
} break;
default:
MicroPrintf("Only float32, int8, and int16 types are supported.");
return kTfLiteError;
}
return kTfLiteOk;
}

} // namespace tflite
Loading
Loading