Skip to content

Commit

Permalink
🫣 Fix vllm duplicating <bos> token
Browse files Browse the repository at this point in the history
🫣 Fix vllm duplicating <bos> token
  • Loading branch information
alekseymalakhov11 authored Sep 13, 2024
2 parents 184e19c + d0df9ef commit d03c83f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
1 change: 0 additions & 1 deletion turbo_alignment/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(
)

self._custom_generation_settings = custom_generation_settings
self._return_logits = return_logits

@abstractmethod
def _generate_from_single_record(
Expand Down
29 changes: 23 additions & 6 deletions turbo_alignment/generators/vllm_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any

import torch
from transformers import PreTrainedTokenizerBase
from vllm import LLM, SamplingParams

Expand All @@ -21,6 +22,7 @@ def __init__(
model: LLM,
tokenizer: PreTrainedTokenizerBase,
batch: int,
return_logits: bool = False,
):
model.set_tokenizer(tokenizer)
super().__init__(model, tokenizer, batch=batch)
Expand Down Expand Up @@ -50,26 +52,41 @@ def __init__(
**beam_search_params,
)

self._return_logits = return_logits

def _generate_from_batch(
self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str
) -> list[ChatInferenceOutput]:
input_ids = [record['input_ids'].tolist() for record in records]
prompts = self._tokenizer.batch_decode(sequences=input_ids, skip_special_tokens=False)
request_outputs = self._model.generate(prompts, self._sampling_params)
request_outputs = self._model.generate(
prompts=None,
prompt_token_ids=input_ids,
sampling_params=self._sampling_params,
)

outputs = []
for i, request_output in enumerate(request_outputs):
original_record = original_records[i]
answers = []
for a in request_output.outputs:
ans_msg = AnswerMessage(
id=str(a.index),
content=a.text,
sequence_score=a.cumulative_logprob,
)
if self._return_logits:
ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0)
ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0)

answers.append(ans_msg)

outputs.append(
ChatInferenceOutput(
id=original_record.id,
dataset_name=dataset_name,
messages=original_record.messages,
label=original_record.label,
answers=[
AnswerMessage(id=str(a.index), content=a.text, sequence_score=a.cumulative_logprob)
for a in request_output.outputs
],
answers=answers,
)
)
return outputs

0 comments on commit d03c83f

Please sign in to comment.