From 8ed4a7b70df385a253aad7ed7df782439f42da6c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 May 2024 07:55:44 -0700 Subject: [PATCH] Allow specifying the BERT text classifier's model output tensor name via options. PiperOrigin-RevId: 638652416 --- tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc | 3 ++- .../cc/task/text/proto/bert_nl_classifier_options.proto | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc index 52c898dac..a1a9bc5ee 100644 --- a/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc @@ -74,7 +74,8 @@ StatusOr> BertNLClassifier::Postprocess( } const TfLiteTensor* scores = FindTensorByName( output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(), - kScoreTensorName); + options_->has_output_tensor_name() ? options_->output_tensor_name() + : kScoreTensorName); // optional labels extracted from metadata return BuildResults(scores, /*labels=*/nullptr); diff --git a/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto b/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto index 505ccef31..9e54bb0e7 100644 --- a/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto +++ b/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto @@ -20,7 +20,7 @@ package tflite.task.text; import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; // Options for setting up a BertNLClassifier. -// Next Id: 3 +// Next Id: 4 message BertNLClassifierOptions { // Base options for configuring BertNLClassifier, such as specifying the // TfLite model file with metadata, accelerator options, etc. @@ -31,4 +31,9 @@ message BertNLClassifierOptions { // Deprecated: max_seq_len is now read from the model (i.e. input tensor size) // automatically. optional int32 max_seq_len = 2 [default = 128]; + + // The name of the output tensor. + // + // If not provided, defaults to "probability". + optional string output_tensor_name = 3; }