diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 8a9d36e..5415daf 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -105,7 +105,6 @@ def __init__( ) self._custom_generation_settings = custom_generation_settings - self._return_logits = return_logits @abstractmethod def _generate_from_single_record( diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index c2d6e6e..ff52de3 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -22,6 +22,7 @@ def __init__( model: LLM, tokenizer: PreTrainedTokenizerBase, batch: int, + return_logits: bool = False, ): model.set_tokenizer(tokenizer) super().__init__(model, tokenizer, batch=batch) @@ -51,6 +52,8 @@ def __init__( **beam_search_params, ) + self._return_logits = return_logits + def _generate_from_batch( self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str ) -> list[ChatInferenceOutput]: @@ -64,22 +67,26 @@ def _generate_from_batch( outputs = [] for i, request_output in enumerate(request_outputs): original_record = original_records[i] + answers = [] + for a in request_output.outputs: + ans_msg = AnswerMessage( + id=str(a.index), + content=a.text, + sequence_score=a.cumulative_logprob, + ) + if self._return_logits: + ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0) + ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0) + + answers.append(ans_msg) + outputs.append( ChatInferenceOutput( id=original_record.id, dataset_name=dataset_name, messages=original_record.messages, label=original_record.label, - answers=[ - AnswerMessage( - id=str(a.index), - content=a.text, - input_token_ids=torch.tensor(request_output.prompt_token_ids).unsqueeze(0), - answer_token_ids=torch.tensor(a.token_ids).unsqueeze(0), - sequence_score=a.cumulative_logprob, - ) - for a in request_output.outputs - ], + answers=answers, ) ) return outputs