Skip to content

Commit

Permalink
fix keep end
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Sep 16, 2024
1 parent 99d32f4 commit 3f00b2b
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def _truncate_and_merge(
]
right_bound = random.choice(bot_indices) if bot_indices else right_bound

input_ids = np.array([self.tokenizer.bos_token_id])
labels = np.array([DISABLE_LOSS_LABEL])
input_ids = np.array([])
labels = np.array([])

truncated_conversation_messages = conversation.messages[left_bound:right_bound]
truncated_tokenized_replicas = tokenized_replicas[left_bound:right_bound]
Expand All @@ -205,8 +205,8 @@ def _truncate_and_merge(

for ind, (message, tokenized_replica) in enumerate(
zip(
truncated_conversation_messages[left_bound:right_bound],
truncated_tokenized_replicas[left_bound:right_bound],
truncated_conversation_messages,
truncated_tokenized_replicas,
)
):
prefix_tokens = role_prefix_tokens[message.role]
Expand Down Expand Up @@ -236,6 +236,20 @@ def _truncate_and_merge(
input_ids = np.append(input_ids, self.tokenizer.eos_token_id)
labels = np.append(labels, DISABLE_LOSS_LABEL)

# FIXME
start_replica_token_id = role_prefix_tokens[ChatMessageRole.BOT][0].item()

# -1 for bos token
input_ids = input_ids[-(self.settings.max_tokens_count - 1) :] # type: ignore[operator]
replica_start_token_inds = np.where(input_ids == start_replica_token_id)[0]
if len(replica_start_token_inds) != 0:
cut_index = replica_start_token_inds[0]
input_ids = input_ids[cut_index:]

labels = labels[-len(input_ids) :]
input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids))
labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels))

return input_ids, labels, conversation.get_prompt_repr(left_bound, right_bound)

def _encode(
Expand Down

0 comments on commit 3f00b2b

Please sign in to comment.