diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index c81856b8ba..a68a611c52 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -76,6 +76,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: DatasetTooSmallError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, + InvalidConversationError, InvalidExampleTypeError, InvalidFileExtensionError, InvalidLastChatMessageRoleError, @@ -270,17 +271,17 @@ def slice_out_last_turn( if conversation_through_previous_turn != full_conversation[:len( conversation_through_previous_turn, )]: - raise ValueError( + raise InvalidConversationError( f'The full conversation must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {full_conversation=}', ) if conversation_through_previous_turn != prompt_with_history[:len( conversation_through_previous_turn, )]: - raise ValueError( + raise InvalidConversationError( f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}', ) if prompt_with_history != full_conversation[:len(prompt_with_history)]: - raise ValueError( + raise InvalidConversationError( f'prompt_with_history must be the first part of the full conversation. {prompt_with_history=}, {full_conversation=}', ) prompt = prompt_with_history[len(conversation_through_previous_turn):] diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 4a4321637f..81cfb21d11 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -497,3 +497,18 @@ def __init__(self, files_searched: list[str]) -> None: message, files_searched=files_searched, ) + + +class InvalidConversationError(UserError): + """Error thrown when the conversation is invalid.""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(message) + + def __reduce__(self): + # Return a tuple of class, a tuple of arguments, and optionally state + return (InvalidConversationError, (self.message,)) + + def __str__(self): + return self.message