From b3f1e095684ab0599923b032a755d8a71fa5887d Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Sun, 29 Sep 2024 18:18:29 +0800 Subject: [PATCH 01/16] Initial updates for vllm 0.6.2 --- .../src/ipex_llm/vllm/xpu/engine/__init__.py | 3 +- .../src/ipex_llm/vllm/xpu/engine/engine.py | 26 + .../vllm/xpu/entrypoints/openai/api_server.py | 451 +++++++++++++----- .../vllm/xpu/entrypoints/openai/cli_args.py | 85 +++- .../vllm/xpu/entrypoints/openai/rpc/server.py | 221 --------- .../src/ipex_llm/vllm/xpu/model_convert.py | 5 +- 6 files changed, 433 insertions(+), 358 deletions(-) delete mode 100644 python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/rpc/server.py diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/__init__.py b/python/llm/src/ipex_llm/vllm/xpu/engine/__init__.py index 7b653c9b729..a3cec88f2b0 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/__init__.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/__init__.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass +from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine __all__ = [ "IPEXLLMAsyncLLMEngine", "IPEXLLMLLMEngine", "IPEXLLMClass", + "run_mp_engine", ] diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index 0a3a87414a0..3d3c1ee3152 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -22,6 +22,8 @@ from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert from vllm.usage.usage_lib import UsageContext from vllm.engine.metrics import StatLoggerBase +from vllm.engine.multiprocessing.engine import MQLLMEngine +import signal class IPEXLLMAsyncLLMEngine(AsyncLLMEngine): @@ -117,3 +119,27 @@ def from_engine_args( # Create the engine configs. _ipex_llm_convert(load_in_low_bit) return super().from_engine_args(engine_args, usage_context, stat_loggers) + + +class IPEXLLMMQLLMEngine(MQLLMEngine): + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str, load_in_low_bit: str): + _ipex_llm_convert(load_in_low_bit) + return super().from_engine_args(engine_args, usage_context, ipc_path) + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, + ipc_path: str, load_in_low_bit: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm + raise KeyboardInterrupt("MQLLMEngine terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path, + load_in_low_bit=load_in_low_bit) + engine.start() diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py index b6ea51c1010..45d4ed2ea5c 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py @@ -1,25 +1,34 @@ import asyncio import importlib import inspect +import multiprocessing +import os import re +import signal +import socket +import tempfile from argparse import Namespace from contextlib import asynccontextmanager +from functools import partial from http import HTTPStatus -from multiprocessing import Process from typing import AsyncIterator, Set +import uvloop from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import make_asgi_app +from starlette.datastructures import State from starlette.routing import Mount +from typing_extensions import assert_never import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from ipex_llm.vllm.xpu.engine import run_mp_engine +from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser @@ -28,154 +37,268 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, + CompletionResponse, DetokenizeRequest, DetokenizeResponse, - EmbeddingRequest, ErrorResponse, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, + LoadLoraAdapterRequest, TokenizeRequest, - TokenizeResponse) -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from ipex_llm.vllm.xpu.entrypoints.openai.rpc.server import run_rpc_server + TokenizeResponse, + UnloadLoraAdapterRequest) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, get_open_port +from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds -async_engine_client: AsyncEngineClient -engine_args: AsyncEngineArgs -openai_serving_chat: OpenAIServingChat -openai_serving_completion: OpenAIServingCompletion -openai_serving_embedding: OpenAIServingEmbedding -openai_serving_tokenization: OpenAIServingTokenization +prometheus_multiproc_dir: tempfile.TemporaryDirectory +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger('vllm.entrypoints.openai.api_server') _running_tasks: Set[asyncio.Task] = set() -def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool: - return ModelConfig(model=model_name, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - seed=0, - dtype="float16").embedding_mode - - @asynccontextmanager async def lifespan(app: FastAPI): + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10.) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state + - async def _force_log(): - while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() +@asynccontextmanager +async def build_async_engine_client( + args: Namespace) -> AsyncIterator[EngineClient]: - if not engine_args.disable_log_stats: - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) + # Context manager to handle engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) - yield + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine: + yield engine @asynccontextmanager -async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: - # Context manager to handle async_engine_client lifecycle - # Ensures everything is shutdown and cleaned up on error/exit - global engine_args - engine_args = AsyncEngineArgs.from_cli_args(args) +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, + load_in_low_bit: str = 'sym_int4', +) -> AsyncIterator[EngineClient]: + """ + Create EngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + + # Fall back + # TODO: fill out feature matrix. + if (MQLLMEngineClient.is_unsupported_config(engine_args) + or disable_frontend_multiprocessing): + engine_config = engine_args.create_engine_config() + uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), + "uses_ray", False) + + build_engine = partial(AsyncLLMEngine.from_engine_args, + engine_args=engine_args, + engine_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER) + if uses_ray: + # Must run in main thread with ray for its signal handlers to work + engine_client = build_engine() + else: + engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_engine) - # Backend itself still global for the silly lil' health handler - global async_engine_client - - # If manually triggered or embedding model, use AsyncLLMEngine in process. - # TODO: support embedding model via RPC. - if (model_is_embedding(args.model, args.trust_remote_code) - or args.disable_frontend_multiprocessing): - async_engine_client = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER, - load_in_low_bit=args.load_in_low_bit) - yield async_engine_client + yield engine_client return # Otherwise, use the multiprocessing AsyncLLMEngine. else: - # Start RPCServer in separate process (holds the AsyncLLMEngine). - port = get_open_port(envs.VLLM_RPC_PORT) - load_in_low_bit = args.load_in_low_bit - rpc_server_process = Process(target=run_rpc_server, - args=(engine_args, - UsageContext.OPENAI_API_SERVER, - port, load_in_low_bit)) - rpc_server_process.start() - - # Build RPCClient, which conforms to AsyncEngineClient Protocol. - async_engine_client = AsyncEngineRPCClient(port) - await async_engine_client.setup() + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + global prometheus_multiproc_dir + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ[ + "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + else: + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + # Select random path for IPC. + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) + + # Start RPCServer in separate process (holds the LLMEngine). + # the current process might have CUDA context, + # so we need to spawn a new process + context = multiprocessing.get_context("spawn") + + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path, + load_in_low_bit)) + engine_process.start() + logger.info("Started engine process with PID %d", engine_process.pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) try: - yield async_engine_client + while True: + try: + await mp_engine_client.setup() + break + except TimeoutError: + if not engine_process.is_alive(): + raise RuntimeError( + "Engine process failed to start") from None + + yield mp_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - async_engine_client.close() + mp_engine_client.close() + + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() - # Wait for server process to join - rpc_server_process.join() + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import multiprocess + multiprocess.mark_process_dead(engine_process.pid) router = APIRouter() def mount_metrics(app: FastAPI): - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app()) + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) + if prometheus_multiproc_dir_path is not None: + logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + else: + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app()) + # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile('^/metrics(?P.*)$') + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") app.routes.append(metrics_route) +def chat(request: Request) -> OpenAIServingChat: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> OpenAIServingCompletion: + return request.app.state.openai_serving_completion + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def embedding(request: Request) -> OpenAIServingEmbedding: + return request.app.state.openai_serving_embedding + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + @router.get("/health") -async def health() -> Response: +async def health(raw_request: Request) -> Response: """Health check.""" - await async_engine_client.check_health() + await engine_client(raw_request).check_health() return Response(status_code=200) @router.post("/tokenize") -async def tokenize(request: TokenizeRequest): - generator = await openai_serving_tokenization.create_tokenize(request) +async def tokenize(request: TokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: - assert isinstance(generator, TokenizeResponse) + elif isinstance(generator, TokenizeResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + @router.post("/detokenize") -async def detokenize(request: DetokenizeRequest): - generator = await openai_serving_tokenization.create_detokenize(request) +async def detokenize(request: DetokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: - assert isinstance(generator, DetokenizeResponse) + elif isinstance(generator, DetokenizeResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + @router.get("/v1/models") -async def show_available_models(): - models = await openai_serving_completion.show_available_models() +async def show_available_models(raw_request: Request): + models = await completion(raw_request).show_available_models() return JSONResponse(content=models.model_dump()) @@ -188,46 +311,110 @@ async def show_version(): @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): - generator = await openai_serving_chat.create_chat_completion( + + generator = await chat(raw_request).create_chat_completion( request, raw_request) + if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, - media_type="text/event-stream") - else: - assert isinstance(generator, ChatCompletionResponse) + + elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) + return StreamingResponse(content=generator, media_type="text/event-stream") + @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - generator = await openai_serving_completion.create_completion( + generator = await completion(raw_request).create_completion( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, - media_type="text/event-stream") - else: + elif isinstance(generator, CompletionResponse): return JSONResponse(content=generator.model_dump()) + return StreamingResponse(content=generator, media_type="text/event-stream") + @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): - generator = await openai_serving_embedding.create_embedding( + generator = await embedding(raw_request).create_embedding( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: + elif isinstance(generator, EmbeddingResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + + +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "Lora dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter") + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await completion(raw_request).load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter") + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await completion(raw_request).unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + def build_app(args: Namespace) -> FastAPI: - app = FastAPI(lifespan=lifespan) + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path @@ -243,7 +430,8 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - err = openai_serving_chat.create_error_response(message=str(exc)) + chat = app.state.openai_serving_chat + err = chat.create_error_response(message=str(exc)) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -275,74 +463,87 @@ async def authentication(request: Request, call_next): return app -async def init_app( - async_engine_client: AsyncEngineClient, +def init_app_state( + engine_client: EngineClient, + model_config: ModelConfig, + state: State, args: Namespace, -) -> FastAPI: - app = build_app(args) - +) -> None: if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] - model_config = await async_engine_client.get_model_config() - if args.disable_log_requests: request_logger = None else: request_logger = RequestLogger(max_log_len=args.max_log_len) - global openai_serving_chat - global openai_serving_completion - global openai_serving_embedding - global openai_serving_tokenization + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] - openai_serving_chat = OpenAIServingChat( - async_engine_client, + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + + state.openai_serving_chat = OpenAIServingChat( + engine_client, model_config, - served_model_names, + base_model_paths, args.response_role, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, chat_template=args.chat_template, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) - openai_serving_completion = OpenAIServingCompletion( - async_engine_client, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser) + state.openai_serving_completion = OpenAIServingCompletion( + engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) - openai_serving_embedding = OpenAIServingEmbedding( - async_engine_client, + state.openai_serving_embedding = OpenAIServingEmbedding( + engine_client, model_config, - served_model_names, + base_model_paths, request_logger=request_logger, ) - openai_serving_tokenization = OpenAIServingTokenization( - async_engine_client, + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=args.lora_modules, request_logger=request_logger, chat_template=args.chat_template, ) - app.root_path = args.root_path - - return app async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - async with build_async_engine_client(args) as async_engine_client: - app = await init_app(async_engine_client, args) + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + temp_socket.bind(("", args.port)) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) + + temp_socket.close() shutdown_task = await serve_http( app, @@ -369,4 +570,4 @@ async def run_server(args, **uvicorn_kwargs) -> None: parser = make_arg_parser(parser) args = parser.parse_args() - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py index 70af2a0389d..88a62ba400d 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py @@ -7,6 +7,7 @@ import argparse import json import ssl +from typing import List, Optional, Sequence, Union from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, @@ -16,18 +17,55 @@ class LoRAParserAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - lora_list = [] + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + lora_list: List[LoRAModulePath] = [] for item in values: - name, path = item.split('=') - lora_list.append(LoRAModulePath(name, path)) + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) setattr(namespace, self.dest, lora_list) class PromptAdapterParserAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - adapter_list = [] + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + adapter_list: List[PromptAdapterPath] = [] for item in values: name, path = item.split('=') adapter_list.append(PromptAdapterPath(name, path)) @@ -72,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, nargs='+', action=LoRAParserAction, - help="LoRA module configurations in the format name=path. " - "Multiple modules can be specified.") + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): 'name=path' " + "Example (new format): " + "'{\"name\": \"name\", \"local_path\": \"path\", " + "\"base_model_name\": \"id\"}'") parser.add_argument( "--prompt-adapters", type=nullable_str, @@ -139,6 +181,26 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action="store_true", help="If specified, will run the OpenAI frontend server in the same " "process as the model serving engine.") + + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help= + "Enable auto tool choice for supported models. Use --tool-call-parser" + "to specify which parser to use") + + parser.add_argument( + "--tool-call-parser", + type=str, + choices=["mistral", "hermes"], + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for --enable-auto-tool-choice.") + + parser.add_argument( "--load-in-low-bit", type=str, @@ -154,6 +216,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'ID numbers being printed in log.' '\n\nDefault: Unlimited') + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" + ) + return parser diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/rpc/server.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/rpc/server.py deleted file mode 100644 index f9c7778c0d3..00000000000 --- a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/rpc/server.py +++ /dev/null @@ -1,221 +0,0 @@ -import asyncio -import signal -from typing import Any, Coroutine - -import cloudpickle -import zmq -import zmq.asyncio -from typing_extensions import Never - -from vllm import AsyncEngineArgs -from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine - -from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger(__name__) - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, port: int, load_in_low_bit: str): - # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, - usage_context=usage_context, - load_in_low_bit=load_in_low_bit) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket for readiness state. - self.socket = self.context.socket(zmq.constants.ROUTER) - # Note numeric form of localhost should be used for zmq bind(), - # see https://stackoverflow.com/a/8958414 - self.socket.bind(f"tcp://127.0.0.1:{port}") - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - - async def get_model_config(self, identity): - """Send the ModelConfig""" - model_config = await self.engine.get_model_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(model_config)]) - - async def get_decoding_config(self, identity): - """Send the DecodingConfig""" - decoding_config = await self.engine.get_decoding_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(decoding_config)]) - - async def get_lora_config(self, identity): - lora_config = await self.engine.get_lora_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(lora_config)]) - - async def get_scheduler_config(self, identity): - """Send the SchedulerConfig""" - parallel_config = await self.engine.get_scheduler_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(parallel_config)]) - - async def get_parallel_config(self, identity): - """Send the ParallelConfig""" - parallel_config = await self.engine.get_parallel_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(parallel_config)]) - - async def is_tracing_enabled(self, identity): - """Send the is_tracing_enabled flag""" - tracing_flag = await self.engine.is_tracing_enabled() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(tracing_flag)]) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - - # Send confirmation to the client. - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - try: - results_generator = self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - trace_headers=generate_request.trace_headers, - prompt_adapter_request=generate_request.prompt_adapter_request) - - async for request_output in results_generator: - await self.socket.send_multipart( - [identity, cloudpickle.dumps(request_output)]) - - except Exception as e: - # Notify client of all failures - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) - except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) - - def _make_handler_coro(self, identity, - message) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - return self.get_model_config(identity) - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - return self.get_parallel_config(identity) - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - return self.get_decoding_config(identity) - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - return self.get_scheduler_config(identity) - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - return self.get_lora_config(identity) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.CHECK_HEALTH: - return self.check_health(identity) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - return self.is_tracing_enabled(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") # noqa - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") # noqa - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart() - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - -async def run_server(server: AsyncEngineRPCServer): - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("vLLM ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, port: int, load_in_low_bit: str): - server = AsyncEngineRPCServer(async_engine_args, usage_context, port, load_in_low_bit) - asyncio.run(run_server(server)) diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 355f3dc2578..2086319ca6a 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -75,14 +75,13 @@ def _ipex_llm_load_model(self) -> None: _model_sample_convert() # from vllm.utils import measure_device_memory - from vllm.utils import CudaMemoryProfiler - with CudaMemoryProfiler() as m: + from vllm.utils import DeviceMemoryProfiler + with DeviceMemoryProfiler() as m: self.model = get_model( model_config=self.model_config, device_config=DeviceConfig("cpu"), load_config=self.load_config, lora_config=self.lora_config, - multimodal_config=self.multimodal_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config, From 30aa80344fc9c6b61d51ae0e0988c1699b8f4fd7 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 5 Nov 2024 16:57:08 +0800 Subject: [PATCH 02/16] fix --- python/llm/src/ipex_llm/vllm/xpu/engine/engine.py | 8 ++++---- .../vllm/xpu/entrypoints/openai/cli_args.py | 14 +++++--------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index 3d3c1ee3152..d15359dc223 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -134,12 +134,12 @@ def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, def signal_handler(*_) -> None: # Interrupt server on sigterm - raise KeyboardInterrupt("MQLLMEngine terminated") + raise KeyboardInterrupt("MQLLMEngine terminated") # noqa signal.signal(signal.SIGTERM, signal_handler) engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args, - usage_context=usage_context, - ipc_path=ipc_path, - load_in_low_bit=load_in_low_bit) + usage_context=usage_context, + ipc_path=ipc_path, + load_in_low_bit=load_in_low_bit) engine.start() diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py index 88a62ba400d..4110f11b483 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py @@ -27,7 +27,7 @@ def __call__( if values is None: values = [] if isinstance(values, str): - raise TypeError("Expected values to be a list") + raise TypeError("Expected values to be a list") # noqa lora_list: List[LoRAModulePath] = [] for item in values: @@ -63,7 +63,7 @@ def __call__( if values is None: values = [] if isinstance(values, str): - raise TypeError("Expected values to be a list") + raise TypeError("Expected values to be a list") # noqa adapter_list: List[PromptAdapterPath] = [] for item in values: @@ -133,8 +133,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--response-role", type=nullable_str, default="assistant", - help="The role name to return if " - "`request.add_generation_prompt=true`.") + help="The role name to return if `request.add_generation_prompt=true`.") parser.add_argument("--ssl-keyfile", type=nullable_str, default=None, @@ -186,8 +185,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--enable-auto-tool-choice", action="store_true", default=False, - help= - "Enable auto tool choice for supported models. Use --tool-call-parser" + help="Enable auto tool choice for supported models. Use --tool-call-parser" "to specify which parser to use") parser.add_argument( @@ -195,12 +193,10 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, choices=["mistral", "hermes"], default=None, - help= - "Select the tool call parser depending on the model that you're using." + help="Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " "format. Required for --enable-auto-tool-choice.") - parser.add_argument( "--load-in-low-bit", type=str, From f5ee889e1c9aa7402a54e22e69a471b8e221c2ff Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 6 Nov 2024 08:49:14 +0800 Subject: [PATCH 03/16] Change Dockerfile to support v062 --- docker/llm/serving/xpu/docker/Dockerfile | 22 +- .../xpu/docker/benchmark_vllm_throughput.py | 283 +++++++++--------- 2 files changed, 159 insertions(+), 146 deletions(-) diff --git a/docker/llm/serving/xpu/docker/Dockerfile b/docker/llm/serving/xpu/docker/Dockerfile index 3f20fe559a5..d1703f63cff 100644 --- a/docker/llm/serving/xpu/docker/Dockerfile +++ b/docker/llm/serving/xpu/docker/Dockerfile @@ -36,11 +36,22 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO python3 get-pip.py && \ rm get-pip.py && \ pip install --upgrade requests argparse urllib3 && \ + # TODO: only take effect after it has been merged... pip install --pre --upgrade ipex-llm[xpu,serving] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \ + # TODO: remove the following things... pip install transformers_stream_generator einops tiktoken && \ pip install --upgrade colorama && \ # Download all-in-one benchmark and examples git clone https://github.com/intel-analytics/ipex-llm && \ + # TODO: remove the following steps + cd ipex-llm && \ + git fetch origin pull/12338/head:local_pr && \ + git checkout local_pr && \ + pip uninstall -y ipex-llm && \ + cd python/llm && \ + python setup.py install && \ + cd ../../../ && \ + # REMOVE END cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \ cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \ # Install vllm dependencies @@ -74,13 +85,16 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO rm -rf /tmp/neo && \ mkdir -p /llm && \ cd /llm && \ - git clone -b 0.5.4 https://github.com/analytics-zoo/vllm.git /llm/vllm && \ + git clone -b 062_test_0929 https://github.com/analytics-zoo/vllm.git /llm/vllm && \ cd /llm/vllm && \ - pip install -r /llm/vllm/requirements-xpu.txt && \ - VLLM_TARGET_DEVICE=xpu python setup.py install && \ + pip install setuptools-scm && \ + pip install --upgrade cmake && \ + VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \ + # pip install -r /llm/vllm/requirements-xpu.txt && \ + # VLLM_TARGET_DEVICE=xpu python setup.py install && \ pip install mpi4py fastapi uvicorn openai && \ pip install gradio==4.43.0 && \ - pip install transformers==4.44.2 && \ + # pip install transformers==4.44.2 && \ # patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \ pip install ray && \ patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch diff --git a/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py b/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py index 28e94da1c36..5702db8907b 100644 --- a/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py +++ b/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py @@ -1,14 +1,22 @@ """Benchmark offline inference throughput.""" import argparse +import dataclasses import json import random import time from typing import List, Optional, Tuple import torch +import uvloop +from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from tqdm import tqdm + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +# from vllm.sampling_params import BeamSearchParams +from vllm.utils import FlexibleArgumentParser, merge_async_iterators def sample_requests( @@ -29,22 +37,23 @@ def sample_requests( dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - if fixed_output_len is not None: - output_len = fixed_output_len - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + # Shuffle the dataset. + random.shuffle(dataset) - # Filter out too long sequences. + # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue @@ -53,51 +62,18 @@ def sample_requests( continue filtered_dataset.append((prompt, prompt_len, output_len)) - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests + return filtered_dataset def run_vllm( requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, n: int, - use_beam_search: bool, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - device: str, - enable_prefix_caching: bool, - gpu_memory_utilization: float = 0.9, - load_in_low_bit: str = "sym_int4", - max_num_batched_tokens: int = 5000, - max_num_seqs: int = 256, + low_bit: str, + engine_args: EngineArgs, ) -> float: from vllm import SamplingParams from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM - llm = LLM(model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - device=device, - enable_prefix_caching=enable_prefix_caching, - load_in_low_bit=load_in_low_bit, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs,) - + llm = LLM(**dataclasses.asdict(engine_args), load_in_low_bit=low_bit) # Add the requests to the engine. warm_prompt = "hi " * (1024 - 1) @@ -111,14 +87,14 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=0.0, top_p=1.0, - use_beam_search=use_beam_search, ignore_eos=True, max_tokens=output_len, )) llm.generate(prompts, sampling_params, use_tqdm=True) + # Add the requests to the engine. prompts: List[str] = [] sampling_params: List[SamplingParams] = [] for prompt, _, output_len in requests: @@ -126,29 +102,78 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=use_beam_search, ignore_eos=True, max_tokens=output_len, )) - start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() + use_beam_search = False + + if not use_beam_search: + start = time.perf_counter() + llm.generate(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + else: + prompts = [prompt for prompt, _, _ in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for prompt, input_len, _output_len in requests: + assert _output_len == output_len + start = time.perf_counter() + llm.beam_search(prompts, + beam_width=n, + max_tokens=output_len, + ignore_eos=True) + end = time.perf_counter() return end - start +async def run_vllm_async( + requests: List[Tuple[str, int, int]], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=output_len, + )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + def run_hf( requests: List[Tuple[str, int, int]], model: str, tokenizer: PreTrainedTokenizerBase, n: int, - use_beam_search: bool, max_batch_size: int, trust_remote_code: bool, ) -> float: - assert not use_beam_search llm = AutoModelForCausalLM.from_pretrained( model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": @@ -180,7 +205,7 @@ def run_hf( padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), - do_sample=not use_beam_search, + do_sample=True, num_return_sequences=n, temperature=1.0, top_p=1.0, @@ -205,13 +230,15 @@ def run_mii( tensor_parallel_size: int, output_len: int, ) -> float: - from mii import pipeline - llm = pipeline(model, tensor_parallel=tensor_parallel_size) + from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) prompts = [prompt for prompt, _, _ in requests] start = time.perf_counter() - llm(prompts, max_new_tokens=output_len) + llm.generate(prompts, max_new_tokens=output_len) end = time.perf_counter() + client = client(model) + client.terminate_server() return end - start @@ -224,7 +251,16 @@ def main(args: argparse.Namespace): args.tokenizer, trust_remote_code=args.trust_remote_code) if args.dataset is None: # Synthesize a prompt with the given input length. - prompt = "hi" * (args.input_len - 1) + # As tokenizer may add additional tokens like BOS, we need to try + # different lengths to get the desired input length. + for i in range(-10, 10): + prompt = "hi " * (args.input_len + i) + tokenized_prompt = tokenizer(prompt).input_ids + if len(tokenized_prompt) == args.input_len: + break + else: + raise ValueError( + f"Failed to synthesize a prompt with {args.input_len} tokens.") requests = [(prompt, args.input_len, args.output_len) for _ in range(args.num_prompts)] else: @@ -232,18 +268,21 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm( - requests, args.model, args.tokenizer, args.quantization, - args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, args.device, - args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit, - args.max_num_batched_tokens,args.max_num_seqs) + if args.async_engine: + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + )) + else: + elapsed_time = run_vllm(requests, args.n, args.load_in_low_bit, + EngineArgs.from_cli_args(args)) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.use_beam_search, args.hf_max_batch_size, - args.trust_remote_code) + args.hf_max_batch_size, args.trust_remote_code) elif args.backend == "mii": elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, args.output_len) @@ -251,12 +290,26 @@ def main(args: argparse.Namespace): raise ValueError(f"Unknown backend: {args.backend}") total_num_tokens = sum(prompt_len + output_len for _, prompt_len, output_len in requests) + total_output_tokens = sum(output_len for _, _, output_len in requests) print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} tokens/s") + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument("--backend", type=str, choices=["vllm", "hf", "mii"], @@ -274,89 +327,38 @@ def main(args: argparse.Namespace): default=None, help="Output length for each request. Overrides the " "output length from the dataset.") - parser.add_argument("--model", type=str, default="facebook/opt-125m") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=['awq', 'gptq', 'squeezellm', None], - default=None) - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, default=1, help="Number of generated sequences per prompt.") - parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.") - parser.add_argument("--seed", type=int, default=0) parser.add_argument("--hf-max-batch-size", type=int, default=None, help="Maximum batch size for HF backend.") - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument("--enforce-eager", - action="store_true", - help="enforce eager execution") - parser.add_argument( - "--kv-cache-dtype", - type=str, - choices=["auto", "fp8_e5m2"], - default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type.') parser.add_argument( - "--device", + '--output-json', type=str, - default="cuda", - choices=["cuda", "xpu"], - help='device type for vLLM execution, supporting CUDA only currently.') - parser.add_argument( - "--enable-prefix-caching", - action='store_true', - help="enable automatic prefix caching for vLLM backend.") + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") parser.add_argument( "--load-in-low-bit", type=str, choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"], default="sym_int4", help="Low-bit format quantization with IPEX-LLM") - - parser.add_argument('--max-num-batched-tokens', - type=int, - default=4096, - help='maximum number of batched tokens per iteration') - - parser.add_argument('--max-num-seqs', - type=int, - default=256, - help='Maximum number of sequences per iteration.') - - + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model @@ -379,8 +381,6 @@ def main(args: argparse.Namespace): raise ValueError("dtype must be auto for MII backend.") if args.n != 1: raise ValueError("n must be 1 for MII backend.") - if args.use_beam_search: - raise ValueError("Beam search is not supported for MII backend.") if args.quantization is not None: raise ValueError("Quantization is only for vLLM backend.") if args.hf_max_batch_size is not None: @@ -388,5 +388,4 @@ def main(args: argparse.Namespace): if args.tokenizer != args.model: raise ValueError("Tokenizer must be the same as the model for MII " "backend.") - main(args) From c9eab78817cfbdaf7f1034b3bc2a457a3b3f9d8a Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 6 Nov 2024 15:51:13 +0800 Subject: [PATCH 04/16] Fix --- docker/llm/serving/xpu/docker/Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/llm/serving/xpu/docker/Dockerfile b/docker/llm/serving/xpu/docker/Dockerfile index d1703f63cff..6bec3980179 100644 --- a/docker/llm/serving/xpu/docker/Dockerfile +++ b/docker/llm/serving/xpu/docker/Dockerfile @@ -5,6 +5,8 @@ ARG https_proxy ENV TZ=Asia/Shanghai ENV PYTHONUNBUFFERED=1 +# To prevent RPC_TIMEOUT ERROR for the first request +ENV VLLM_RPC_TIMEOUT=100000 # Disable pip's cache behavior From 6d14b224de5502f847ec6e656265aa8179617d77 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 6 Nov 2024 16:02:16 +0800 Subject: [PATCH 05/16] fix examples --- python/llm/example/GPU/vLLM-Serving/README.md | 16 ++++++++-------- .../GPU/vLLM-Serving/offline_inference.py | 8 +++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/python/llm/example/GPU/vLLM-Serving/README.md b/python/llm/example/GPU/vLLM-Serving/README.md index e192bc1ecb6..7d6fb013bd9 100644 --- a/python/llm/example/GPU/vLLM-Serving/README.md +++ b/python/llm/example/GPU/vLLM-Serving/README.md @@ -2,7 +2,7 @@ This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations). -The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.3.3). +The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2). Currently, we support the following models for vLLM engine: @@ -17,7 +17,7 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI- ### 0. Environment -To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.0. Please check the requirements at [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU#requirements). +To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU#requirements). After install the toolkit, run the following commands in your environment before starting vLLM GPU: ```bash @@ -44,14 +44,13 @@ conda create -n ipex-vllm python=3.11 conda activate ipex-vllm # Install dependencies pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +pip install setuptools-scm +pip install --upgrade cmake # cd to your workdir -git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git +# TODO: check this later... for the specific branch +git clone -b https://github.com/analytics-zoo/vllm.git cd vllm -pip install -r requirements-xpu.txt -pip install --no-deps xformers -VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e . -pip install outlines==0.0.34 --no-deps -pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy +VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v . && \ # For Qwen model support pip install transformers_stream_generator einops tiktoken ``` @@ -61,6 +60,7 @@ pip install transformers_stream_generator einops tiktoken ```bash export USE_XETLA=OFF export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +export SYCL_CACHE_PERSISTENT=1 ``` ### 3. Offline inference/Service diff --git a/python/llm/example/GPU/vLLM-Serving/offline_inference.py b/python/llm/example/GPU/vLLM-Serving/offline_inference.py index ae6d6ed8322..4f09483f045 100644 --- a/python/llm/example/GPU/vLLM-Serving/offline_inference.py +++ b/python/llm/example/GPU/vLLM-Serving/offline_inference.py @@ -49,8 +49,10 @@ device="xpu", dtype="float16", enforce_eager=True, - load_in_low_bit="sym_int4", - tensor_parallel_size=1) + load_in_low_bit="fp8", + tensor_parallel_size=1, + max_model_len=2000, + max_num_batched_tokens=2000) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -58,4 +60,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file From e956b7ba2b65c819c4e7fb81fa27921865a3c250 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 7 Nov 2024 20:02:00 +0800 Subject: [PATCH 06/16] Fix --- python/llm/src/ipex_llm/serving/fastchat/vllm_worker.py | 2 +- python/llm/src/ipex_llm/vllm/xpu/engine/engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/serving/fastchat/vllm_worker.py b/python/llm/src/ipex_llm/serving/fastchat/vllm_worker.py index 482031f97a3..739b7a88634 100644 --- a/python/llm/src/ipex_llm/serving/fastchat/vllm_worker.py +++ b/python/llm/src/ipex_llm/serving/fastchat/vllm_worker.py @@ -93,7 +93,7 @@ async def generate_stream(self, params): request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) - top_k = params.get("top_k", -1.0) + top_k = params.get("top_k", -1) presence_penalty = float(params.get("presence_penalty", 0.0)) frequency_penalty = float(params.get("frequency_penalty", 0.0)) max_new_tokens = params.get("max_new_tokens", 256) diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index d15359dc223..ab813dc0ad6 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -42,7 +42,7 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. _ipex_llm_convert(load_in_low_bit) - return super().from_engine_args(engine_args, start_engine_loop, usage_context, stat_loggers) + return super().from_engine_args(engine_args, start_engine_loop=start_engine_loop, usage_context=usage_context, stat_loggers=stat_loggers) class IPEXLLMClass(LLM): From f09963b8ba7dba75fc9dbd27f05b2cb4093d0f7d Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 7 Nov 2024 20:04:37 +0800 Subject: [PATCH 07/16] done --- python/llm/example/GPU/vLLM-Serving/README.md | 1 + python/llm/src/ipex_llm/vllm/xpu/engine/engine.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/llm/example/GPU/vLLM-Serving/README.md b/python/llm/example/GPU/vLLM-Serving/README.md index 7d6fb013bd9..0a29b0b69df 100644 --- a/python/llm/example/GPU/vLLM-Serving/README.md +++ b/python/llm/example/GPU/vLLM-Serving/README.md @@ -86,6 +86,7 @@ For vLLM, you can start the service using the following command: #!/bin/bash model="YOUR_MODEL_PATH" served_model_name="YOUR_MODEL_NAME" +export VLLM_RPC_TIMEOUT=100000 # You may need to adjust the value of # --max-model-len, --max-num-batched-tokens, --max-num-seqs diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index ab813dc0ad6..9f87222448b 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -42,7 +42,8 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. _ipex_llm_convert(load_in_low_bit) - return super().from_engine_args(engine_args, start_engine_loop=start_engine_loop, usage_context=usage_context, stat_loggers=stat_loggers) + return super().from_engine_args(engine_args, start_engine_loop=start_engine_loop, + usage_context=usage_context, stat_loggers=stat_loggers) class IPEXLLMClass(LLM): From f91b3962db6f05e82ac6e9001cdc7bec30b5b427 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Mon, 11 Nov 2024 12:16:29 +0800 Subject: [PATCH 08/16] fix --- python/llm/src/ipex_llm/vllm/xpu/engine/engine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index 9f87222448b..aaeee24a080 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -19,6 +19,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.llm import LLM from vllm.utils import Counter +from vllm.config import EngineConfig from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert from vllm.usage.usage_lib import UsageContext from vllm.engine.metrics import StatLoggerBase @@ -34,6 +35,7 @@ def __init__(self, *args, **kwargs): def from_engine_args( cls, engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, load_in_low_bit: str = "sym_int4", @@ -42,7 +44,8 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. _ipex_llm_convert(load_in_low_bit) - return super().from_engine_args(engine_args, start_engine_loop=start_engine_loop, + return super().from_engine_args(engine_args=engine_args, engine_config=engine_config, + engine_start_engine_loop=start_engine_loop, usage_context=usage_context, stat_loggers=stat_loggers) From 6ad0f202ebdff3cb2a34e361c0b9b057b85f0778 Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Tue, 12 Nov 2024 12:59:11 +0800 Subject: [PATCH 09/16] Update engine.py --- python/llm/src/ipex_llm/vllm/xpu/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index aaeee24a080..c4bed0c52e7 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -45,7 +45,7 @@ def from_engine_args( # Create the engine configs. _ipex_llm_convert(load_in_low_bit) return super().from_engine_args(engine_args=engine_args, engine_config=engine_config, - engine_start_engine_loop=start_engine_loop, + start_engine_loop=start_engine_loop, usage_context=usage_context, stat_loggers=stat_loggers) From ec304eb24c0e0f2e17e39ec9fcc895fb98b23cdd Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 12 Nov 2024 15:06:22 +0800 Subject: [PATCH 10/16] Fix Dockerfile to original path --- docker/llm/serving/xpu/docker/Dockerfile | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/docker/llm/serving/xpu/docker/Dockerfile b/docker/llm/serving/xpu/docker/Dockerfile index 6bec3980179..8ffe814f166 100644 --- a/docker/llm/serving/xpu/docker/Dockerfile +++ b/docker/llm/serving/xpu/docker/Dockerfile @@ -38,22 +38,20 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO python3 get-pip.py && \ rm get-pip.py && \ pip install --upgrade requests argparse urllib3 && \ - # TODO: only take effect after it has been merged... pip install --pre --upgrade ipex-llm[xpu,serving] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \ - # TODO: remove the following things... pip install transformers_stream_generator einops tiktoken && \ pip install --upgrade colorama && \ # Download all-in-one benchmark and examples git clone https://github.com/intel-analytics/ipex-llm && \ - # TODO: remove the following steps - cd ipex-llm && \ - git fetch origin pull/12338/head:local_pr && \ - git checkout local_pr && \ - pip uninstall -y ipex-llm && \ - cd python/llm && \ - python setup.py install && \ - cd ../../../ && \ - # REMOVE END + # # TODO: remove the following steps + # cd ipex-llm && \ + # git fetch origin pull/12338/head:local_pr && \ + # git checkout local_pr && \ + # pip uninstall -y ipex-llm && \ + # cd python/llm && \ + # python setup.py install && \ + # cd ../../../ && \ + # # REMOVE END cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \ cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \ # Install vllm dependencies @@ -87,7 +85,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO rm -rf /tmp/neo && \ mkdir -p /llm && \ cd /llm && \ - git clone -b 062_test_0929 https://github.com/analytics-zoo/vllm.git /llm/vllm && \ + git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git /llm/vllm && \ cd /llm/vllm && \ pip install setuptools-scm && \ pip install --upgrade cmake && \ From 290f085d0676652713d2e09cfe3fc3ccf45c386a Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 12 Nov 2024 15:12:16 +0800 Subject: [PATCH 11/16] fix --- docker/llm/serving/xpu/docker/Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docker/llm/serving/xpu/docker/Dockerfile b/docker/llm/serving/xpu/docker/Dockerfile index 8ffe814f166..6a2f78b0a44 100644 --- a/docker/llm/serving/xpu/docker/Dockerfile +++ b/docker/llm/serving/xpu/docker/Dockerfile @@ -43,7 +43,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO pip install --upgrade colorama && \ # Download all-in-one benchmark and examples git clone https://github.com/intel-analytics/ipex-llm && \ - # # TODO: remove the following steps + # The following comment segment is used when building from source... # cd ipex-llm && \ # git fetch origin pull/12338/head:local_pr && \ # git checkout local_pr && \ @@ -51,7 +51,6 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO # cd python/llm && \ # python setup.py install && \ # cd ../../../ && \ - # # REMOVE END cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \ cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \ # Install vllm dependencies From ab32ca4253de3d8ab01d5c881cae00594d31d3b1 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 12 Nov 2024 15:29:30 +0800 Subject: [PATCH 12/16] add option --- docker/llm/serving/xpu/docker/start-vllm-service.sh | 4 +++- python/llm/example/GPU/vLLM-Serving/README.md | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docker/llm/serving/xpu/docker/start-vllm-service.sh b/docker/llm/serving/xpu/docker/start-vllm-service.sh index 15a252f7f6d..30b8e53102e 100644 --- a/docker/llm/serving/xpu/docker/start-vllm-service.sh +++ b/docker/llm/serving/xpu/docker/start-vllm-service.sh @@ -28,4 +28,6 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \ --max-model-len 2048 \ --max-num-batched-tokens 4000 \ --max-num-seqs 12 \ - --tensor-parallel-size 1 + --tensor-parallel-size 1 \ + --disable-async-output-proc \ + --distributed-executor-backend ray diff --git a/python/llm/example/GPU/vLLM-Serving/README.md b/python/llm/example/GPU/vLLM-Serving/README.md index 0a29b0b69df..8c8c7a628da 100644 --- a/python/llm/example/GPU/vLLM-Serving/README.md +++ b/python/llm/example/GPU/vLLM-Serving/README.md @@ -105,7 +105,8 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \ --max-model-len 4096 \ --max-num-batched-tokens 10240 \ --max-num-seqs 12 \ - --tensor-parallel-size 1 + --tensor-parallel-size 1 \ + --disable-async-output-proc ``` You can tune the service using these four arguments: @@ -201,5 +202,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \ --max-model-len 4096 \ --max-num-batched-tokens 10240 \ --max-num-seqs 12 \ - --tensor-parallel-size 2 + --tensor-parallel-size 2 \ + --distributed-executor-backend ray \ + --disable-async-output-proc ``` From 0f93953fb66a1c79ac5d4b58a1e6c125d8800955 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 12 Nov 2024 16:09:30 +0800 Subject: [PATCH 13/16] fix --- docker/llm/serving/xpu/docker/vllm_offline_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/llm/serving/xpu/docker/vllm_offline_inference.py b/docker/llm/serving/xpu/docker/vllm_offline_inference.py index 4f09483f045..6ee7598deb0 100644 --- a/docker/llm/serving/xpu/docker/vllm_offline_inference.py +++ b/docker/llm/serving/xpu/docker/vllm_offline_inference.py @@ -51,6 +51,8 @@ enforce_eager=True, load_in_low_bit="fp8", tensor_parallel_size=1, + disable_async_output_proc=True, + distributed_executor_backend="ray", max_model_len=2000, max_num_batched_tokens=2000) # Generate texts from the prompts. The output is a list of RequestOutput objects From 2eeb2f204e009f997310aaa76a3b4623246c61a8 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 12 Nov 2024 16:31:17 +0800 Subject: [PATCH 14/16] fix --- .../llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py index 45d4ed2ea5c..1bd9835fe03 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py @@ -131,6 +131,7 @@ async def build_async_engine_client_from_engine_args( build_engine = partial(AsyncLLMEngine.from_engine_args, engine_args=engine_args, + load_in_low_bit=load_in_low_bit, engine_config=engine_config, usage_context=UsageContext.OPENAI_API_SERVER) if uses_ray: From 5fd0dce0b9695f369008ab6388f359b46dd2e5d9 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 12 Nov 2024 16:44:27 +0800 Subject: [PATCH 15/16] fix --- python/llm/example/GPU/vLLM-Serving/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/example/GPU/vLLM-Serving/README.md b/python/llm/example/GPU/vLLM-Serving/README.md index 8c8c7a628da..da35e5b79b9 100644 --- a/python/llm/example/GPU/vLLM-Serving/README.md +++ b/python/llm/example/GPU/vLLM-Serving/README.md @@ -17,7 +17,7 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI- ### 0. Environment -To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU#requirements). +To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html). After install the toolkit, run the following commands in your environment before starting vLLM GPU: ```bash @@ -50,7 +50,7 @@ pip install --upgrade cmake # TODO: check this later... for the specific branch git clone -b https://github.com/analytics-zoo/vllm.git cd vllm -VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v . && \ +VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v . # For Qwen model support pip install transformers_stream_generator einops tiktoken ``` @@ -59,7 +59,7 @@ pip install transformers_stream_generator einops tiktoken ```bash export USE_XETLA=OFF -export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 export SYCL_CACHE_PERSISTENT=1 ``` ### 3. Offline inference/Service From dbbe5cc40ba470361b9a931db47fa80a5b5bd2c7 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 12 Nov 2024 17:03:14 +0800 Subject: [PATCH 16/16] fix --- python/llm/example/GPU/vLLM-Serving/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/llm/example/GPU/vLLM-Serving/README.md b/python/llm/example/GPU/vLLM-Serving/README.md index da35e5b79b9..0e644d8b0e8 100644 --- a/python/llm/example/GPU/vLLM-Serving/README.md +++ b/python/llm/example/GPU/vLLM-Serving/README.md @@ -47,8 +47,7 @@ pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-ex pip install setuptools-scm pip install --upgrade cmake # cd to your workdir -# TODO: check this later... for the specific branch -git clone -b https://github.com/analytics-zoo/vllm.git +git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git cd vllm VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v . # For Qwen model support