diff --git a/keras_nlp/models/bart/bart_tokenizer.py b/keras_nlp/models/bart/bart_tokenizer.py
index c4e3d1204d..2941ba8423 100644
--- a/keras_nlp/models/bart/bart_tokenizer.py
+++ b/keras_nlp/models/bart/bart_tokenizer.py
@@ -44,6 +44,9 @@ class BartTokenizer(BytePairTokenizer):
it should be the file path to merge rules. The merge rule file
should have one merge rule per line. Every merge rule contains
merge entities separated by a space.
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
Examples:
@@ -77,6 +80,7 @@ def __init__(
self,
vocabulary=None,
merges=None,
+ special_tokens_in_strings=False,
**kwargs,
):
self.start_token = ""
@@ -86,11 +90,12 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
- unsplittable_tokens=[
+ special_tokens=[
self.start_token,
self.pad_token,
self.end_token,
],
+ special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)
@@ -98,15 +103,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if vocabulary is not None:
- # Check for necessary special tokens.
- for token in [self.start_token, self.pad_token, self.end_token]:
- if token not in self.vocabulary:
- raise ValueError(
- f"Cannot find token `'{token}'` in the provided "
- f"`vocabulary`. Please provide `'{token}'` in your "
- "`vocabulary` or use a pretrained `vocabulary` name."
- )
-
self.start_token_id = self.token_to_id(self.start_token)
self.pad_token_id = self.token_to_id(self.pad_token)
self.end_token_id = self.token_to_id(self.end_token)
@@ -117,8 +113,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
def get_config(self):
config = super().get_config()
- # In the constructor, we pass the list of special tokens to the
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
- # delete it from the config here.
- del config["unsplittable_tokens"]
+ del config["special_tokens"] # Not configurable; set in __init__.
return config
diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py
index 5a0015357b..7cdd582881 100644
--- a/keras_nlp/models/bart/bart_tokenizer_test.py
+++ b/keras_nlp/models/bart/bart_tokenizer_test.py
@@ -26,7 +26,11 @@ def setUp(self):
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
- self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
+ self.init_kwargs = {
+ "vocabulary": self.vocab,
+ "merges": self.merges,
+ "special_tokens_in_strings": True,
+ }
self.input_data = [
" airplane at airport",
" airplane airport",
@@ -37,10 +41,9 @@ def test_tokenizer_basics(self):
cls=BartTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- # TODO: should not get tokenized as
- expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]],
+ expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]],
expected_detokenize_output=[
- " airplane at airport",
+ " airplane at airport",
" airplane airport",
],
)
diff --git a/keras_nlp/models/bloom/bloom_tokenizer.py b/keras_nlp/models/bloom/bloom_tokenizer.py
index 6c6097e4ce..3d1f646d59 100644
--- a/keras_nlp/models/bloom/bloom_tokenizer.py
+++ b/keras_nlp/models/bloom/bloom_tokenizer.py
@@ -42,6 +42,9 @@ class BloomTokenizer(BytePairTokenizer):
it should be the file path to merge rules. The merge rule file
should have one merge rule per line. Every merge rule contains
merge entities separated by a space.
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
Examples:
@@ -69,6 +72,7 @@ def __init__(
self,
vocabulary=None,
merges=None,
+ special_tokens_in_strings=False,
**kwargs,
):
self.start_token = ""
@@ -78,11 +82,12 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
- unsplittable_tokens=[
+ special_tokens=[
self.start_token,
self.end_token,
self.pad_token,
],
+ special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)
@@ -90,15 +95,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if vocabulary is not None:
- # Check for necessary special tokens.
- for token in [self.start_token, self.end_token, self.pad_token]:
- if token not in self.get_vocabulary():
- raise ValueError(
- f"Cannot find token `'{token}'` in the provided "
- f"`vocabulary`. Please provide `'{token}'` in "
- "your `vocabulary` or use a pretrained `vocabulary` name."
- )
-
self.start_token_id = self.token_to_id(self.start_token)
self.end_token_id = self.token_to_id(self.end_token)
self.pad_token_id = self.token_to_id(self.pad_token)
@@ -109,8 +105,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
def get_config(self):
config = super().get_config()
- # In the constructor, we pass the list of special tokens to the
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
- # delete it from the config here.
- del config["unsplittable_tokens"]
+ del config["special_tokens"] # Not configurable; set in __init__.
return config
diff --git a/keras_nlp/models/bloom/bloom_tokenizer_test.py b/keras_nlp/models/bloom/bloom_tokenizer_test.py
index 9ae9c0cc00..c2ee12e5ca 100644
--- a/keras_nlp/models/bloom/bloom_tokenizer_test.py
+++ b/keras_nlp/models/bloom/bloom_tokenizer_test.py
@@ -26,10 +26,14 @@ def setUp(self):
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
- self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
+ self.init_kwargs = {
+ "vocabulary": self.vocab,
+ "merges": self.merges,
+ "special_tokens_in_strings": True,
+ }
self.input_data = [
- "airplane at airport",
- " airplane airport",
+ "airplane at airport",
+ " airplane airport",
]
def test_tokenizer_basics(self):
@@ -37,7 +41,7 @@ def test_tokenizer_basics(self):
cls=BloomTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[6, 1, 3, 4, 2, 5, 8], [6, 2, 3, 2, 5, 8]],
+ expected_output=[[6, 1, 3, 4, 2, 5, 7, 8], [6, 2, 3, 2, 5, 7, 8]],
)
def test_errors_missing_special_tokens(self):
diff --git a/keras_nlp/models/falcon/falcon_tokenizer.py b/keras_nlp/models/falcon/falcon_tokenizer.py
index 80d7334fe7..46b6193197 100644
--- a/keras_nlp/models/falcon/falcon_tokenizer.py
+++ b/keras_nlp/models/falcon/falcon_tokenizer.py
@@ -42,6 +42,9 @@ class FalconTokenizer(BytePairTokenizer):
it should be the file path to merge rules. The merge rule file
should have one merge rule per line. Every merge rule contains
merge entities separated by a space.
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
Examples:
@@ -69,6 +72,7 @@ def __init__(
self,
vocabulary=None,
merges=None,
+ special_tokens_in_strings=False,
**kwargs,
):
# Falcon uses the same start as end token, i.e., "<|endoftext|>".
@@ -77,7 +81,8 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
- unsplittable_tokens=[self.end_token],
+ special_tokens=[self.end_token],
+ special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)
@@ -85,14 +90,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if vocabulary is not None:
- # Check for necessary special tokens.
- if self.end_token not in self.get_vocabulary():
- raise ValueError(
- f"Cannot find token `'{self.end_token}'` in the provided "
- f"`vocabulary`. Please provide `'{self.end_token}'` in "
- "your `vocabulary` or use a pretrained `vocabulary` name."
- )
-
self.end_token_id = self.token_to_id(self.end_token)
self.start_token_id = self.end_token_id
self.pad_token_id = 0
@@ -103,8 +100,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
def get_config(self):
config = super().get_config()
- # In the constructor, we pass the list of special tokens to the
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
- # delete it from the config here.
- del config["unsplittable_tokens"]
+ del config["special_tokens"] # Not configurable; set in __init__.
return config
diff --git a/keras_nlp/models/falcon/falcon_tokenizer_test.py b/keras_nlp/models/falcon/falcon_tokenizer_test.py
index 735bcac4b6..6ee2a19a0d 100644
--- a/keras_nlp/models/falcon/falcon_tokenizer_test.py
+++ b/keras_nlp/models/falcon/falcon_tokenizer_test.py
@@ -25,7 +25,11 @@ def setUp(self):
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
- self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
+ self.init_kwargs = {
+ "vocabulary": self.vocab,
+ "merges": self.merges,
+ "special_tokens_in_strings": True,
+ }
self.input_data = [
" airplane at airport<|endoftext|>",
" airplane airport",
diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer.py b/keras_nlp/models/gpt2/gpt2_tokenizer.py
index 4a585c3176..aeff9d97cf 100644
--- a/keras_nlp/models/gpt2/gpt2_tokenizer.py
+++ b/keras_nlp/models/gpt2/gpt2_tokenizer.py
@@ -42,6 +42,9 @@ class GPT2Tokenizer(BytePairTokenizer):
it should be the file path to merge rules. The merge rule file
should have one merge rule per line. Every merge rule contains
merge entities separated by a space.
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
Examples:
@@ -69,6 +72,7 @@ def __init__(
self,
vocabulary=None,
merges=None,
+ special_tokens_in_strings=False,
**kwargs,
):
# GPT2 uses the same start as end token, i.e., "<|endoftext|>".
@@ -77,7 +81,8 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
- unsplittable_tokens=[self.end_token],
+ special_tokens=[self.end_token],
+ special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)
@@ -85,14 +90,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if vocabulary is not None:
- # Check for necessary special tokens.
- if self.end_token not in self.get_vocabulary():
- raise ValueError(
- f"Cannot find token `'{self.end_token}'` in the provided "
- f"`vocabulary`. Please provide `'{self.end_token}'` in "
- "your `vocabulary` or use a pretrained `vocabulary` name."
- )
-
self.end_token_id = self.token_to_id(self.end_token)
self.start_token_id = self.end_token_id
self.pad_token_id = 0
@@ -103,8 +100,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
def get_config(self):
config = super().get_config()
- # In the constructor, we pass the list of special tokens to the
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
- # delete it from the config here.
- del config["unsplittable_tokens"]
+ del config["special_tokens"] # Not configurable; set in __init__.
return config
diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer_test.py b/keras_nlp/models/gpt2/gpt2_tokenizer_test.py
index 026392fd25..237cb661aa 100644
--- a/keras_nlp/models/gpt2/gpt2_tokenizer_test.py
+++ b/keras_nlp/models/gpt2/gpt2_tokenizer_test.py
@@ -26,7 +26,11 @@ def setUp(self):
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
- self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
+ self.init_kwargs = {
+ "vocabulary": self.vocab,
+ "merges": self.merges,
+ "special_tokens_in_strings": True,
+ }
self.input_data = [
" airplane at airport<|endoftext|>",
" airplane airport",
diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py
index d109c5849d..84eac197d9 100644
--- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py
+++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py
@@ -41,12 +41,16 @@ class GPTNeoXTokenizer(BytePairTokenizer):
it should be the file path to merge rules. The merge rule file
should have one merge rule per line. Every merge rule contains
merge entities separated by a space.
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
"""
def __init__(
self,
vocabulary=None,
merges=None,
+ special_tokens_in_strings=False,
**kwargs,
):
# GPTNeoX uses the same start as end token, i.e., "<|endoftext|>".
@@ -55,7 +59,8 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
- unsplittable_tokens=[self.end_token],
+ special_tokens=[self.end_token],
+ special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)
@@ -63,14 +68,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if vocabulary is not None:
- # Check for necessary special tokens.
- if self.end_token not in self.get_vocabulary():
- raise ValueError(
- f"Cannot find token `'{self.end_token}'` in the provided "
- f"`vocabulary`. Please provide `'{self.end_token}'` in "
- "your `vocabulary` or use a pretrained `vocabulary` name."
- )
-
self.end_token_id = self.token_to_id(self.end_token)
self.start_token_id = self.end_token_id
self.pad_token_id = 0
@@ -81,8 +78,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
def get_config(self):
config = super().get_config()
- # In the constructor, we pass the list of special tokens to the
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
- # delete it from the config here.
- del config["unsplittable_tokens"]
+ del config["special_tokens"] # Not configurable; set in __init__.
return config
diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
index c23b7dd44d..284ae3e733 100644
--- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
+++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
@@ -24,7 +24,11 @@ def setUp(self):
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
- self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
+ self.init_kwargs = {
+ "vocabulary": self.vocab,
+ "merges": self.merges,
+ "special_tokens_in_strings": True,
+ }
self.input_data = [
" airplane at airport<|endoftext|>",
" airplane airport",
diff --git a/keras_nlp/models/opt/opt_tokenizer.py b/keras_nlp/models/opt/opt_tokenizer.py
index addcd0c01f..c22e2baedf 100644
--- a/keras_nlp/models/opt/opt_tokenizer.py
+++ b/keras_nlp/models/opt/opt_tokenizer.py
@@ -41,6 +41,9 @@ class OPTTokenizer(BytePairTokenizer):
it should be the file path to merge rules. The merge rule file
should have one merge rule per line. Every merge rule contains
merge entities separated by a space.
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
Examples:
```python
@@ -69,6 +72,7 @@ def __init__(
self,
vocabulary=None,
merges=None,
+ special_tokens_in_strings=False,
**kwargs,
):
self.start_token = ""
@@ -78,11 +82,12 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
- unsplittable_tokens=[
+ special_tokens=[
self.start_token,
self.pad_token,
self.end_token,
],
+ special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)
@@ -90,15 +95,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if vocabulary is not None:
- # Check for necessary special tokens.
- for token in [self.start_token, self.pad_token, self.end_token]:
- if token not in self.vocabulary:
- raise ValueError(
- f"Cannot find token `'{token}'` in the provided "
- f"`vocabulary`. Please provide `'{token}'` in your "
- "`vocabulary` or use a pretrained `vocabulary` name."
- )
-
self.start_token_id = self.token_to_id(self.start_token)
self.pad_token_id = self.token_to_id(self.pad_token)
self.end_token_id = self.token_to_id(self.end_token)
@@ -109,8 +105,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
def get_config(self):
config = super().get_config()
- # In the constructor, we pass the list of special tokens to the
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
- # delete it from the config here.
- del config["unsplittable_tokens"]
+ del config["special_tokens"] # Not configurable; set in __init__.
return config
diff --git a/keras_nlp/models/opt/opt_tokenizer_test.py b/keras_nlp/models/opt/opt_tokenizer_test.py
index 4b52ef1aed..dfda855462 100644
--- a/keras_nlp/models/opt/opt_tokenizer_test.py
+++ b/keras_nlp/models/opt/opt_tokenizer_test.py
@@ -25,7 +25,11 @@ def setUp(self):
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
- self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
+ self.init_kwargs = {
+ "vocabulary": self.vocab,
+ "merges": self.merges,
+ "special_tokens_in_strings": True,
+ }
self.input_data = [
" airplane at airport",
" airplane airport",
diff --git a/keras_nlp/models/roberta/roberta_tokenizer.py b/keras_nlp/models/roberta/roberta_tokenizer.py
index acf7f0aef9..642c618b0b 100644
--- a/keras_nlp/models/roberta/roberta_tokenizer.py
+++ b/keras_nlp/models/roberta/roberta_tokenizer.py
@@ -43,6 +43,9 @@ class RobertaTokenizer(BytePairTokenizer):
merges: A list of merge rules or a string file path, If passing a file
path. the file should have one merge rule per line. Every merge
rule contains merge entities separated by a space.
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
Examples:
```python
@@ -76,6 +79,7 @@ def __init__(
self,
vocabulary=None,
merges=None,
+ special_tokens_in_strings=False,
**kwargs,
):
self.start_token = ""
@@ -86,12 +90,13 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
- unsplittable_tokens=[
+ special_tokens=[
self.start_token,
self.pad_token,
self.end_token,
self.mask_token,
],
+ special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)
@@ -99,20 +104,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if vocabulary is not None:
- # Check for necessary special tokens.
- for token in [
- self.start_token,
- self.pad_token,
- self.end_token,
- self.mask_token,
- ]:
- if token not in self.vocabulary:
- raise ValueError(
- f"Cannot find token `'{token}'` in the provided "
- f"`vocabulary`. Please provide `'{token}'` in your "
- "`vocabulary` or use a pretrained `vocabulary` name."
- )
-
self.start_token_id = self.token_to_id(self.start_token)
self.pad_token_id = self.token_to_id(self.pad_token)
self.end_token_id = self.token_to_id(self.end_token)
@@ -125,8 +116,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
def get_config(self):
config = super().get_config()
- # In the constructor, we pass the list of special tokens to the
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
- # delete it from the config here.
- del config["unsplittable_tokens"]
+ del config["special_tokens"] # Not configurable; set in __init__.
return config
diff --git a/keras_nlp/models/roberta/roberta_tokenizer_test.py b/keras_nlp/models/roberta/roberta_tokenizer_test.py
index 3b2305608d..c35bffb609 100644
--- a/keras_nlp/models/roberta/roberta_tokenizer_test.py
+++ b/keras_nlp/models/roberta/roberta_tokenizer_test.py
@@ -26,9 +26,13 @@ def setUp(self):
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
- self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
+ self.init_kwargs = {
+ "vocabulary": self.vocab,
+ "merges": self.merges,
+ "special_tokens_in_strings": True,
+ }
self.input_data = [
- " airplane at airport",
+ " airplane at airport",
" airplane airport",
]
@@ -37,10 +41,9 @@ def test_tokenizer_basics(self):
cls=RobertaTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- # TODO: should not get tokenized as
- expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]],
+ expected_output=[[0, 4, 5, 6, 4, 7, 8, 2, 1], [4, 5, 4, 7]],
expected_detokenize_output=[
- " airplane at airport",
+ " airplane at airport",
" airplane airport",
],
)
diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py
index a78ea96639..305ed40a74 100644
--- a/keras_nlp/models/whisper/whisper_preprocessor.py
+++ b/keras_nlp/models/whisper/whisper_preprocessor.py
@@ -205,9 +205,9 @@ def build(self, input_shape):
bos_tokens += [self.tokenizer.language_tokens[self.language]]
if self.task == "transcribe":
- bos_tokens += [self.tokenizer.special_tokens["<|transcribe|>"]]
+ bos_tokens += [self.tokenizer._special_tokens["<|transcribe|>"]]
elif self.task == "translate":
- bos_tokens += [self.tokenizer.special_tokens["<|translate|>"]]
+ bos_tokens += [self.tokenizer._special_tokens["<|translate|>"]]
else:
if self.language is not None:
logging.info(
diff --git a/keras_nlp/models/whisper/whisper_tokenizer.py b/keras_nlp/models/whisper/whisper_tokenizer.py
index f14fd1ee98..ee77f1a830 100644
--- a/keras_nlp/models/whisper/whisper_tokenizer.py
+++ b/keras_nlp/models/whisper/whisper_tokenizer.py
@@ -45,6 +45,10 @@ class WhisperTokenizer(BytePairTokenizer):
language_tokens: string or dict, maps language tokens to integer IDs. If
not None, the tokenizer will be assumed to be a multilingual
tokenizer.
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
+
"""
def __init__(
@@ -53,6 +57,7 @@ def __init__(
merges=None,
special_tokens=None,
language_tokens=None,
+ special_tokens_in_strings=False,
**kwargs,
):
special_tokens = _load_dict(special_tokens)
@@ -94,7 +99,8 @@ def __init__(
self.translate_token_id = special_tokens[self.translate_token]
self.transcribe_token_id = special_tokens[self.transcribe_token]
- self.special_tokens = special_tokens
+ # Underscore to distinguish it from `self.special_tokens` in base class.
+ self._special_tokens = special_tokens
self.language_tokens = language_tokens
# TODO: Add language tokens to `unsplittable_tokens` once we figure
@@ -104,7 +110,8 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
- unsplittable_tokens=unsplittable_tokens,
+ special_tokens=unsplittable_tokens,
+ special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)
@@ -140,7 +147,7 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
self.translate_token,
self.transcribe_token,
]:
- vocabulary[token] = self.special_tokens[token]
+ vocabulary[token] = self._special_tokens[token]
else:
self._initial_vocabulary = None
@@ -148,15 +155,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
def get_config(self):
config = super().get_config()
-
- # In the constructor, we pass the list of special tokens to the
- # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
- # delete it from the config here.
- del config["unsplittable_tokens"]
-
+ del config["special_tokens"] # Not configurable; set in __init__.
config.update(
{
- "special_tokens": self.special_tokens,
+ "special_tokens": self._special_tokens,
"language_tokens": self.language_tokens,
}
)
diff --git a/keras_nlp/models/whisper/whisper_tokenizer_test.py b/keras_nlp/models/whisper/whisper_tokenizer_test.py
index 16fab2e34a..84a900104c 100644
--- a/keras_nlp/models/whisper/whisper_tokenizer_test.py
+++ b/keras_nlp/models/whisper/whisper_tokenizer_test.py
@@ -42,6 +42,7 @@ def setUp(self):
"merges": self.merges,
"special_tokens": self.special_tokens,
"language_tokens": self.language_tokens,
+ "special_tokens_in_strings": True,
}
self.input_data = [
" airplane at airport<|endoftext|>",
diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py
index cc549c28e0..f073c6a5f4 100644
--- a/keras_nlp/tokenizers/byte_pair_tokenizer.py
+++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py
@@ -59,17 +59,10 @@
SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
-def create_alts_for_unsplittable_tokens(unsplittable_tokens):
- # Create alternates for all special tokens that will be not split during
- # tokenization.
- alts = []
- prefix = "Ĵ"
- # Trim out splitters.
- replace_pattern = r"'|\s+|[^\p{L}\p{N}]+"
- for token in unsplittable_tokens:
- token = re.sub(replace_pattern, "", token)
- alts.append(prefix + token)
- return alts
+def get_special_tokens_pattern(special_tokens):
+ if special_tokens is None or len(special_tokens) == 0:
+ return None
+ return r"|".join([re.escape(token) for token in special_tokens])
def bytes_to_unicode():
@@ -104,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove):
return result
-def split_strings_for_bpe(inputs, unsplittable_tokens=None):
+def split_strings_for_bpe(inputs, special_tokens_pattern=None):
# We need to recreate the exact behavior of token presplitting in the
# original gpt2 tokenizer which uses a lookahead. As re2 does not
# support lookahead match, we are using an alternative insert a special
@@ -116,24 +109,35 @@ def split_strings_for_bpe(inputs, unsplittable_tokens=None):
inputs = tf.strings.regex_replace(
inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६"
)
- if unsplittable_tokens:
- alts = create_alts_for_unsplittable_tokens(unsplittable_tokens)
- for token, alt in zip(unsplittable_tokens, alts):
- escaped_token = re.escape(token)
- inputs = tf_text.regex_split(inputs, escaped_token, escaped_token)
- inputs = tf.strings.regex_replace(inputs, escaped_token, alt)
- raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1)
+
+ if special_tokens_pattern is not None:
+ # First split the special tokens from the input.
+ raw_tokens = tf_text.regex_split(
+ inputs, special_tokens_pattern, special_tokens_pattern
+ )
+ # Then split using both `special_tokens_pattern` and
+ # `SPLIT_PATTERN_1` to split inputs like original gpt2, while not
+ # affecting the special tokens.
+ # We split special tokens first then apply this split instead of
+ # applying this split directly, because otherwise we will not split
+ # special tokens from inputs properly, because of this pattern
+ # ` ?[^\s\p{L}\p{N}{special_spaces}]+`.
+ # e.g., [" "] will be [" ", "s", ">"] instead of [" ", ""]
+ raw_tokens = tf_text.regex_split(
+ raw_tokens,
+ r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
+ r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
+ )
+ raw_tokens = raw_tokens.merge_dims(-2, -1)
+ else:
+ raw_tokens = tf_text.regex_split(
+ inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1
+ )
+
# Second pass splits out the last whilespace char or "६".
raw_tokens = tf_text.regex_split(
raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2
)
- if unsplittable_tokens:
- # Replace special tokens alternate with originals.
- for token, alt in zip(unsplittable_tokens, alts):
- escaped_alt = re.escape(alt)
- raw_tokens = tf.strings.regex_replace(
- raw_tokens, escaped_alt, token
- )
while raw_tokens.shape.rank > 2:
raw_tokens = raw_tokens.merge_dims(1, 2)
return remove_strings_from_inputs(raw_tokens, "६")
@@ -234,12 +238,17 @@ class BytePairTokenizer(tokenizer.Tokenizer):
a prefix space to the first word will cause it to be tokenized
equivalently to all subsequent words in the sequence.
Defaults to `False`.
- unsplittable_tokens: list. A list of strings that will
- never be split during the word-level splitting applied before the
- byte-pair encoding. This can be used to ensure special tokens map to
- unique indices in the vocabulary, even if these special tokens
- contain splittable characters such as punctuation. Special tokens
- must still be included in `vocabulary`. Defaults to `None`.
+ special_tokens: list. A list of special tokens. when
+ `special_tokens_in_strings` is set to `True`, special
+ tokens will never be split during the word-level splitting applied
+ before the byte-pair encoding. This can be used to ensure special
+ tokens map to unique indices in the vocabulary, even if these
+ special tokens contain splittable characters such as
+ punctuation. special tokens must still be included in
+ `vocabulary`. Defaults to `None`.
+ special_tokens_in_strings: bool. To indicate if the tokenizer
+ should expect special tokens in input strings that should be
+ tokenized and mapped correctly to their ids. Defaults to False.
Examples:
@@ -278,7 +287,8 @@ def __init__(
merges=None,
sequence_length=None,
add_prefix_space=False,
- unsplittable_tokens=None,
+ special_tokens=None,
+ special_tokens_in_strings=False,
dtype="int32",
**kwargs,
) -> None:
@@ -293,7 +303,12 @@ def __init__(
super().__init__(dtype=dtype, **kwargs)
self.sequence_length = sequence_length
self.add_prefix_space = add_prefix_space
- self.unsplittable_tokens = unsplittable_tokens
+ self.special_tokens = special_tokens
+ self._special_tokens_pattern = None
+ if special_tokens_in_strings:
+ self._special_tokens_pattern = get_special_tokens_pattern(
+ special_tokens
+ )
# Create byte <=> unicode mapping. This is useful for handling
# whitespace tokens.
@@ -345,6 +360,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
"token to int ids. Received: "
f"`type(vocabulary)={type(vocabulary)}`."
)
+
+ # Check for special tokens in vocabulary.
+ if self.special_tokens is not None:
+ for token in self.special_tokens:
+ if token not in self.get_vocabulary():
+ raise ValueError(
+ f"Cannot find token `'{token}'` in the provided "
+ f"`vocabulary`. Please provide `'{token}'` in your"
+ "`vocabulary` or use a pretrained `vocabulary` name."
+ )
+
if isinstance(merges, str):
with open(merges, encoding="utf-8") as f:
self.merges = [bp.rstrip() for bp in f]
@@ -357,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
)
self.cache = BytePairTokenizerCache()
- if self.unsplittable_tokens:
+ if self.special_tokens and self._special_tokens_pattern is not None:
# Put special tokens into cache, so it won't be further split and
# merged.
- self.cache.insert(
- self.unsplittable_tokens, self.unsplittable_tokens
- )
+ self.cache.insert(self.special_tokens, self.special_tokens)
# Create mapping between string tokens to int ids, and vice versa.
byte_pairs = [x[0] for x in self.vocabulary.items()]
@@ -540,7 +564,7 @@ def tokenize(self, inputs):
if scalar_input:
inputs = tf.expand_dims(inputs, 0)
- raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
+ raw_tokens = split_strings_for_bpe(inputs, self._special_tokens_pattern)
token_row_splits = raw_tokens.row_splits
flat_tokens = raw_tokens.flat_values
@@ -634,7 +658,7 @@ def get_config(self):
{
"sequence_length": self.sequence_length,
"add_prefix_space": self.add_prefix_space,
- "unsplittable_tokens": self.unsplittable_tokens,
+ "special_tokens": self.special_tokens,
}
)
return config
diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py
index 00f8f9b87f..790bc4837c 100644
--- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py
+++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py
@@ -67,19 +67,40 @@ def test_tokenize_with_special_tokens(self):
tokenizer = BytePairTokenizer(
vocabulary=vocab,
merges=merges,
- unsplittable_tokens=["s", "p"],
+ special_tokens=["s", "p"],
+ special_tokens_in_strings=True,
)
output = tokenizer("sp")
self.assertAllEqual(output, [1, 2])
- # If not setting special tokens, "sp" is one token.
+ # If special_tokens_in_strings isn't `True`, "sp" is one token.
tokenizer = BytePairTokenizer(
vocabulary=vocab,
merges=merges,
+ special_tokens=["s", "p"],
)
output = tokenizer("sp")
self.assertAllEqual(output, [0])
+ # test real wolrd special tokens. e. g. and
+ vocab = {"": 0, "": 1, "a": 2, "Ġquick": 3, "Ġfox": 4}
+ merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
+ merges += ["Ġ f", "o x", "Ġf ox"]
+ tokenizer = BytePairTokenizer(
+ vocabulary=vocab,
+ merges=merges,
+ special_tokens=["", ""],
+ special_tokens_in_strings=True,
+ )
+ output = tokenizer("a quick fox")
+ self.assertAllEqual(output, [0, 2, 3, 4, 1])
+
+ def test_errors_missing_special_tokens(self):
+ with self.assertRaises(ValueError):
+ BytePairTokenizer(
+ vocabulary=["a", "b", "c"], merges=[], special_tokens=["d"]
+ )
+
def test_tokenize_prefix_space(self):
input_data = ["brown.", "black."]
tokenizer = BytePairTokenizer(