diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index a1409ba..b722d45 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -193,8 +193,8 @@ def _truncate_and_merge( ] right_bound = random.choice(bot_indices) if bot_indices else right_bound - input_ids = np.array([self.tokenizer.bos_token_id]) - labels = np.array([DISABLE_LOSS_LABEL]) + input_ids = np.array([]) + labels = np.array([]) truncated_conversation_messages = conversation.messages[left_bound:right_bound] truncated_tokenized_replicas = tokenized_replicas[left_bound:right_bound] @@ -205,8 +205,8 @@ def _truncate_and_merge( for ind, (message, tokenized_replica) in enumerate( zip( - truncated_conversation_messages[left_bound:right_bound], - truncated_tokenized_replicas[left_bound:right_bound], + truncated_conversation_messages, + truncated_tokenized_replicas, ) ): prefix_tokens = role_prefix_tokens[message.role] @@ -236,6 +236,20 @@ def _truncate_and_merge( input_ids = np.append(input_ids, self.tokenizer.eos_token_id) labels = np.append(labels, DISABLE_LOSS_LABEL) + # FIXME + start_replica_token_id = role_prefix_tokens[ChatMessageRole.BOT][0].item() + + # -1 for bos token + input_ids = input_ids[-(self.settings.max_tokens_count - 1) :] # type: ignore[operator] + replica_start_token_inds = np.where(input_ids == start_replica_token_id)[0] + if len(replica_start_token_inds) != 0: + cut_index = replica_start_token_inds[0] + input_ids = input_ids[cut_index:] + + labels = labels[-len(input_ids) :] + input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids)) + labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels)) + return input_ids, labels, conversation.get_prompt_repr(left_bound, right_bound) def _encode(