diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 8a9d36e..5415daf 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -105,7 +105,6 @@ def __init__( ) self._custom_generation_settings = custom_generation_settings - self._return_logits = return_logits @abstractmethod def _generate_from_single_record( diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index 513c5e7..ff52de3 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -1,5 +1,6 @@ from typing import Any +import torch from transformers import PreTrainedTokenizerBase from vllm import LLM, SamplingParams @@ -21,6 +22,7 @@ def __init__( model: LLM, tokenizer: PreTrainedTokenizerBase, batch: int, + return_logits: bool = False, ): model.set_tokenizer(tokenizer) super().__init__(model, tokenizer, batch=batch) @@ -50,26 +52,41 @@ def __init__( **beam_search_params, ) + self._return_logits = return_logits + def _generate_from_batch( self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str ) -> list[ChatInferenceOutput]: input_ids = [record['input_ids'].tolist() for record in records] - prompts = self._tokenizer.batch_decode(sequences=input_ids, skip_special_tokens=False) - request_outputs = self._model.generate(prompts, self._sampling_params) + request_outputs = self._model.generate( + prompts=None, + prompt_token_ids=input_ids, + sampling_params=self._sampling_params, + ) outputs = [] for i, request_output in enumerate(request_outputs): original_record = original_records[i] + answers = [] + for a in request_output.outputs: + ans_msg = AnswerMessage( + id=str(a.index), + content=a.text, + sequence_score=a.cumulative_logprob, + ) + if self._return_logits: + ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0) + ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0) + + answers.append(ans_msg) + outputs.append( ChatInferenceOutput( id=original_record.id, dataset_name=dataset_name, messages=original_record.messages, label=original_record.label, - answers=[ - AnswerMessage(id=str(a.index), content=a.text, sequence_score=a.cumulative_logprob) - for a in request_output.outputs - ], + answers=answers, ) ) return outputs