Skip to content

Commit

Permalink
support batched pt infer
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 17, 2024
1 parent 939643a commit 797ece3
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 128 deletions.
4 changes: 2 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Recommend using `xxx_main`
from .app_ui import gradio_chat_demo, gradio_generation_demo, app_ui_main
from .infer import (deploy_main, infer_main, merge_lora_main, merge_lora, VllmEngine, InferRequest, RequestConfig,
InferStats, LmdeployEngine)
InferStats, LmdeployEngine, PtEngine)
from .export import export_main
from .eval import eval_main
from .train import sft_main, pt_main, rlhf_main
Expand All @@ -32,7 +32,7 @@
'rlhf': ['rlhf_main'],
'infer': [
'deploy_main', 'merge_lora', 'infer_main', 'merge_lora_main', 'VllmEngine', 'InferRequest', 'RequestConfig',
'InferStats', 'LmdeployEngine'
'InferStats', 'LmdeployEngine', 'PtEngine'
],
'export': ['export_main'],
'eval': ['eval_main'],
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
'infer': ['infer_main', 'merge_lora_main', 'merge_lora'],
'vllm_engine': ['VllmEngine'],
'lmdeploy_engine': ['LmdeployEngine'],
'pt_engine': ['TransformersFramework'],
'pt_engine': ['PtEngine'],
'protocol': ['InferRequest', 'RequestConfig'],
}

Expand Down
58 changes: 35 additions & 23 deletions swift/llm/infer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,7 @@ async def __batch_run(tasks):
@staticmethod
def __infer_stream(tasks,
stream: bool = True,
use_tqdm: bool = True,
metrics: Optional[List[Metric]] = None) -> Iterator[List[ChatCompletionStreamResponse]]:
if metrics is None:
metrics = []
for metric in metrics:
metric.reset()
use_tqdm: bool = True) -> Iterator[List[ChatCompletionStreamResponse]]:

queue = Queue()
new_tasks = [InferEngine.__run_infer(i, task, queue, stream) for i, task in enumerate(tasks)]
Expand All @@ -107,15 +102,11 @@ def __infer_stream(tasks,
yield outputs
outputs = [None] * len(new_tasks)
outputs[i] = output
for metric in metrics:
metric.update(output)
yield outputs

@staticmethod
def __infer_full(tasks,
use_tqdm: bool = True,
metrics: Optional[List[Metric]] = None) -> List[ChatCompletionResponse]:
for outputs in InferEngine.__infer_stream(tasks, False, use_tqdm, metrics):
def __infer_full(tasks, use_tqdm: bool = True) -> List[ChatCompletionResponse]:
for outputs in InferEngine.__infer_stream(tasks, False, use_tqdm):
pass
return outputs

Expand All @@ -127,29 +118,46 @@ def _get_usage_info(num_prompt_tokens: int, num_generated_tokens: int) -> UsageI
total_tokens=num_prompt_tokens + num_generated_tokens,
)

@staticmethod
def _update_metrics_wrapper(gen, metrics: Optional[List[Metric]] = None):
for res in gen:
yield InferEngine._update_metrics(res, metrics)

@staticmethod
def _update_metrics(result, metrics: Optional[List[Metric]] = None):
result_origin = result
if not isinstance(result, (list, tuple)):
result = [result]
for response in result:
for metric in metrics:
metric.update(response)
return result_origin

@torch.inference_mode()
def infer(self,
template: Template,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
metrics: Optional[List[Metric]] = None,
*,
use_tqdm: Optional[bool] = None,
metrics: Optional[List[Metric]] = None,
**kwargs) -> Union[List[ChatCompletionResponse], Iterator[List[ChatCompletionStreamResponse]]]:
tasks = [
self.infer_async(template, infer_request, request_config, **kwargs) for infer_request in infer_requests
]
if use_tqdm is None:
use_tqdm = not request_config.stream
if request_config.stream:
return self.__infer_stream(tasks, True, use_tqdm, metrics)
return self._update_metrics_wrapper(self.__infer_stream(tasks, True, use_tqdm), metrics)
else:
return self.__infer_full(tasks, use_tqdm, metrics)
return self._update_metrics(self.__infer_full(tasks, use_tqdm), metrics)

