From d0df9efb5f650c6817a3b5285e70274a6f58370f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Thu, 12 Sep 2024 13:34:39 +0000 Subject: [PATCH] add return_logits to vllm generator --- turbo_alignment/generators/base.py | 1 - turbo_alignment/generators/vllm_chat.py | 27 ++++++++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 8a9d36e..5415daf 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -105,7 +105,6 @@ def __init__( ) self._custom_generation_settings = custom_generation_settings - self._return_logits = return_logits @abstractmethod def _generate_from_single_record( diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index c2d6e6e..ff52de3 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -22,6 +22,7 @@ def __init__( model: LLM, tokenizer: PreTrainedTokenizerBase, batch: int, + return_logits: bool = False, ): model.set_tokenizer(tokenizer) super().__init__(model, tokenizer, batch=batch) @@ -51,6 +52,8 @@ def __init__( **beam_search_params, ) + self._return_logits = return_logits + def _generate_from_batch( self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str ) -> list[ChatInferenceOutput]: @@ -64,22 +67,26 @@ def _generate_from_batch( outputs = [] for i, request_output in enumerate(request_outputs): original_record = original_records[i] + answers = [] + for a in request_output.outputs: + ans_msg = AnswerMessage( + id=str(a.index), + content=a.text, + sequence_score=a.cumulative_logprob, + ) + if self._return_logits: + ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0) + ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0) + + answers.append(ans_msg) + outputs.append( ChatInferenceOutput( id=original_record.id, dataset_name=dataset_name, messages=original_record.messages, label=original_record.label, - answers=[ - AnswerMessage( - id=str(a.index), - content=a.text, - input_token_ids=torch.tensor(request_output.prompt_token_ids).unsqueeze(0), - answer_token_ids=torch.tensor(a.token_ids).unsqueeze(0), - sequence_score=a.cumulative_logprob, - ) - for a in request_output.outputs - ], + answers=answers, ) ) return outputs