Skip to content

Commit

Permalink
support vllm infer
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 15, 2024
1 parent ed06221 commit 20029cc
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 125 deletions.
127 changes: 6 additions & 121 deletions swift/llm/infer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import time
from abc import ABC, abstractmethod
from queue import Queue
from queue import Empty, Queue
from threading import Thread
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union

Expand All @@ -16,125 +16,6 @@
from .protocol import ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, RequestConfig


class InferTools:

@staticmethod
def _is_chinese_char(cp: int) -> bool:
"""Checks whether CP is the codepoint of a CJK character."""
# copy from transformers.generation.streamers.TextStreamer
if ((0x4E00 <= cp <= 0x9FFF) or (0x3400 <= cp <= 0x4DBF) or (0x20000 <= cp <= 0x2A6DF)
or (0x2A700 <= cp <= 0x2B73F) or (0x2B740 <= cp <= 0x2B81F) or (0x2B820 <= cp <= 0x2CEAF)
or (0xF900 <= cp <= 0xFAFF) or (0x2F800 <= cp <= 0x2FA1F)):
return True

return False

@staticmethod
def safe_decode(template: Template, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> str:
tokenizer = template.tokenizer

if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
generate_ids = generate_ids[:-1]
response = tokenizer.decode(generate_ids, **decode_kwargs)
if len(response) == 0 or response.endswith('\n'):
return response
# skip suffix and eos_token
template_suffix = template.suffix[-1]
if isinstance(template_suffix, list):
template_suffix = tokenizer.decode(template_suffix)
len_suffix = len(template.suffix[-1])
if not is_finished or is_finished and response[-len_suffix:] == template.suffix[-1]:
# To avoid response length being shorter than previous response length during streaming.
# TODO:check
# idx = max(len(response) - len_suffix, 0, self.print_idx)
response = response[:-len_suffix]
return response


class InferStreamer(InferTools):

def __init__(self, template, input_delta: bool = False, **decode_kwargs):
self.template = template
self.tokenizer = template.tokenizer

self.token_cache = [] # Reduce the time of tokenizer.decode
self.input_delta = input_delta
self.cache_idx = 0 # This parameter takes effect when input_delta is set to False.
self.print_idx = 0
self.decode_kwargs = decode_kwargs
self.first_num_space = -1 # The number of whitespace characters before the first token.

def _align_blank_suffix(self, response: str) -> str:
# Avoid the occurrence of repeated words in sentence.
cur_num_space = len(response) - len(response.lstrip(' '))
if self.first_num_space == -1:
self.first_num_space = cur_num_space
elif cur_num_space < self.first_num_space:
response = ' ' * (self.first_num_space - cur_num_space) + response
elif cur_num_space > self.first_num_space:
response = response[cur_num_space - self.first_num_space:]
return response

def _get_printable_text(self, response: str) -> str:
# After the symbol for a new line, we flush the cache.
if response.endswith('\n'):
printable_text = response[self.print_idx:]
if self.input_delta:
self.token_cache = []
else:
self.cache_idx += len(self.token_cache)
self.print_idx = 0
# If the last token is a CJK character, we print the characters.
elif len(response) > 0 and self._is_chinese_char(ord(response[-1])):
printable_text = response[self.print_idx:]
self.print_idx += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = response[self.print_idx:response.rfind(' ') + 1]
self.print_idx += len(printable_text)
return printable_text

def put(self, tokens: List[int], is_finished: bool) -> str:
if self.input_delta:
self.token_cache += tokens
else:
self.token_cache = tokens[self.cache_idx:]
response = self.safe_decode(self.template, self.token_cache, is_finished, **self.decode_kwargs)
response = self._align_blank_suffix(response)
return self._get_printable_text(response)


class InferStats(Metric):

def __init__(self):
super().__init__()
self.add_state('start_runtime', default_factory=lambda: time.perf_counter())
self.add_state('num_prompt_tokens', default_factory=dict)
self.add_state('num_generated_tokens', default_factory=dict)
self.add_state('num_samples', 0)

def update(self, output):
if output is None:
self.num_samples += 1
return
id_ = output.id
self.num_prompt_tokens[id_] = output.usage.prompt_tokens
self.num_generated_tokens[id_] = output.usage.completion_tokens

def compute(self):
runtime = time.perf_counter() - self.start_runtime
num_generated_tokens = sum(self.num_generated_tokens.values())
return {
'num_prompt_tokens': sum(self.num_prompt_tokens.values()),
'num_generated_tokens': num_generated_tokens,
'num_samples': self.num_samples,
'runtime': runtime,
'samples/s': self.num_samples / runtime,
'tokens/s': num_generated_tokens / runtime,
}


class InferEngine(ABC):

def _prepare_model_tokenizer(self,
Expand Down Expand Up @@ -201,14 +82,18 @@ def _infer_stream(tasks,

while n_finished < len(new_tasks):
for i, queue in enumerate(queue_list):
output = queue.get(block=False)
try:
output = queue.get(block=False)
except Empty:
continue
if output is None:
n_finished += 1
prog_bar.update()
else:
outputs[i] = output
for metric in metrics:
metric.update(output)
time.sleep(0.01)
yield outputs

@staticmethod
Expand Down
126 changes: 126 additions & 0 deletions swift/llm/infer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import time
from typing import List

from swift.plugin import Metric
from ..template import Template


class InferTools:

@staticmethod
def _is_chinese_char(cp: int) -> bool:
"""Checks whether CP is the codepoint of a CJK character."""
# copy from transformers.generation.streamers.TextStreamer
if ((0x4E00 <= cp <= 0x9FFF) or (0x3400 <= cp <= 0x4DBF) or (0x20000 <= cp <= 0x2A6DF)
or (0x2A700 <= cp <= 0x2B73F) or (0x2B740 <= cp <= 0x2B81F) or (0x2B820 <= cp <= 0x2CEAF)
or (0xF900 <= cp <= 0xFAFF) or (0x2F800 <= cp <= 0x2FA1F)):
return True

return False

@staticmethod
def safe_decode(template: Template, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> str:
tokenizer = template.tokenizer

if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
generate_ids = generate_ids[:-1]
response = tokenizer.decode(generate_ids, **decode_kwargs)
if len(response) == 0 or response.endswith('\n'):
return response
# skip suffix and eos_token
template_suffix = template.suffix[-1]
if isinstance(template_suffix, list):
template_suffix = tokenizer.decode(template_suffix)
len_suffix = len(template.suffix[-1])
if not is_finished or is_finished and response[-len_suffix:] == template.suffix[-1]:
# To avoid response length being shorter than previous response length during streaming.
# TODO:check
# idx = max(len(response) - len_suffix, 0, self.print_idx)
response = response[:-len_suffix]
return response


class InferStreamer(InferTools):

def __init__(self, template, input_delta: bool = False, **decode_kwargs):
self.template = template
self.tokenizer = template.tokenizer

self.token_cache = [] # Reduce the time of tokenizer.decode
self.input_delta = input_delta
self.cache_idx = 0 # This parameter takes effect when input_delta is set to False.
self.print_idx = 0
self.decode_kwargs = decode_kwargs
self.first_num_space = -1 # The number of whitespace characters before the first token.

def _align_blank_suffix(self, response: str) -> str:
# Avoid the occurrence of repeated words in sentence.
cur_num_space = len(response) - len(response.lstrip(' '))
if self.first_num_space == -1:
self.first_num_space = cur_num_space
elif cur_num_space < self.first_num_space:
response = ' ' * (self.first_num_space - cur_num_space) + response
elif cur_num_space > self.first_num_space:
response = response[cur_num_space - self.first_num_space:]
return response

def _get_printable_text(self, response: str) -> str:
# After the symbol for a new line, we flush the cache.
if response.endswith('\n'):
printable_text = response[self.print_idx:]
if self.input_delta:
self.token_cache = []
else:
self.cache_idx += len(self.token_cache)
self.print_idx = 0
# If the last token is a CJK character, we print the characters.
elif len(response) > 0 and self._is_chinese_char(ord(response[-1])):
printable_text = response[self.print_idx:]
self.print_idx += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = response[self.print_idx:response.rfind(' ') + 1]
self.print_idx += len(printable_text)
return printable_text

def put(self, tokens: List[int], is_finished: bool) -> str:
if self.input_delta:
self.token_cache += tokens
else:
self.token_cache = tokens[self.cache_idx:]
response = self.safe_decode(self.template, self.token_cache, is_finished, **self.decode_kwargs)
response = self._align_blank_suffix(response)
return self._get_printable_text(response)


class InferStats(Metric):

def __init__(self):
super().__init__()
self.add_state('start_runtime', default_factory=lambda: time.perf_counter())
self.add_state('num_prompt_tokens', default_factory=dict)
self.add_state('num_generated_tokens', default_factory=dict)
self.add_state('num_samples', 0)

def update(self, output):
if output is None:
self.num_samples += 1
return
id_ = output.id
self.num_prompt_tokens[id_] = output.usage.prompt_tokens
self.num_generated_tokens[id_] = output.usage.completion_tokens

def compute(self):
runtime = time.perf_counter() - self.start_runtime
num_generated_tokens = sum(self.num_generated_tokens.values())
return {
'num_prompt_tokens': sum(self.num_prompt_tokens.values()),
'num_generated_tokens': num_generated_tokens,
'num_samples': self.num_samples,
'runtime': runtime,
'samples/s': self.num_samples / runtime,
'tokens/s': num_generated_tokens / runtime,
}
4 changes: 0 additions & 4 deletions swift/plugin/metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod



class Metric(ABC):

def __init__(self):
Expand All @@ -20,18 +19,15 @@ def add_state(self, name: str, default=None, default_factory=None) -> None:
assert name not in self._default_factory, f'self._default_factory: {self._default_factory}'
setattr(self, name, default)


def reset(self):
for k, v in self._default.items():
setattr(self, k, v)
for k, v in self._default_factory.items():
setattr(self, k, v())


def update(self, *args, **kwargs):
pass

@abstractmethod
def compute(self):
pass

0 comments on commit 20029cc

Please sign in to comment.