@staticmethod
def _get_toolcall(response: str, is_finished: bool) -> Optional[List[ChatCompletionMessageToolCall]]:
def _get_toolcall(self, response: Union[str, List[int]],
is_finished: bool) -> Optional[List[ChatCompletionMessageToolCall]]:
if is_finished:
return None
if not isinstance(response, str):
response = self.tokenizer.decode(response)
action, action_input = split_action_action_input(response)
if action is None:
return None
Expand Down Expand Up @@ -177,23 +185,27 @@ async def infer_async(self,

@staticmethod
def _get_num_tokens(inputs: Dict[str, Any]) -> int:
if 'input_ids' in inputs: # 1d
return len(inputs['input_ids'])
elif 'inputs_embeds' in inputs: # 2d
return inputs['inputs_embeds'].shape[0]
if 'input_ids' in inputs: # 1d or 2d
input_ids = inputs['input_ids']
if isinstance(input_ids, list):
return len(input_ids)
else:
return input_ids.shape[-1]
elif 'inputs_embeds' in inputs: # 2d or 3d
return inputs['inputs_embeds'].shape[-1]
raise ValueError(f'Unable to retrieve input_ids and inputs_embeds. inputs: {inputs}')

def set_default_max_tokens(self,
request_config: RequestConfig,
inputs: Union[Dict[str, Any], List[Dict[str, Any]]],
inputs: Dict[str, Any],
strict: bool = False) -> None:
max_model_len = self.max_model_len
if isinstance(inputs, dict):
inputs = [inputs]
# The num_tokens takes the maximum value from inputs_list.
num_tokens = 0
for inp in inputs:
num_tokens = max(num_tokens, InferEngine._get_num_tokens(inp))
num_tokens = max(num_tokens, self._get_num_tokens(inp))
max_tokens = request_config.max_tokens
if max_model_len is None:
max_model_len = 8192
Expand Down
12 changes: 6 additions & 6 deletions swift/llm/infer/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,10 @@ async def _generate_stream():
choices.append(choice)
response = CompletionStreamResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
yield f'data:{json.dumps(asdict(response), ensure_ascii=False)}\n\n'
yield f'data: {json.dumps(asdict(response), ensure_ascii=False)}\n\n'
if _args.log_interval > 0:
_update_stats(response)
yield 'data:[DONE]\n\n'
yield 'data: [DONE]\n\n'

if request.stream:
return StreamingResponse(_generate_stream())
Expand Down Expand Up @@ -618,10 +618,10 @@ async def _generate_stream():
choices = [CompletionResponseStreamChoice(index=0, text=delta_text, finish_reason=finish_reason)]
response = CompletionStreamResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
yield f'data:{json.dumps(asdict(response), ensure_ascii=False)}\n\n'
yield f'data: {json.dumps(asdict(response), ensure_ascii=False)}\n\n'
if _args.log_interval > 0:
_update_stats(response)
yield 'data:[DONE]\n\n'
yield 'data: [DONE]\n\n'

if request.stream:
return StreamingResponse(_generate_stream())
Expand Down Expand Up @@ -842,10 +842,10 @@ def _generate_stream():
choices = [CompletionResponseStreamChoice(index=0, text=delta_text, finish_reason=None)]
resp = CompletionStreamResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
yield f'data:{json.dumps(asdict(resp), ensure_ascii=False)}\n\n'
yield f'data: {json.dumps(asdict(resp), ensure_ascii=False)}\n\n'
if _args.log_interval > 0:
_update_stats(resp)
yield 'data:[DONE]\n\n'
yield 'data: [DONE]\n\n'

if request.stream:
return StreamingResponse(_generate_stream())
Expand Down
23 changes: 11 additions & 12 deletions swift/llm/infer/lmdeploy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .base import InferEngine
from .patch import patch_auto_config, patch_auto_tokenizer
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, UsageInfo)
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, UsageInfo, random_uuid)
from .utils import InferStreamer, InferTools

logger = get_logger()
Expand Down Expand Up @@ -178,10 +178,10 @@ async def _infer_stream_async(
self, template: Template, inputs: Dict[str, Any],
generation_config: LmdeployGenerationConfig) -> AsyncIterator[ChatCompletionStreamResponse]:
session_id = time.time_ns()
request_id = f'chatcmpl-{random_uuid()}'
generator = await self._add_request(template, inputs, session_id)

infer_streamer = InferStreamer(template)
total_response = ''
async with self.engine.safe_run(session_id):
async_iter = generator.async_stream_infer(
session_id=session_id, **inputs, stream_output=True, gen_config=generation_config).__aiter__()
Expand All @@ -195,17 +195,17 @@ async def _infer_stream_async(
if not delta_text and not is_finished:
continue

usage_info = InferEngine._get_usage_info(len(inputs['input_ids']), output.num_token)
total_response += delta_text
finish_reason = LmdeployEngine._get_finish_reason(output, generation_config)
toolcall = InferEngine._get_toolcall(total_response, is_finished)
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
finish_reason = self._get_finish_reason(output, generation_config)
toolcall = self._get_toolcall(output.token_ids, is_finished)
choices = [
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall),
finish_reason=finish_reason)
]
yield ChatCompletionStreamResponse(model=self.model_type, choices=choices, usage=usage_info)
yield ChatCompletionStreamResponse(
model=self.model_dir, choices=choices, usage=usage_info, id=request_id)

async def _infer_full_async(self, template: Template, inputs: Dict[str, Any],
generation_config: LmdeployGenerationConfig) -> ChatCompletionResponse:
Expand All @@ -217,17 +217,16 @@ async def _infer_full_async(self, template: Template, inputs: Dict[str, Any],
session_id=session_id, **inputs, stream_output=False, gen_config=generation_config):
pass

usage_info = InferEngine._get_usage_info(len(inputs['input_ids']), output.num_token)
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
response = InferTools.safe_decode(template, output.token_ids, True)
logprobs = self._get_logprobs(template.tokenizer, output.logprobs, output.token_ids, generation_config.logprobs)
finish_reason = LmdeployEngine._get_finish_reason(output, generation_config)
toolcall = InferEngine._get_toolcall(response, True)
finish_reason = self._get_finish_reason(output, generation_config)
toolcall = self._get_toolcall(response, True)
choices = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs)
]
response = ChatCompletionResponse(model=self.model_type, choices=choices, usage=usage_info)
return response
return ChatCompletionResponse(model=self.model_dir, choices=choices, usage=usage_info)
Loading

0 comments on commit 797ece3

Please sign in to comment.