Skip to content

Commit

Permalink
Allow specifying the BERT text classifier's model output tensor name …
Browse files Browse the repository at this point in the history
…via options.

PiperOrigin-RevId: 638652416
  • Loading branch information
tensorflower-gardener authored and tflite-support-robot committed May 30, 2024
1 parent 3247254 commit 8ed4a7b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess(
}
const TfLiteTensor* scores = FindTensorByName(
output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
kScoreTensorName);
options_->has_output_tensor_name() ? options_->output_tensor_name()
: kScoreTensorName);

// optional labels extracted from metadata
return BuildResults(scores, /*labels=*/nullptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package tflite.task.text;
import "tensorflow_lite_support/cc/task/core/proto/base_options.proto";

// Options for setting up a BertNLClassifier.
// Next Id: 3
// Next Id: 4
message BertNLClassifierOptions {
// Base options for configuring BertNLClassifier, such as specifying the
// TfLite model file with metadata, accelerator options, etc.
Expand All @@ -31,4 +31,9 @@ message BertNLClassifierOptions {
// Deprecated: max_seq_len is now read from the model (i.e. input tensor size)
// automatically.
optional int32 max_seq_len = 2 [default = 128];

// The name of the output tensor.
//
// If not provided, defaults to "probability".
optional string output_tensor_name = 3;
}

0 comments on commit 8ed4a7b

Please sign in to comment.