diff --git a/configs/exp/train/classification/classification.json b/configs/exp/train/classification/classification.json
index af851ed..5257230 100755
--- a/configs/exp/train/classification/classification.json
+++ b/configs/exp/train/classification/classification.json
@@ -86,6 +86,11 @@
}
},
"tokenizer_settings": {},
+ "special_tokens_settings": {
+ "bos_token": "",
+ "eos_token": "",
+ "pad_token": ""
+ },
"trainer_settings": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
diff --git a/configs/exp/train/dpo/dpo.json b/configs/exp/train/dpo/dpo.json
index 4e4dc21..61bb787 100755
--- a/configs/exp/train/dpo/dpo.json
+++ b/configs/exp/train/dpo/dpo.json
@@ -96,6 +96,10 @@
}
},
"tokenizer_settings": {},
+ "special_tokens_settings": {
+ "bos_token": "<|begin_of_text|>",
+ "eos_token": "<|end_of_text|>"
+ },
"trainer_settings": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
diff --git a/configs/exp/train/multimodal/c_abs.json b/configs/exp/train/multimodal/c_abs.json
index 713abeb..46f00f7 100644
--- a/configs/exp/train/multimodal/c_abs.json
+++ b/configs/exp/train/multimodal/c_abs.json
@@ -105,6 +105,10 @@
"tokenizer_settings": {
"tokenizer_path": "/from_s3/model"
},
+ "special_tokens_settings": {
+ "bos_token": "",
+ "eos_token": ""
+ },
"trainer_settings": {
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
diff --git a/configs/exp/train/multimodal/llava.json b/configs/exp/train/multimodal/llava.json
deleted file mode 100644
index e69de29..0000000
diff --git a/configs/exp/train/multimodal/mlp.json b/configs/exp/train/multimodal/mlp.json
index afb2f2f..68a8925 100644
--- a/configs/exp/train/multimodal/mlp.json
+++ b/configs/exp/train/multimodal/mlp.json
@@ -105,6 +105,10 @@
"tokenizer_settings": {
"tokenizer_path": "/from_s3/model"
},
+ "special_tokens_settings": {
+ "bos_token": "",
+ "eos_token": ""
+ },
"trainer_settings": {
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
diff --git a/configs/exp/train/rag/end2end_rag.json b/configs/exp/train/rag/end2end_rag.json
index 28c109a..e0b5cad 100755
--- a/configs/exp/train/rag/end2end_rag.json
+++ b/configs/exp/train/rag/end2end_rag.json
@@ -112,6 +112,10 @@
"metric_settings": []
},
"tokenizer_settings": {},
+ "special_tokens_settings": {
+ "bos_token": "",
+ "eos_token": ""
+ },
"trainer_settings": {
"evaluation_strategy": "steps",
"save_strategy": "steps",
diff --git a/configs/exp/train/rm/rm.json b/configs/exp/train/rm/rm.json
index f46d2f0..b9bd08c 100755
--- a/configs/exp/train/rm/rm.json
+++ b/configs/exp/train/rm/rm.json
@@ -83,6 +83,10 @@
}
},
"tokenizer_settings": {},
+ "special_tokens_settings": {
+ "bos_token": "<|begin_of_text|>",
+ "eos_token": "<|end_of_text|>"
+ },
"trainer_settings": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
diff --git a/configs/exp/train/sft/sft.json b/configs/exp/train/sft/sft.json
index 14f5606..f746e83 100755
--- a/configs/exp/train/sft/sft.json
+++ b/configs/exp/train/sft/sft.json
@@ -104,6 +104,10 @@
]
},
"tokenizer_settings": {},
+ "special_tokens_settings": {
+ "bos_token": "<|begin_of_text|>",
+ "eos_token": "<|end_of_text|>"
+},
"trainer_settings": {
"evaluation_strategy": "steps",
"save_total_limit": 5,
diff --git a/turbo_alignment/common/tf/special_tokens_setter.py b/turbo_alignment/common/tf/special_tokens_setter.py
index d15eda4..b1aaeb6 100755
--- a/turbo_alignment/common/tf/special_tokens_setter.py
+++ b/turbo_alignment/common/tf/special_tokens_setter.py
@@ -13,11 +13,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, special_tokens_settings:
self._special_tokens_settings = special_tokens_settings
self._special_tokens_already_set: bool = False
- def setBOS(self, bos_token: str | None) -> None:
- if self._tokenizer.bos_token_id is None and bos_token is None:
- logger.info('Skip adding bos_token_id')
- return None
-
+ def setBOS(self, bos_token: str) -> None:
if self._tokenizer.bos_token_id is None:
logger.info('Model does not have bos_token_id')
self._tokenizer.add_special_tokens(special_tokens_dict={'bos_token': bos_token})
@@ -26,13 +22,7 @@ def setBOS(self, bos_token: str | None) -> None:
else:
logger.info(f'Model has bos_token_id = {self._tokenizer.bos_token_id}')
- return None
-
- def setEOS(self, eos_token: str | None) -> None:
- if self._tokenizer.eos_token_id is None and eos_token is None:
- logger.info('Skip adding eos_token_id')
- return None
-
+ def setEOS(self, eos_token: str) -> None:
if self._tokenizer.eos_token_id is None:
logger.info('Model does not have eos_token_id')
self._tokenizer.add_special_tokens(special_tokens_dict={'eos_token': eos_token})
@@ -41,8 +31,6 @@ def setEOS(self, eos_token: str | None) -> None:
else:
logger.info(f'Model has eos_token_id = {self._tokenizer.eos_token_id}')
- return None
-
def setPAD(self, pad_token: str | None) -> None:
if self._tokenizer.pad_token_id is None and pad_token is None:
logger.info('Skip adding pad_token_id')
@@ -105,12 +93,11 @@ def set_custom_tokens(self, tokens: list[str]) -> None:
added_tokens = self._tokenizer.add_special_tokens({'additional_special_tokens': tokens})
assert added_tokens == len(tokens)
- def setup_model_config(self, model: PreTrainedModel):
- if self._tokenizer.bos_token_id is None:
- model.config.bos_token_id = self._tokenizer.bos_token_id
- if self._tokenizer.eos_token_id is None:
- model.config.eos_token_id = self._tokenizer.eos_token_id
- if self._tokenizer.pad_token_id is None:
+ def setup_model_config(self, model: PreTrainedModel) -> None:
+ model.config.bos_token_id = self._tokenizer.bos_token_id
+ model.config.eos_token_id = self._tokenizer.eos_token_id
+
+ if self._tokenizer.pad_token_id is not None:
model.config.pad_token_id = self._tokenizer.pad_token_id
- if self._tokenizer.sep_token_id is None:
+ if self._tokenizer.sep_token_id is not None:
model.config.sep_token_id = self._tokenizer.sep_token_id
diff --git a/turbo_alignment/settings/tf/special_tokens_setter.py b/turbo_alignment/settings/tf/special_tokens_setter.py
index 90f6ab9..fc6ab7e 100755
--- a/turbo_alignment/settings/tf/special_tokens_setter.py
+++ b/turbo_alignment/settings/tf/special_tokens_setter.py
@@ -2,8 +2,8 @@
class SpecialTokensSettings(ExtraFieldsNotAllowedBaseModel):
- bos_token: str = ''
- eos_token: str = ''
+ bos_token: str
+ eos_token: str
pad_token: str | None = None
unk_token: str | None = None
sep_token: str | None = None