From 3a234e471113700126778e28047d1f61ff62ba42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Wed, 11 Sep 2024 12:23:29 +0000 Subject: [PATCH 1/2] fix double eos in vvlm --- turbo_alignment/generators/vllm_chat.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index 513c5e7..c2d6e6e 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -1,5 +1,6 @@ from typing import Any +import torch from transformers import PreTrainedTokenizerBase from vllm import LLM, SamplingParams @@ -54,8 +55,11 @@ def _generate_from_batch( self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str ) -> list[ChatInferenceOutput]: input_ids = [record['input_ids'].tolist() for record in records] - prompts = self._tokenizer.batch_decode(sequences=input_ids, skip_special_tokens=False) - request_outputs = self._model.generate(prompts, self._sampling_params) + request_outputs = self._model.generate( + prompts=None, + prompt_token_ids=input_ids, + sampling_params=self._sampling_params, + ) outputs = [] for i, request_output in enumerate(request_outputs): @@ -67,7 +71,13 @@ def _generate_from_batch( messages=original_record.messages, label=original_record.label, answers=[ - AnswerMessage(id=str(a.index), content=a.text, sequence_score=a.cumulative_logprob) + AnswerMessage( + id=str(a.index), + content=a.text, + input_token_ids=torch.tensor(request_output.prompt_token_ids).unsqueeze(0), + answer_token_ids=torch.tensor(a.token_ids).unsqueeze(0), + sequence_score=a.cumulative_logprob, + ) for a in request_output.outputs ], ) From d0df9efb5f650c6817a3b5285e70274a6f58370f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Thu, 12 Sep 2024 13:34:39 +0000 Subject: [PATCH 2/2] add return_logits to vllm generator --- turbo_alignment/generators/base.py | 1 - turbo_alignment/generators/vllm_chat.py | 27 ++++++++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 8a9d36e..5415daf 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -105,7 +105,6 @@ def __init__( ) self._custom_generation_settings = custom_generation_settings - self._return_logits = return_logits @abstractmethod def _generate_from_single_record( diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index c2d6e6e..ff52de3 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -22,6 +22,7 @@ def __init__( model: LLM, tokenizer: PreTrainedTokenizerBase, batch: int, + return_logits: bool = False, ): model.set_tokenizer(tokenizer) super().__init__(model, tokenizer, batch=batch) @@ -51,6 +52,8 @@ def __init__( **beam_search_params, ) + self._return_logits = return_logits + def _generate_from_batch( self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str ) -> list[ChatInferenceOutput]: @@ -64,22 +67,26 @@ def _generate_from_batch( outputs = [] for i, request_output in enumerate(request_outputs): original_record = original_records[i] + answers = [] + for a in request_output.outputs: + ans_msg = AnswerMessage( + id=str(a.index), + content=a.text, + sequence_score=a.cumulative_logprob, + ) + if self._return_logits: + ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0) + ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0) + + answers.append(ans_msg) + outputs.append( ChatInferenceOutput( id=original_record.id, dataset_name=dataset_name, messages=original_record.messages, label=original_record.label, - answers=[ - AnswerMessage( - id=str(a.index), - content=a.text, - input_token_ids=torch.tensor(request_output.prompt_token_ids).unsqueeze(0), - answer_token_ids=torch.tensor(a.token_ids).unsqueeze(0), - sequence_score=a.cumulative_logprob, - ) - for a in request_output.outputs - ], + answers=answers, ) ) return outputs