From b55691e9edf0ab53645f4e7c28323e6c3ca8ae4b Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Wed, 9 Aug 2023 22:46:47 +0530 Subject: [PATCH 1/6] pre-init only licenses added --- keras_nlp/models/xlnet/xlnet_tokenizer.py | 17 +++++++++++++++++ keras_nlp/models/xlnet/xlnet_tokenizer_test.py | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 keras_nlp/models/xlnet/xlnet_tokenizer.py create mode 100644 keras_nlp/models/xlnet/xlnet_tokenizer_test.py diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer.py b/keras_nlp/models/xlnet/xlnet_tokenizer.py new file mode 100644 index 0000000000..1cb43edf86 --- /dev/null +++ b/keras_nlp/models/xlnet/xlnet_tokenizer.py @@ -0,0 +1,17 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""XLNET tokenizer.""" + + diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py new file mode 100644 index 0000000000..5afd3eeed5 --- /dev/null +++ b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py @@ -0,0 +1,17 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for XLNET tokenizer.""" + + From 9ef8875e779a61257d3fd84ebc716ddbb6f51b22 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Thu, 10 Aug 2023 12:56:29 +0530 Subject: [PATCH 2/6] tokenizer encode and decode done --- keras_nlp/models/__init__.py | 1 + keras_nlp/models/xlnet/xlnet_tokenizer.py | 154 ++++++++++++++++++++++ 2 files changed, 155 insertions(+) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index eb4e74be3a..30616a0626 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -127,3 +127,4 @@ XLMRobertaTokenizer, ) from keras_nlp.models.xlnet.xlnet_backbone import XLNetBackbone +from keras_nlp.models.xlnet.xlnet_tokenizer import XLNetTokenizer diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer.py b/keras_nlp/models/xlnet/xlnet_tokenizer.py index 1cb43edf86..07869ece6d 100644 --- a/keras_nlp/models/xlnet/xlnet_tokenizer.py +++ b/keras_nlp/models/xlnet/xlnet_tokenizer.py @@ -14,4 +14,158 @@ """XLNET tokenizer.""" +import tensorflow as tf + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer + +try: + import unicodedata +except ImportError: + unicodedata = None + + +@keras_nlp_export("keras_nlp.models.XLNetTokenizer") +class XLNetTokenizer(SentencePieceTokenizer): + """XLNET tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + XLNET models and provides a `from_preset()` method to automatically + download a matching vocabulary for a ALBERT preset. + + This tokenizer does not provide truncation or padding of inputs. It can be + combined with a `keras_nlp.models.XLNetPreprocessor` layer for input + packing. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.XLNetTokenizer(proto="") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=10, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + sep_piece="", + unk_piece="", + bos_piece="", + eos_piece="", + user_defined_symbols="[MASK]", + ) + tokenizer = keras_nlp.models.AlbertTokenizer( + proto=bytes_io.getvalue(), + ) + tokenizer("The quick brown fox jumped.") + ``` + """ + + def __init__(self, proto, **kwargs): + super().__init__(proto=proto, **kwargs) + + # Check for necessary special tokens. + cls_token = "" + sep_token = "" + pad_token = "" + mask_token = "" + bos_token = "" + eos_token = "" + unk_token = "" + + for token in [cls_token, sep_token, pad_token, mask_token, bos_token, eos_token, unk_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + + self.cls_token_id = self.token_to_id(cls_token) + self.sep_token_id = self.token_to_id(sep_token) + self.pad_token_id = self.token_to_id(pad_token) + self.mask_token_id = self.token_to_id(mask_token) + self.bos_token_id = self.token_to_id(bos_token) + self.eos_token_id = self.token_to_id(eos_token) + self.unk_token_id = self.token_to_id(unk_token) + + def preprocess_text(self, inputs): + """Preprocesses the text. This method removes spaces and accents.""" + + # remove space + outputs = " ".join(inputs.strip().split()) + outputs = outputs.replace("``", '"').replace("''", '"') + + # remove accents + outputs = unicodedata.normalize("NFKD", outputs) + outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) + + return outputs + + def tokenize(self, text): + """Tokenize a string.""" + + # check if there are multiple batches present or not + is_batched = isinstance(text, list) + if not is_batched: + text = [text] + + tokenized_text = [] + for each_text in text: + each_text = self.preprocess_text(each_text) + pieces = [self.id_to_token(token_id) for token_id in super().tokenize(each_text)] + + new_pieces = [] + for piece in pieces: + if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit(): + cur_pieces = [self.id_to_token(cur_piece_id) for cur_piece_id in super().tokenize(piece[:-1].replace("▁", ""))] + if piece[0] != "▁" and cur_pieces[0][0] == "▁": + if len(cur_pieces[0]) == 1: + cur_pieces = cur_pieces[1:] + else: + cur_pieces[0] = cur_pieces[0][1:] + cur_pieces.append(piece[-1]) + new_pieces.extend(cur_pieces) + else: + new_pieces.append(piece) + + new_pieces = [self.token_to_id(new_piece_token) for new_piece_token in new_pieces] + # add sep_token and cls_token in the end. + new_pieces.extend([self.sep_token_id, self.cls_token_id]) + + tokenized_text.append(new_pieces) + + if is_batched: + return tf.ragged.constant(tokenized_text) + + return tokenized_text[0] From f1fe2d8ad722a7c4284b7af7a0ada33ca650d3be Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Thu, 10 Aug 2023 14:07:39 +0530 Subject: [PATCH 3/6] tests all green! --- .../models/xlnet/xlnet_tokenizer_test.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py index 5afd3eeed5..428fdd8c61 100644 --- a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py +++ b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py @@ -14,4 +14,90 @@ """Tests for XLNET tokenizer.""" +import io +import sentencepiece +import tensorflow as tf +from keras_nlp.backend import keras +from keras_nlp.models.xlnet.xlnet_tokenizer import XLNetTokenizer +from keras_nlp.tests.test_case import TestCase + + +class XLNetTokenizerTest(TestCase): + def setUp(self): + bytes_io = io.BytesIO() + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=14, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + user_defined_symbols=["", "", ""] + ) + self.proto = bytes_io.getvalue() + + self.tokenizer = XLNetTokenizer(proto=self.proto) + print(self.tokenizer.get_vocabulary()) + + def test_tokenize(self): + input_data = "the quick brown fox" + output = self.tokenizer(input_data) + self.assertAllEqual(output, [7, 12, 8, 10, 6, 5]) + + def test_tokenize_batch(self): + input_data = ["the quick brown fox", "the earth is round"] + output = self.tokenizer(input_data) + self.assertAllEqual(output, [[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]]) + + def test_detokenize(self): + input_data = [[7, 12, 8, 10, 6, 5]] + output = self.tokenizer.detokenize(input_data) + self.assertEqual(output, ["the quick brown fox"]) + + def test_detokenize_mask_token(self): + input_data = [[7, 12, 8, 10, 6, 5, self.tokenizer.mask_token_id]] + output = self.tokenizer.detokenize(input_data) + self.assertEqual(output, ["the quick brown fox"]) + + def test_vocabulary_size(self): + self.assertEqual(self.tokenizer.vocabulary_size(), 14) + + def test_get_vocabulary_mask_token(self): + self.assertEqual(self.tokenizer.get_vocabulary()[4], "") + + def test_id_to_token_mask_token(self): + self.assertEqual(self.tokenizer.id_to_token(4), "") + + def test_token_to_id_mask_token(self): + self.assertEqual(self.tokenizer.token_to_id(""), 4) + + def test_errors_missing_special_tokens(self): + bytes_io = io.BytesIO() + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=iter(["abc"]), + model_writer=bytes_io, + vocab_size=5, + pad_id=-1, + eos_id=-1, + bos_id=-1, + ) + with self.assertRaises(ValueError): + XLNetTokenizer(proto=bytes_io.getvalue()) + + def test_serialization(self): + config = keras.saving.serialize_keras_object(self.tokenizer) + new_tokenizer = keras.saving.deserialize_keras_object(config) + self.assertEqual( + new_tokenizer.get_config(), + self.tokenizer.get_config(), + ) From 983d93d4ba4bc6e5dfdea8376c7aa5cbdbe7734a Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Thu, 10 Aug 2023 15:08:06 +0530 Subject: [PATCH 4/6] style fix --- keras_nlp/models/xlnet/xlnet_tokenizer.py | 98 +++++++++++++------ .../models/xlnet/xlnet_tokenizer_test.py | 12 ++- 2 files changed, 73 insertions(+), 37 deletions(-) diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer.py b/keras_nlp/models/xlnet/xlnet_tokenizer.py index 07869ece6d..2344f9868c 100644 --- a/keras_nlp/models/xlnet/xlnet_tokenizer.py +++ b/keras_nlp/models/xlnet/xlnet_tokenizer.py @@ -33,7 +33,7 @@ class XLNetTokenizer(SentencePieceTokenizer): is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the underlying tokenizer, it will check for all special tokens needed by XLNET models and provides a `from_preset()` method to automatically - download a matching vocabulary for a ALBERT preset. + download a matching vocabulary for a XLNET preset. This tokenizer does not provide truncation or padding of inputs. It can be combined with a `keras_nlp.models.XLNetPreprocessor` layer for input @@ -68,22 +68,21 @@ class XLNetTokenizer(SentencePieceTokenizer): bytes_io = io.BytesIO() ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) sentencepiece.SentencePieceTrainer.train( - sentence_iterator=ds.as_numpy_iterator(), + sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, - vocab_size=10, + vocab_size=14, model_type="WORD", pad_id=0, - unk_id=1, - bos_id=2, - eos_id=3, + bos_id=1, + eos_id=2, + unk_id=3, pad_piece="", - sep_piece="", - unk_piece="", bos_piece="", eos_piece="", - user_defined_symbols="[MASK]", + unk_piece="", + user_defined_symbols=["", "", ""] ) - tokenizer = keras_nlp.models.AlbertTokenizer( + tokenizer = keras_nlp.models.XLNetTokenizer( proto=bytes_io.getvalue(), ) tokenizer("The quick brown fox jumped.") @@ -94,15 +93,23 @@ def __init__(self, proto, **kwargs): super().__init__(proto=proto, **kwargs) # Check for necessary special tokens. - cls_token = "" - sep_token = "" - pad_token = "" - mask_token = "" - bos_token = "" - eos_token = "" - unk_token = "" - - for token in [cls_token, sep_token, pad_token, mask_token, bos_token, eos_token, unk_token]: + self.cls_token = "" + self.sep_token = "" + self.pad_token = "" + self.mask_token = "" + self.bos_token = "" + self.eos_token = "" + self.unk_token = "" + + for token in [ + self.cls_token, + self.sep_token, + self.pad_token, + self.mask_token, + self.bos_token, + self.eos_token, + self.unk_token, + ]: if token not in self.get_vocabulary(): raise ValueError( f"Cannot find token `'{token}'` in the provided " @@ -110,16 +117,16 @@ def __init__(self, proto, **kwargs): "`vocabulary` or use a pretrained `vocabulary` name." ) - self.cls_token_id = self.token_to_id(cls_token) - self.sep_token_id = self.token_to_id(sep_token) - self.pad_token_id = self.token_to_id(pad_token) - self.mask_token_id = self.token_to_id(mask_token) - self.bos_token_id = self.token_to_id(bos_token) - self.eos_token_id = self.token_to_id(eos_token) - self.unk_token_id = self.token_to_id(unk_token) + self.cls_token_id = self.token_to_id(self.cls_token) + self.sep_token_id = self.token_to_id(self.sep_token) + self.pad_token_id = self.token_to_id(self.pad_token) + self.mask_token_id = self.token_to_id(self.mask_token) + self.bos_token_id = self.token_to_id(self.bos_token) + self.eos_token_id = self.token_to_id(self.eos_token) + self.unk_token_id = self.token_to_id(self.unk_token) def preprocess_text(self, inputs): - """Preprocesses the text. This method removes spaces and accents.""" + """Preprocesses the text. This method removes spaces and accents from the text.""" # remove space outputs = " ".join(inputs.strip().split()) @@ -134,7 +141,7 @@ def preprocess_text(self, inputs): def tokenize(self, text): """Tokenize a string.""" - # check if there are multiple batches present or not + # check if there are multiple examples present or not is_batched = isinstance(text, list) if not is_batched: text = [text] @@ -142,12 +149,24 @@ def tokenize(self, text): tokenized_text = [] for each_text in text: each_text = self.preprocess_text(each_text) - pieces = [self.id_to_token(token_id) for token_id in super().tokenize(each_text)] + pieces = [ + self.id_to_token(token_id) + for token_id in super().tokenize(each_text) + ] new_pieces = [] for piece in pieces: - if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit(): - cur_pieces = [self.id_to_token(cur_piece_id) for cur_piece_id in super().tokenize(piece[:-1].replace("▁", ""))] + if ( + len(piece) > 1 + and piece[-1] == str(",") + and piece[-2].isdigit() + ): + cur_pieces = [ + self.id_to_token(cur_piece_id) + for cur_piece_id in super().tokenize( + piece[:-1].replace("▁", "") + ) + ] if piece[0] != "▁" and cur_pieces[0][0] == "▁": if len(cur_pieces[0]) == 1: cur_pieces = cur_pieces[1:] @@ -158,14 +177,29 @@ def tokenize(self, text): else: new_pieces.append(piece) - new_pieces = [self.token_to_id(new_piece_token) for new_piece_token in new_pieces] + new_pieces = [ + self.token_to_id(new_piece_token) + for new_piece_token in new_pieces + ] # add sep_token and cls_token in the end. new_pieces.extend([self.sep_token_id, self.cls_token_id]) tokenized_text.append(new_pieces) + # if there are multiple examples present, then return a `tf.RaggedTensor`. if is_batched: return tf.ragged.constant(tokenized_text) return tokenized_text[0] + def detokenize(self, inputs): + """Detokenize the input_ids into text.""" + + outputs = super().detokenize(inputs) + + outputs = tf.strings.regex_replace(outputs, self.cls_token, "") + outputs = tf.strings.regex_replace(outputs, self.sep_token, "") + outputs = tf.strings.regex_replace(outputs, self.mask_token, "") + outputs = tf.strings.regex_replace(outputs, self.pad_token, "") + + return outputs diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py index 428fdd8c61..0474a543b6 100644 --- a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py +++ b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py @@ -15,6 +15,7 @@ """Tests for XLNET tokenizer.""" import io + import sentencepiece import tensorflow as tf @@ -42,12 +43,11 @@ def setUp(self): bos_piece="", eos_piece="", unk_piece="", - user_defined_symbols=["", "", ""] + user_defined_symbols=["", "", ""], ) self.proto = bytes_io.getvalue() self.tokenizer = XLNetTokenizer(proto=self.proto) - print(self.tokenizer.get_vocabulary()) def test_tokenize(self): input_data = "the quick brown fox" @@ -57,17 +57,19 @@ def test_tokenize(self): def test_tokenize_batch(self): input_data = ["the quick brown fox", "the earth is round"] output = self.tokenizer(input_data) - self.assertAllEqual(output, [[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]]) + self.assertAllEqual( + output, [[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]] + ) def test_detokenize(self): input_data = [[7, 12, 8, 10, 6, 5]] output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, ["the quick brown fox"]) + self.assertEqual(output, ["the quick brown fox"]) def test_detokenize_mask_token(self): input_data = [[7, 12, 8, 10, 6, 5, self.tokenizer.mask_token_id]] output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, ["the quick brown fox"]) + self.assertEqual(output, ["the quick brown fox"]) def test_vocabulary_size(self): self.assertEqual(self.tokenizer.vocabulary_size(), 14) From e03e9bbbe90ad576086ade7f60042db2efd3e5c2 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Thu, 10 Aug 2023 16:31:13 +0530 Subject: [PATCH 5/6] tests fix --- keras_nlp/models/xlnet/xlnet_tokenizer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py index 0474a543b6..49ada980df 100644 --- a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py +++ b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py @@ -50,9 +50,9 @@ def setUp(self): self.tokenizer = XLNetTokenizer(proto=self.proto) def test_tokenize(self): - input_data = "the quick brown fox" + input_data = ["the quick brown fox"] output = self.tokenizer(input_data) - self.assertAllEqual(output, [7, 12, 8, 10, 6, 5]) + self.assertAllEqual(output, [[7, 12, 8, 10, 6, 5]]) def test_tokenize_batch(self): input_data = ["the quick brown fox", "the earth is round"] From e43119698ff15786eb68714c4a70cc8ca397bbd4 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Sat, 12 Aug 2023 15:57:03 +0530 Subject: [PATCH 6/6] tokenizer now working with ds.map(...) --- keras_nlp/models/xlnet/xlnet_tokenizer.py | 67 ++++++++++++------- .../models/xlnet/xlnet_tokenizer_test.py | 16 +++++ 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer.py b/keras_nlp/models/xlnet/xlnet_tokenizer.py index 2344f9868c..ae6dfff295 100644 --- a/keras_nlp/models/xlnet/xlnet_tokenizer.py +++ b/keras_nlp/models/xlnet/xlnet_tokenizer.py @@ -18,11 +18,12 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_nlp.utils.tensor_utils import assert_tf_text_installed try: - import unicodedata + import tensorflow_text as tf_text except ImportError: - unicodedata = None + tf_text = None @keras_nlp_export("keras_nlp.models.XLNetTokenizer") @@ -90,6 +91,8 @@ class XLNetTokenizer(SentencePieceTokenizer): """ def __init__(self, proto, **kwargs): + assert_tf_text_installed(self.__class__.__name__) + super().__init__(proto=proto, **kwargs) # Check for necessary special tokens. @@ -129,30 +132,28 @@ def preprocess_text(self, inputs): """Preprocesses the text. This method removes spaces and accents from the text.""" # remove space - outputs = " ".join(inputs.strip().split()) - outputs = outputs.replace("``", '"').replace("''", '"') + outputs = tf.strings.split(tf.strings.strip(inputs), sep=" ") + outputs = tf.strings.reduce_join( + outputs, separator=" ", keepdims=True, axis=-1 + ) + outputs = tf.strings.regex_replace(outputs, pattern="``", rewrite='"') + outputs = tf.strings.regex_replace(outputs, pattern="''", rewrite='"') # remove accents - outputs = unicodedata.normalize("NFKD", outputs) - outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) + outputs = tf_text.normalize_utf8(outputs, "nfkd") return outputs - def tokenize(self, text): - """Tokenize a string.""" - - # check if there are multiple examples present or not - is_batched = isinstance(text, list) - if not is_batched: - text = [text] + def postprocess(self, batch_token_ids): + batch_token_ids = ( + tf.squeeze(batch_token_ids, -2) + if tf.rank(batch_token_ids) > 2 + else batch_token_ids + ) tokenized_text = [] - for each_text in text: - each_text = self.preprocess_text(each_text) - pieces = [ - self.id_to_token(token_id) - for token_id in super().tokenize(each_text) - ] + for each_token_ids in batch_token_ids: + pieces = [self.id_to_token(token_id) for token_id in each_token_ids] new_pieces = [] for piece in pieces: @@ -164,7 +165,9 @@ def tokenize(self, text): cur_pieces = [ self.id_to_token(cur_piece_id) for cur_piece_id in super().tokenize( - piece[:-1].replace("▁", "") + tf.strings.regex_replace( + piece[:-1], pattern="▁", rewrite="" + ) ) ] if piece[0] != "▁" and cur_pieces[0][0] == "▁": @@ -183,14 +186,28 @@ def tokenize(self, text): ] # add sep_token and cls_token in the end. new_pieces.extend([self.sep_token_id, self.cls_token_id]) - tokenized_text.append(new_pieces) - # if there are multiple examples present, then return a `tf.RaggedTensor`. - if is_batched: - return tf.ragged.constant(tokenized_text) + tokenized_text = tf.ragged.constant(tokenized_text) + + return tokenized_text + + def tokenize(self, text): + """Tokenize a string.""" + + text = self.preprocess_text(text) + token_ids = super().tokenize(text) + token_ids = tf.py_function( + func=self.postprocess, + inp=[token_ids], + Tout=tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32), + ) + + # if there is only one example in the batch then output tf.Tensor otherwise tf.RaggedTensor + if isinstance(text, str): + return token_ids.to_tensor() - return tokenized_text[0] + return token_ids def detokenize(self, inputs): """Detokenize the input_ids into text.""" diff --git a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py index 49ada980df..2d677658a9 100644 --- a/keras_nlp/models/xlnet/xlnet_tokenizer_test.py +++ b/keras_nlp/models/xlnet/xlnet_tokenizer_test.py @@ -61,6 +61,22 @@ def test_tokenize_batch(self): output, [[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]] ) + def test_tokenize_ds(self): + input_ds = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + input_ds = input_ds.map(self.tokenizer) + outputs = [] + for each_item in input_ds.take(2): + self.assertTrue(isinstance(each_item, tf.RaggedTensor)) + outputs.append(each_item.to_tensor()) + + outputs = tf.squeeze(tf.convert_to_tensor(outputs), 1) + self.assertAllEqual( + outputs, + tf.convert_to_tensor([[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]]), + ) + def test_detokenize(self): input_data = [[7, 12, 8, 10, 6, 5]] output = self.tokenizer.detokenize(input_data)