Skip to content

Commit

Permalink
add return_logits to vllm generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Sep 12, 2024
1 parent 3a234e4 commit d0df9ef
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
1 change: 0 additions & 1 deletion turbo_alignment/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(
)

self._custom_generation_settings = custom_generation_settings
self._return_logits = return_logits

@abstractmethod
def _generate_from_single_record(
Expand Down
27 changes: 17 additions & 10 deletions turbo_alignment/generators/vllm_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
model: LLM,
tokenizer: PreTrainedTokenizerBase,
batch: int,
return_logits: bool = False,
):
model.set_tokenizer(tokenizer)
super().__init__(model, tokenizer, batch=batch)
Expand Down Expand Up @@ -51,6 +52,8 @@ def __init__(
**beam_search_params,
)

self._return_logits = return_logits

def _generate_from_batch(
self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str
) -> list[ChatInferenceOutput]:
Expand All @@ -64,22 +67,26 @@ def _generate_from_batch(
outputs = []
for i, request_output in enumerate(request_outputs):
original_record = original_records[i]
answers = []
for a in request_output.outputs:
ans_msg = AnswerMessage(
id=str(a.index),
content=a.text,
sequence_score=a.cumulative_logprob,
)
if self._return_logits:
ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0)
ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0)

answers.append(ans_msg)

outputs.append(
ChatInferenceOutput(
id=original_record.id,
dataset_name=dataset_name,
messages=original_record.messages,
label=original_record.label,
answers=[
AnswerMessage(
id=str(a.index),
content=a.text,
input_token_ids=torch.tensor(request_output.prompt_token_ids).unsqueeze(0),
answer_token_ids=torch.tensor(a.token_ids).unsqueeze(0),
sequence_score=a.cumulative_logprob,
)
for a in request_output.outputs
],
answers=answers,
)
)
return outputs

0 comments on commit d0df9ef

Please sign in to comment.