Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 12, 2024
1 parent ffd3c5d commit 6000ebf
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 34 deletions.
4 changes: 2 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
# Recommend using `xxx_main`
from .app_ui import gradio_chat_demo, gradio_generation_demo, app_ui_main
from .infer import (deploy_main, infer_main, merge_lora_main, merge_lora, VLLMFramework, VllmGenerationConfig,
from .infer import (deploy_main, infer_main, merge_lora_main, merge_lora, VllmEngine, VllmGenerationConfig,
LMDeployFramework, LmdeployGenerationConfig, TransformersFramework, InferRequest)
from .export import export_main
from .eval import eval_main
Expand All @@ -31,7 +31,7 @@
'app_ui': ['gradio_chat_demo', 'gradio_generation_demo', 'app_ui_main'],
'rlhf': ['rlhf_main'],
'infer': [
'deploy_main', 'merge_lora', 'infer_main', 'merge_lora_main', 'VLLMFramework', 'VllmGenerationConfig',
'deploy_main', 'merge_lora', 'infer_main', 'merge_lora_main', 'VllmEngine', 'VllmGenerationConfig',
'LMDeployFramework', 'LmdeployGenerationConfig', 'TransformersFramework', 'InferRequest'
],
'export': ['export_main'],
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from .deploy import deploy_main
from .infer import infer_main, merge_lora_main, merge_lora
from .vllm import VLLMFramework, VllmGenerationConfig
from .vllm import VllmEngine
from .lmdeploy import LMDeployFramework, LmdeployGenerationConfig
from .transformers import TransformersFramework
from .protocol import InferRequest
Expand All @@ -15,7 +15,7 @@
_import_structure = {
'deploy': ['deploy_main'],
'infer': ['infer_main', 'merge_lora_main', 'merge_lora'],
'vllm': ['VLLMFramework', 'VllmGenerationConfig'],
'vllm': ['VllmEngine'],
'lmdeploy': ['LMDeployFramework', 'LmdeployGenerationConfig'],
'transformers': ['TransformersFramework'],
'protocol': ['InferRequest'],
Expand Down
9 changes: 6 additions & 3 deletions swift/llm/infer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,17 @@ async def _run_infer(i, task, queue, stream: bool = False):
await task
queue.put((i, None))

@staticmethod
async def _batch_run(tasks):
return await asyncio.gather(*tasks)

@staticmethod
def _infer_stream(tasks,
use_tqdm: bool = True,
stream: bool = True) -> Iterator[List[ChatCompletionStreamResponse]]:
queue = Queue()
new_tasks = [InferEngine._run_infer(i, task, queue, stream) for i, task in enumerate(tasks)]
thread = Thread(target=asyncio.run(asyncio.gather(*new_tasks)))
thread = Thread(target=asyncio.run(InferEngine._batch_run(new_tasks)))
thread.start()
prog_bar = tqdm(total=len(new_tasks), dynamic_ncols=True, disable=not use_tqdm)
n_finished = 0
Expand All @@ -163,7 +167,6 @@ async def infer_async(self,
template: Template,
infer_request: InferRequest,
request_config: Optional[RequestConfig] = None,
*,
request_id: Optional[str] = None,
**kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
pass
Expand All @@ -176,7 +179,7 @@ def infer(
request_config: Optional[RequestConfig] = None,
*,
use_tqdm: bool = True,
**kwargs,
**kwargs
) -> Union[List[ChatCompletionResponse], Iterator[List[ChatCompletionStreamResponse]]]:
request_config = request_config or RequestConfig()

Expand Down
4 changes: 2 additions & 2 deletions swift/llm/infer/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
CompletionStreamResponse, DeltaMessage, Function, Model, ModelList, UsageInfo,
messages_join_observation, random_uuid)
from swift.llm.infer.transformers import TransformersFramework
from swift.llm.infer.vllm import VLLMFramework
from swift.llm.infer.vllm import VllmEngine
from swift.llm.utils import set_generation_config
from swift.utils import get_logger, get_main, get_seed, seed_everything
from swift.utils.utils import split_action_action_input
Expand Down Expand Up @@ -892,7 +892,7 @@ def llm_deploy(args: DeployArguments) -> None:
if args.merge_lora:
merge_lora(args, device_map=args.merge_device_map)
if args.infer_backend == 'vllm':
framework = VLLMFramework(args, use_async=True)
framework = VllmEngine(args, use_async=True)
elif args.infer_backend == 'lmdeploy':
framework = LMDeployFramework(args, use_async=True)
else:
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ def llm_infer(args: InferArguments) -> Dict[str, List[Dict[str, Any]]]:
merge_lora(args, device_map=args.merge_device_map)

if args.infer_backend == 'vllm':
from .vllm import VLLMFramework
framework = VLLMFramework(args)
from .vllm import VllmEngine
framework = VllmEngine(args)
elif args.infer_backend == 'lmdeploy':
from .lmdeploy import LMDeployFramework
framework = LMDeployFramework(args)
Expand Down
36 changes: 14 additions & 22 deletions swift/llm/infer/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from modelscope import GenerationConfig
from packaging import version
from transformers import PreTrainedTokenizerBase
from vllm import AsyncEngineArgs, AsyncLLMEngine, EngineArgs, LLMEngine, SamplingParams
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams

from swift.utils import get_logger
from ..template import Template, split_action_action_input
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(

def _prepare_engine(self) -> None:
with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config):
engine = AsyncLLMEngine.from_engine_args()
engine = AsyncLLMEngine.from_engine_args(self.engine_args)
self.engine = engine

def _prepare_engine_kwargs(
Expand Down Expand Up @@ -139,28 +139,20 @@ def _init_env() -> None:
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

def _fix_vllm_bug(self) -> None:
_engine = self.engine.engine
# compatible with vllm==0.3.*
if version.parse(vllm.__version__) >= version.parse('0.3'):
assert isinstance(_engine.tokenizer.tokenizer, PreTrainedTokenizerBase)
_engine.tokenizer.tokenizer = self.tokenizer
# fix vllm==0.4 bug (very slow)
tokenizer = self.tokenizer
if version.parse(vllm.__version__) >= version.parse('0.4') and not tokenizer.__class__.__name__.startswith("Cached"):
_tokenizer_len = len(tokenizer)
__old_len__ = tokenizer.__class__.__len__

# fix vllm==0.4 bug (very slow)
if version.parse(vllm.__version__) >= version.parse('0.4'):
_tokenizer_len = len(self.tokenizer)
__old_len__ = self.tokenizer.__class__.__len__
def __len__(self) -> int:
if self is tokenizer:
return _tokenizer_len
else:
return __old_len__(self)

def __len__(self) -> int:
if self is self.tokenizer:
return _tokenizer_len
else:
return __old_len__(self)

self.tokenizer.__class__.__len__ = __len__
tokenizer.__class__.__len__ = __len__

else:
assert isinstance(_engine.tokenizer, PreTrainedTokenizerBase)
_engine.tokenizer = self.tokenizer

def _load_generation_config(self) -> None:
generation_config_path = os.path.join(self.model_dir, 'generation_config.json')
Expand Down Expand Up @@ -351,7 +343,7 @@ async def infer_async(
created_time = int(time.time())
request_config = request_config or RequestConfig()

inputs = template.encode(infer_request)
inputs = template.encode(infer_request) # TODO: `{}`
request_config.stop = self._get_stop_words(request_config, template)
generation_config = self._prepare_generation_config(request_config)
result_generator = self._add_vllm_request(inputs, generation_config, request_id, lora_request=lora_request)
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/template/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StopWordsCriteria(StoppingCriteria):
Like suffixes and chat seps in the template.
"""

def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: List[StopWord], **tokenizer_kwargs) -> None:
def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: List[Word], **tokenizer_kwargs) -> None:
self.tokenizer = tokenizer
self.stop_words = stop_words
self.tokenizer_kwargs = tokenizer_kwargs
Expand Down

0 comments on commit 6000ebf

Please sign in to comment.