diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index 5dc21d2ba..c5b08daf1 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: # Recommend using `xxx_main` from .app_ui import gradio_chat_demo, gradio_generation_demo, app_ui_main - from .infer import (deploy_main, infer_main, merge_lora_main, merge_lora, VLLMFramework, VllmGenerationConfig, + from .infer import (deploy_main, infer_main, merge_lora_main, merge_lora, VllmEngine, VllmGenerationConfig, LMDeployFramework, LmdeployGenerationConfig, TransformersFramework, InferRequest) from .export import export_main from .eval import eval_main @@ -31,7 +31,7 @@ 'app_ui': ['gradio_chat_demo', 'gradio_generation_demo', 'app_ui_main'], 'rlhf': ['rlhf_main'], 'infer': [ - 'deploy_main', 'merge_lora', 'infer_main', 'merge_lora_main', 'VLLMFramework', 'VllmGenerationConfig', + 'deploy_main', 'merge_lora', 'infer_main', 'merge_lora_main', 'VllmEngine', 'VllmGenerationConfig', 'LMDeployFramework', 'LmdeployGenerationConfig', 'TransformersFramework', 'InferRequest' ], 'export': ['export_main'], diff --git a/swift/llm/infer/__init__.py b/swift/llm/infer/__init__.py index 2ce569243..2eaeddb12 100644 --- a/swift/llm/infer/__init__.py +++ b/swift/llm/infer/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from .deploy import deploy_main from .infer import infer_main, merge_lora_main, merge_lora - from .vllm import VLLMFramework, VllmGenerationConfig + from .vllm import VllmEngine from .lmdeploy import LMDeployFramework, LmdeployGenerationConfig from .transformers import TransformersFramework from .protocol import InferRequest @@ -15,7 +15,7 @@ _import_structure = { 'deploy': ['deploy_main'], 'infer': ['infer_main', 'merge_lora_main', 'merge_lora'], - 'vllm': ['VLLMFramework', 'VllmGenerationConfig'], + 'vllm': ['VllmEngine'], 'lmdeploy': ['LMDeployFramework', 'LmdeployGenerationConfig'], 'transformers': ['TransformersFramework'], 'protocol': ['InferRequest'], diff --git a/swift/llm/infer/base.py b/swift/llm/infer/base.py index bf4520d26..99ac08aa0 100644 --- a/swift/llm/infer/base.py +++ b/swift/llm/infer/base.py @@ -133,13 +133,17 @@ async def _run_infer(i, task, queue, stream: bool = False): await task queue.put((i, None)) + @staticmethod + async def _batch_run(tasks): + return await asyncio.gather(*tasks) + @staticmethod def _infer_stream(tasks, use_tqdm: bool = True, stream: bool = True) -> Iterator[List[ChatCompletionStreamResponse]]: queue = Queue() new_tasks = [InferEngine._run_infer(i, task, queue, stream) for i, task in enumerate(tasks)] - thread = Thread(target=asyncio.run(asyncio.gather(*new_tasks))) + thread = Thread(target=asyncio.run(InferEngine._batch_run(new_tasks))) thread.start() prog_bar = tqdm(total=len(new_tasks), dynamic_ncols=True, disable=not use_tqdm) n_finished = 0 @@ -163,7 +167,6 @@ async def infer_async(self, template: Template, infer_request: InferRequest, request_config: Optional[RequestConfig] = None, - *, request_id: Optional[str] = None, **kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: pass @@ -176,7 +179,7 @@ def infer( request_config: Optional[RequestConfig] = None, *, use_tqdm: bool = True, - **kwargs, + **kwargs ) -> Union[List[ChatCompletionResponse], Iterator[List[ChatCompletionStreamResponse]]]: request_config = request_config or RequestConfig() diff --git a/swift/llm/infer/deploy.py b/swift/llm/infer/deploy.py index 212771208..3a16d500c 100644 --- a/swift/llm/infer/deploy.py +++ b/swift/llm/infer/deploy.py @@ -29,7 +29,7 @@ CompletionStreamResponse, DeltaMessage, Function, Model, ModelList, UsageInfo, messages_join_observation, random_uuid) from swift.llm.infer.transformers import TransformersFramework -from swift.llm.infer.vllm import VLLMFramework +from swift.llm.infer.vllm import VllmEngine from swift.llm.utils import set_generation_config from swift.utils import get_logger, get_main, get_seed, seed_everything from swift.utils.utils import split_action_action_input @@ -892,7 +892,7 @@ def llm_deploy(args: DeployArguments) -> None: if args.merge_lora: merge_lora(args, device_map=args.merge_device_map) if args.infer_backend == 'vllm': - framework = VLLMFramework(args, use_async=True) + framework = VllmEngine(args, use_async=True) elif args.infer_backend == 'lmdeploy': framework = LMDeployFramework(args, use_async=True) else: diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 00c23563f..82115081f 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -344,8 +344,8 @@ def llm_infer(args: InferArguments) -> Dict[str, List[Dict[str, Any]]]: merge_lora(args, device_map=args.merge_device_map) if args.infer_backend == 'vllm': - from .vllm import VLLMFramework - framework = VLLMFramework(args) + from .vllm import VllmEngine + framework = VllmEngine(args) elif args.infer_backend == 'lmdeploy': from .lmdeploy import LMDeployFramework framework = LMDeployFramework(args) diff --git a/swift/llm/infer/vllm.py b/swift/llm/infer/vllm.py index 7fa42d7cc..715530185 100644 --- a/swift/llm/infer/vllm.py +++ b/swift/llm/infer/vllm.py @@ -8,7 +8,7 @@ from modelscope import GenerationConfig from packaging import version from transformers import PreTrainedTokenizerBase -from vllm import AsyncEngineArgs, AsyncLLMEngine, EngineArgs, LLMEngine, SamplingParams +from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams from swift.utils import get_logger from ..template import Template, split_action_action_input @@ -72,7 +72,7 @@ def __init__( def _prepare_engine(self) -> None: with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config): - engine = AsyncLLMEngine.from_engine_args() + engine = AsyncLLMEngine.from_engine_args(self.engine_args) self.engine = engine def _prepare_engine_kwargs( @@ -139,28 +139,20 @@ def _init_env() -> None: os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' def _fix_vllm_bug(self) -> None: - _engine = self.engine.engine - # compatible with vllm==0.3.* - if version.parse(vllm.__version__) >= version.parse('0.3'): - assert isinstance(_engine.tokenizer.tokenizer, PreTrainedTokenizerBase) - _engine.tokenizer.tokenizer = self.tokenizer + # fix vllm==0.4 bug (very slow) + tokenizer = self.tokenizer + if version.parse(vllm.__version__) >= version.parse('0.4') and not tokenizer.__class__.__name__.startswith("Cached"): + _tokenizer_len = len(tokenizer) + __old_len__ = tokenizer.__class__.__len__ - # fix vllm==0.4 bug (very slow) - if version.parse(vllm.__version__) >= version.parse('0.4'): - _tokenizer_len = len(self.tokenizer) - __old_len__ = self.tokenizer.__class__.__len__ + def __len__(self) -> int: + if self is tokenizer: + return _tokenizer_len + else: + return __old_len__(self) - def __len__(self) -> int: - if self is self.tokenizer: - return _tokenizer_len - else: - return __old_len__(self) - - self.tokenizer.__class__.__len__ = __len__ + tokenizer.__class__.__len__ = __len__ - else: - assert isinstance(_engine.tokenizer, PreTrainedTokenizerBase) - _engine.tokenizer = self.tokenizer def _load_generation_config(self) -> None: generation_config_path = os.path.join(self.model_dir, 'generation_config.json') @@ -351,7 +343,7 @@ async def infer_async( created_time = int(time.time()) request_config = request_config or RequestConfig() - inputs = template.encode(infer_request) + inputs = template.encode(infer_request) # TODO: `{}` request_config.stop = self._get_stop_words(request_config, template) generation_config = self._prepare_generation_config(request_config) result_generator = self._add_vllm_request(inputs, generation_config, request_id, lora_request=lora_request) diff --git a/swift/llm/template/utils.py b/swift/llm/template/utils.py index 91e4e34ef..5fd9dd775 100644 --- a/swift/llm/template/utils.py +++ b/swift/llm/template/utils.py @@ -27,7 +27,7 @@ class StopWordsCriteria(StoppingCriteria): Like suffixes and chat seps in the template. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: List[StopWord], **tokenizer_kwargs) -> None: + def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: List[Word], **tokenizer_kwargs) -> None: self.tokenizer = tokenizer self.stop_words = stop_words self.tokenizer_kwargs = tokenizer_kwargs