From 0db67f64595bb4e08bbaf8af1f3887088a957775 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Fri, 21 Jun 2024 11:55:13 -0400 Subject: [PATCH] Fixed template imports Signed-off-by: Mustafa Eyceoz --- .../training/chat_templates/ibm_generic_tmpl.py | 4 ++-- .../training/chat_templates/mistral_tmpl.py | 4 ++-- src/instructlab/training/utils.py | 15 +++++++++------ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/instructlab/training/chat_templates/ibm_generic_tmpl.py b/src/instructlab/training/chat_templates/ibm_generic_tmpl.py index 6d4b37d7..87bfdb0a 100644 --- a/src/instructlab/training/chat_templates/ibm_generic_tmpl.py +++ b/src/instructlab/training/chat_templates/ibm_generic_tmpl.py @@ -1,5 +1,5 @@ -# Third Party -from tokenizer_utils import SpecialTokens +# First Party +from instructlab.training.tokenizer_utils import SpecialTokens SPECIAL_TOKENS = SpecialTokens( system="<|system|>", diff --git a/src/instructlab/training/chat_templates/mistral_tmpl.py b/src/instructlab/training/chat_templates/mistral_tmpl.py index 753d5559..965823f2 100644 --- a/src/instructlab/training/chat_templates/mistral_tmpl.py +++ b/src/instructlab/training/chat_templates/mistral_tmpl.py @@ -1,5 +1,5 @@ -# Third Party -from tokenizer_utils import SpecialTokens +# First Party +from instructlab.training.tokenizer_utils import SpecialTokens SPECIAL_TOKENS = SpecialTokens( bos="", diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index ce935c41..9e57dc68 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -29,12 +29,15 @@ def retrieve_chat_template(chat_tmpl_path): import importlib.util import sys - spec = importlib.util.spec_from_file_location("spcl_chat_tmpl", chat_tmpl_path) - module = importlib.util.module_from_spec(spec) - sys.modules["spcl_chat_tmpl"] = module - spec.loader.exec_module(module) - SPECIAL_TOKENS = module.SPECIAL_TOKENS - CHAT_TEMPLATE = module.CHAT_TEMPLATE + try: + spec = importlib.util.spec_from_file_location("spcl_chat_tmpl", chat_tmpl_path) + module = importlib.util.module_from_spec(spec) + sys.modules["spcl_chat_tmpl"] = module + spec.loader.exec_module(module) + SPECIAL_TOKENS = module.SPECIAL_TOKENS + CHAT_TEMPLATE = module.CHAT_TEMPLATE + except: + sys.exit(f"Invalid chat template path: {chat_tmpl_path}") return CHAT_TEMPLATE, SPECIAL_TOKENS