Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 18, 2024
1 parent a0afc51 commit d31cb0c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 11 deletions.
9 changes: 8 additions & 1 deletion swift/llm/infer/lmdeploy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def _get_logprobs(tokenizer: PreTrainedTokenizerBase,
top_logprobs: Optional[int] = None) -> Optional[Dict[str, Any]]:
if logprobs_list is None:
return None
assert len(token_ids) > 0
logprobs_list = logprobs_list[-len(token_ids):]
res = []
for logprobs, token_id in zip(logprobs_list, token_ids):
token = tokenizer.decode(token_id)
Expand Down Expand Up @@ -182,6 +184,7 @@ async def _infer_stream_async(
generator = await self._add_request(template, inputs, session_id)

infer_streamer = InferStreamer(template)
token_idx = 0
async with self.engine.safe_run(session_id):
async_iter = generator.async_stream_infer(
session_id=session_id, **inputs, stream_output=True, gen_config=generation_config).__aiter__()
Expand All @@ -196,13 +199,17 @@ async def _infer_stream_async(
continue

usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
logprobs = self._get_logprobs(template.tokenizer, output.logprobs, output.token_ids[token_idx:],
generation_config.logprobs)
token_idx = len(output.token_ids)
finish_reason = self._get_finish_reason(output, generation_config)
toolcall = self._get_toolcall(output.token_ids, is_finished)
choices = [
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall),
finish_reason=finish_reason)
finish_reason=finish_reason,
logprobs=logprobs)
]
yield ChatCompletionStreamResponse(
model=self.model_dir, choices=choices, usage=usage_info, id=request_id)
Expand Down
17 changes: 13 additions & 4 deletions swift/llm/infer/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import json
import torch
from transformers import GenerationConfig, PreTrainedTokenizerBase, StoppingCriteriaList
from transformers import GenerationConfig, PreTrainedTokenizerBase, StoppingCriteriaList, LogitsProcessorList
from transformers.utils import is_torch_npu_available

from swift.plugin import Metric
Expand All @@ -15,7 +15,7 @@
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, InferRequest, RequestConfig,
random_uuid)
from .utils import InferStreamer, InferTools, StopWordsCriteria, TokensIteratorStreamer
from .utils import InferStreamer, InferTools, StopWordsCriteria, TokensIteratorStreamer, LogitsStreamer

logger = get_logger()

Expand Down Expand Up @@ -94,6 +94,8 @@ def _get_logprobs(tokenizer: PreTrainedTokenizerBase,
top_logprobs: Optional[int] = None) -> Optional[Dict[str, Any]]:
if logits_list is None:
return None
assert len(generate_ids) > 0
logits_list = logits_list[-len(generate_ids):]
res = []
for logits, token_id in zip(logits_list, generate_ids):
token = tokenizer.decode(token_id)
Expand Down Expand Up @@ -167,6 +169,7 @@ def _infer_stream(self,
raise ValueError(error_msg)

streamer = TokensIteratorStreamer()
logits_streamer = LogitsStreamer()
thread = Thread(
target=self.__model_generate,
kwargs={
Expand All @@ -182,6 +185,7 @@ def _infer_stream(self,
all_is_finished = False
is_finished = [False] * batch_size
request_id_list = [f'chatcmpl-{random_uuid()}' for _ in range(batch_size)]
token_idxs = [0] * batch_size
while not all_is_finished:
try:
tokens = next(streamer)
Expand All @@ -205,14 +209,19 @@ def _infer_stream(self,
res.append(None)
continue
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
logprobs = None
# logprobs = self._get_logprobs(self.tokenizer, output.get('logits'), generate_ids[token_idxs[i]:],
# generation_config.top_logprobs)
token_idxs[i] = len(generate_ids)
finish_reason = self._get_finish_reason(generation_config, num_prompt_tokens, is_finished[i])
toolcall = self._get_toolcall(generate_ids, is_finished[i])

choices = [
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall),
finish_reason=finish_reason)
finish_reason=finish_reason,
logprobs=logprobs)
]
res.append(
ChatCompletionStreamResponse(
Expand Down Expand Up @@ -241,7 +250,7 @@ def _infer_full(self,
generate_ids = self._ignore_pad_token(generate_ids, generation_config.pad_token_id)
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
response = InferTools.safe_decode(template, generate_ids, True)
logprobs = self._get_logprobs(self.tokenizer, output['logits'], generate_ids,
logprobs = self._get_logprobs(self.tokenizer, output.get('logits'), generate_ids,
generation_config.top_logprobs)
toolcall = self._get_toolcall(response, True)
choices = [
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/infer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List

import torch
from transformers import PreTrainedTokenizerBase, StoppingCriteria
from transformers import PreTrainedTokenizerBase, StoppingCriteria, LogitsProcessor
from transformers.generation.streamers import BaseStreamer

from swift.plugin import Metric
Expand Down Expand Up @@ -171,7 +171,7 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> t
text = self.tokenizer.decode(input_ids[i, self.start_idx:][-20:], **self.decode_kwargs)
for stop_word in self.stop_words:
if isinstance(stop_word, str) and stop_word in text or isinstance(
stop_word, list) and input_ids[i][-len(stop_word):].tolist() == stop_word:
stop_word, list) and input_ids[i][-len(stop_word):].tolist() == stop_word:
is_finished[i] = True
break
else:
Expand Down
9 changes: 5 additions & 4 deletions swift/llm/infer/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def _get_logprobs(tokenizer: PreTrainedTokenizerBase,
top_logprobs: Optional[int] = None) -> Optional[Dict[str, Any]]:
if logprobs_list is None:
return None
assert len(token_ids) > 0
logprobs_list = logprobs_list[-len(token_ids):]
res = []
for logprobs, token_id in zip(logprobs_list, token_ids):
logprob = logprobs[token_id]
Expand Down Expand Up @@ -270,10 +272,9 @@ async def _infer_stream_async(self, template: Template, inputs: Dict[str, Any],
usage_info = self._get_usage_info(len(result.prompt_token_ids), num_generated_tokens)
choices = []
for output in result.outputs:
token_idx = token_idxs[output.index]
logprobs = self._get_logprobs(template.tokenizer, output.logprobs[token_idx:], output.token_ids[token_idx:],
generation_config.logprobs)
token_idxs[output.index] = len(output.logprobs)
logprobs = self._get_logprobs(template.tokenizer, output.logprobs,
output.token_ids[token_idxs[output.index]:], generation_config.logprobs)
token_idxs[output.index] = len(output.token_ids)
toolcall = self._get_toolcall(output.token_ids, output.is_finished)
choice = ChatCompletionResponseStreamChoice(
index=output.index,
Expand Down

0 comments on commit d31cb0c

Please sign in to comment.