Skip to content

Commit

Permalink
[tokenizer] Not returns overflow tokens by default (#2857)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Nov 17, 2023
1 parent e315554 commit 062d395
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public final class HuggingFaceTokenizer extends NativeResource<Long> implements
private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class);

private boolean addSpecialTokens;
private boolean withOverflowingTokens;
private TruncationStrategy truncation;
private PaddingStrategy padding;
private int maxLength;
Expand All @@ -64,6 +65,8 @@ private HuggingFaceTokenizer(long handle, Map<String, String> options) {
if (options != null) {
val = options.getOrDefault("addSpecialTokens", "true");
addSpecialTokens = Boolean.parseBoolean(val);
val = options.getOrDefault("withOverflowingTokens", "false");
withOverflowingTokens = Boolean.parseBoolean(val);
modelMaxLength = ArgumentsUtil.intValue(options, "modelMaxLength", 512);
if (options.containsKey("truncation")) {
truncation = TruncationStrategy.fromValue(options.get("truncation"));
Expand Down Expand Up @@ -203,11 +206,12 @@ public void close() {
* @param text the input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, boolean addSpecialTokens) {
public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) {
long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens);
return toEncoding(encoding);
return toEncoding(encoding, withOverflowingTokens);
}

/**
Expand All @@ -217,7 +221,7 @@ public Encoding encode(String text, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text) {
return encode(text, addSpecialTokens);
return encode(text, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -227,12 +231,14 @@ public Encoding encode(String text) {
* @param textPair the second input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
public Encoding encode(
String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) {
long encoding =
TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens);
return toEncoding(encoding);
return toEncoding(encoding, withOverflowingTokens);
}

/**
Expand All @@ -243,7 +249,7 @@ public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, String textPair) {
return encode(text, textPair, addSpecialTokens);
return encode(text, textPair, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -252,11 +258,13 @@ public Encoding encode(String text, String textPair) {
* @param inputs the input sentences
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(List<String> inputs, boolean addSpecialTokens) {
public Encoding encode(
List<String> inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
return encode(array, addSpecialTokens);
return encode(array, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -266,7 +274,7 @@ public Encoding encode(List<String> inputs, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(List<String> inputs) {
return encode(inputs, addSpecialTokens);
return encode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -275,11 +283,13 @@ public Encoding encode(List<String> inputs) {
* @param inputs the input sentences
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(String[] inputs, boolean addSpecialTokens) {
public Encoding encode(
String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens);
return toEncoding(encoding);
return toEncoding(encoding, withOverflowingTokens);
}

/**
Expand All @@ -289,7 +299,7 @@ public Encoding encode(String[] inputs, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(String[] inputs) {
return encode(inputs, addSpecialTokens);
return encode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -298,11 +308,13 @@ public Encoding encode(String[] inputs) {
* @param inputs the batch of input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(List<String> inputs, boolean addSpecialTokens) {
public Encoding[] batchEncode(
List<String> inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
return batchEncode(array, addSpecialTokens);
return batchEncode(array, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -312,7 +324,7 @@ public Encoding[] batchEncode(List<String> inputs, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(List<String> inputs) {
return batchEncode(inputs, addSpecialTokens);
return batchEncode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -321,13 +333,15 @@ public Encoding[] batchEncode(List<String> inputs) {
* @param inputs the batch of input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) {
public Encoding[] batchEncode(
String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens);
Encoding[] ret = new Encoding[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
ret[i] = toEncoding(encodings[i]);
ret[i] = toEncoding(encodings[i], withOverflowingTokens);
}
return ret;
}
Expand All @@ -339,7 +353,7 @@ public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(String[] inputs) {
return batchEncode(inputs, addSpecialTokens);
return batchEncode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -348,17 +362,21 @@ public Encoding[] batchEncode(String[] inputs) {
* @param inputs the batch of input text pair
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input text pair in batch
*/
public Encoding[] batchEncode(PairList<String, String> inputs, boolean addSpecialTokens) {
public Encoding[] batchEncode(
PairList<String, String> inputs,
boolean addSpecialTokens,
boolean withOverflowingTokens) {
String[] text = inputs.keyArray(Utils.EMPTY_ARRAY);
String[] textPair = inputs.valueArray(Utils.EMPTY_ARRAY);
long[] encodings =
TokenizersLibrary.LIB.batchEncodePair(
getHandle(), text, textPair, addSpecialTokens);
Encoding[] ret = new Encoding[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
ret[i] = toEncoding(encodings[i]);
ret[i] = toEncoding(encodings[i], withOverflowingTokens);
}
return ret;
}
Expand All @@ -370,7 +388,7 @@ public Encoding[] batchEncode(PairList<String, String> inputs, boolean addSpecia
* @return the {@code Encoding} of the input text pair in batch
*/
public Encoding[] batchEncode(PairList<String, String> inputs) {
return batchEncode(inputs, addSpecialTokens);
return batchEncode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand Down Expand Up @@ -503,19 +521,25 @@ private void updateTruncationAndPadding() {
}
}

private Encoding toEncoding(long encoding) {
private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding);
long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding);
String[] tokens = TokenizersLibrary.LIB.getTokens(encoding);
long[] wordIds = TokenizersLibrary.LIB.getWordIds(encoding);
long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding);
long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);

Encoding[] overflowing = new Encoding[overflowingHandles.length];
for (int i = 0; i < overflowingHandles.length; ++i) {
overflowing[i] = toEncoding(overflowingHandles[i]);
Encoding[] overflowing;
if (withOverflowingTokens) {
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);

overflowing = new Encoding[overflowingHandles.length];
for (int i = 0; i < overflowingHandles.length; ++i) {
overflowing[i] = toEncoding(overflowingHandles[i], true);
}
} else {
overflowing = new Encoding[0];
}

TokenizersLibrary.LIB.deleteEncoding(encoding);
Expand Down Expand Up @@ -651,6 +675,17 @@ public Builder optAddSpecialTokens(boolean addSpecialTokens) {
return this;
}

/**
* Sets if add special tokens.
*
* @param withOverflowingTokens true to return overflowing tokens
* @return this builder
*/
public Builder optWithOverflowingTokens(boolean withOverflowingTokens) {
options.put("withOverflowingTokens", String.valueOf(withOverflowingTokens));
return this;
}

/**
* Enables or Disables default truncation behavior for the tokenizer.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class FillMaskBatchTranslator implements NoBatchifyTranslator<String[], C
this.maskToken = maskToken;
this.topK = topK;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false);
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class FillMaskTranslator implements Translator<String, Classifications> {
this.maskToken = maskToken;
this.topK = topK;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false);
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ public void testTruncationStride() throws IOException {
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optAddSpecialTokens(false)
.optWithOverflowingTokens(true)
.optTruncation(true)
.optMaxLength(3)
.optStride(1)
Expand All @@ -322,6 +323,7 @@ public void testTruncationStride() throws IOException {
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optAddSpecialTokens(false)
.optWithOverflowingTokens(true)
.optTruncation(true)
.optMaxLength(8)
.optStride(2)
Expand Down Expand Up @@ -458,13 +460,13 @@ public void testBatchProcessing() throws IOException {
Assert.assertEquals(outputs, outputsWithSpecialTokens);

// encode with special tokens, decode with special tokens
encodings = tokenizer.batchEncode(inputs, true);
encodings = tokenizer.batchEncode(inputs, true, false);
batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new);
outputs = tokenizer.batchDecode(batchIds, false);
Assert.assertEquals(outputs, outputsWithSpecialTokens);

// encode without special tokens, decode without special tokens
encodings = tokenizer.batchEncode(inputs, false);
encodings = tokenizer.batchEncode(inputs, false, false);
batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new);
outputs = tokenizer.batchDecode(batchIds, true);
Assert.assertEquals(outputs, outputsWithoutSpecialTokens);
Expand Down

0 comments on commit 062d395

Please sign in to comment.