diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index ca7db32dc9..a95225ff87 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -394,7 +394,7 @@ def tokenize_prompt(self, prompt): LOG.warning(f"assistant turn has empty text: {prompt}") res = self._tokenize( turn, - add_eos_token=False if conversation.name == 'chatml' else True, + add_eos_token=conversation.name == "chatml", strip_bos_token=True, ) role_res = self._tokenize( diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 9e7373809d..09eb6111eb 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -74,9 +74,11 @@ def train( resume_from_checkpoint = cfg.resume_from_checkpoint if dist.get_rank() == 0: - print('\n\n*********** INPUT SANITY CHECK ***********') - print(tokenizer.decode(train_dataset[0]['input_ids'], skip_special_tokens=False)) - print('******************************************\n\n') + print("\n\n*********** INPUT SANITY CHECK ***********") + print( + tokenizer.decode(train_dataset[0]["input_ids"], skip_special_tokens=False) + ) + print("******************************************\n\n") trainer = setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps @@ -171,7 +173,7 @@ def terminate_handler(_, __, model): if not cfg.hub_model_id: trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) else: - dataset = [d['path'] for d in cfg.datasets] + dataset = [d["path"] for d in cfg.datasets] trainer.push_to_hub(dataset=dataset, dataset_tags=dataset) return model, tokenizer