From 9a7d5514c4a0762f90aa4388a4326db5d193d084 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 30 Dec 2024 16:00:59 -0800 Subject: [PATCH] feat: introduce MQAphroditeEngine (#1056) * feat: introduce MQAphroditeEngine * add `dead_error` property to engine client * fix model unload endpoint * add a simple model load endpoint * take more args in model load field * take yaml config in model load endpoint * inline model switching --- aphrodite/common/envs.py | 6 +- aphrodite/endpoints/openai/api_server.py | 709 +++++++++--------- aphrodite/endpoints/openai/rpc/__init__.py | 44 -- aphrodite/endpoints/openai/rpc/client.py | 420 ----------- aphrodite/endpoints/openai/rpc/server.py | 234 ------ aphrodite/endpoints/openai/serving_chat.py | 17 +- .../endpoints/openai/serving_completions.py | 17 +- .../endpoints/openai/serving_embedding.py | 12 +- aphrodite/endpoints/openai/serving_engine.py | 8 +- .../endpoints/openai/serving_tokenization.py | 10 +- aphrodite/engine/aphrodite_engine.py | 17 + aphrodite/engine/async_aphrodite.py | 9 +- aphrodite/engine/multiprocessing/__init__.py | 83 ++ aphrodite/engine/multiprocessing/client.py | 460 ++++++++++++ aphrodite/engine/multiprocessing/engine.py | 325 ++++++++ aphrodite/engine/protocol.py | 7 +- aphrodite/executor/cpu_executor.py | 1 + aphrodite/executor/multiproc_worker_utils.py | 4 + aphrodite/server/launch.py | 31 +- tests/benchmarks/engine/throughput.py | 6 +- 20 files changed, 1296 insertions(+), 1124 deletions(-) delete mode 100644 aphrodite/endpoints/openai/rpc/__init__.py delete mode 100644 aphrodite/endpoints/openai/rpc/client.py delete mode 100644 aphrodite/endpoints/openai/rpc/server.py create mode 100644 aphrodite/engine/multiprocessing/__init__.py create mode 100644 aphrodite/engine/multiprocessing/client.py create mode 100644 aphrodite/engine/multiprocessing/engine.py diff --git a/aphrodite/common/envs.py b/aphrodite/common/envs.py index f6c873eab..922068a1b 100644 --- a/aphrodite/common/envs.py +++ b/aphrodite/common/envs.py @@ -54,7 +54,7 @@ APHRODITE_DYNAMIC_ROPE_SCALING: bool = False APHRODITE_TEST_FORCE_FP8_MARLIN: bool = False APHRODITE_PLUGINS: Optional[List[str]] = None - APHRODITE_RPC_GET_DATA_TIMEOUT_MS: int = 5000 + APHRODITE_RPC_TIMEOUT: int = 5000 APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE: bool = False APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE: int = 0 APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE: int = 0 @@ -383,8 +383,8 @@ def get_default_config_root(): # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "APHRODITE_RPC_GET_DATA_TIMEOUT_MS": - lambda: int(os.getenv("APHRODITE_RPC_GET_DATA_TIMEOUT_MS", "5000")), + "APHRODITE_RPC_TIMEOUT": + lambda: int(os.getenv("APHRODITE_RPC_TIMEOUT", "5000")), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded diff --git a/aphrodite/endpoints/openai/api_server.py b/aphrodite/endpoints/openai/api_server.py index 7e2096982..395a7198e 100644 --- a/aphrodite/endpoints/openai/api_server.py +++ b/aphrodite/endpoints/openai/api_server.py @@ -1,10 +1,10 @@ import asyncio -import copy import importlib import inspect import json import multiprocessing import os +import pickle import re import signal import tempfile @@ -13,11 +13,10 @@ from distutils.util import strtobool from functools import partial from http import HTTPStatus -from typing import (Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, - Set, Tuple) +from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple import yaml -from fastapi import APIRouter, FastAPI, Request, UploadFile +from fastapi import APIRouter, FastAPI, Form, Request, UploadFile from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import (HTMLResponse, JSONResponse, Response, @@ -35,7 +34,6 @@ random_uuid) from aphrodite.endpoints.logger import RequestLogger from aphrodite.endpoints.openai.args import make_arg_parser -# yapf: disable from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, @@ -46,9 +44,6 @@ KAIGenerationInputSchema, TokenizeRequest, TokenizeResponse) -from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient -from aphrodite.endpoints.openai.rpc.server import run_rpc_server -# yapf: enable from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat from aphrodite.endpoints.openai.serving_completions import ( OpenAIServingCompletion) @@ -59,7 +54,11 @@ OpenAIServingTokenization) from aphrodite.engine.args_tools import AsyncEngineArgs from aphrodite.engine.async_aphrodite import AsyncAphrodite -from aphrodite.engine.protocol import AsyncEngineClient +from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR, + RPCShutdownRequest) +from aphrodite.engine.multiprocessing.client import MQAphroditeEngineClient +from aphrodite.engine.multiprocessing.engine import run_mp_engine +from aphrodite.engine.protocol import EngineClient from aphrodite.modeling.model_loader.weight_utils import get_model_config_yaml from aphrodite.server import serve_http from aphrodite.transformers_utils.tokenizer import get_tokenizer @@ -80,32 +79,20 @@ sampler_json = "" gen_cache: dict = {} prometheus_multiproc_dir: tempfile.TemporaryDirectory -model_is_loaded = True _running_tasks: Set[asyncio.Task] = set() -def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: Optional[str]) -> bool: - return ModelConfig(model=model_name, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - quantization=quantization, - seed=0, - dtype="auto").embedding_mode - - @asynccontextmanager async def lifespan(app: FastAPI): try: if app.state.log_stats: - async_engine_client = app.state.engine_client + engine_client: EngineClient = app.state.engine_client async def _force_log(): while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() + await asyncio.sleep(10.) + await engine_client.do_log_stats() task = asyncio.create_task(_force_log()) _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) @@ -122,36 +109,35 @@ async def _force_log(): @asynccontextmanager -async def build_async_engine_client( - args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: +async def build_engine_client( + args: Namespace) -> AsyncIterator[Optional[EngineClient]]: - # Context manager to handle async_engine_client lifecycle + # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) - async with build_async_engine_client_from_engine_args( + async with build_engine_client_from_engine_args( engine_args, args.disable_frontend_multiprocessing) as engine: yield engine @asynccontextmanager -async def build_async_engine_client_from_engine_args( +async def build_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> AsyncIterator[Optional[AsyncEngineClient]]: +) -> AsyncIterator[Optional[EngineClient]]: """ - Create AsyncEngineClient, either: + Create EngineClient, either: - in-process using the AsyncAphrodite Directly - multiprocess using AsyncAphrodite RPC Returns the Client or None if the creation failed. """ - # If manually triggered or embedding model, use AsyncAphrodite in process. - # TODO: support embedding model via RPC. - if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization) + # Fall back + # TODO: fill out feature matrix. + if (MQAphroditeEngineClient.is_unsupported_config(engine_args) or disable_frontend_multiprocessing): engine_config = engine_args.create_engine_config() uses_ray = getattr(AsyncAphrodite._get_executor_cls(engine_config), @@ -186,224 +172,58 @@ async def build_async_engine_client_from_engine_args( "and Aphrodite will properly handle cleanup.") # Select random path for IPC. - rpc_path = get_open_zmq_ipc_path() - logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path." - ) - - # Build RPCClient, which conforms to AsyncEngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) + ipc_path = get_open_zmq_ipc_path() + logger.info( + f"Multiprocessing frontend to use {ipc_path} for IPC Path.") - # Start RPCServer in separate process (holds the AsyncAphrodite). - context = multiprocessing.get_context("spawn") + # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, rpc_path)) - rpc_server_process.start() - logger.info( - f"Started engine process with PID {rpc_server_process.pid}") + context = multiprocessing.get_context("spawn") + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + ipc_path)) + engine_process.start() + logger.info(f"Started engine process with PID {engine_process.pid}") + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQAphroditeEngineClient(ipc_path, engine_config) try: while True: try: - await rpc_client.setup() + await mp_engine_client.setup() break except TimeoutError: - if not rpc_server_process.is_alive(): - logger.error( - "RPCServer process died before responding " - "to readiness probe") + if not engine_process.is_alive(): + logger.error("Engine process died before responding " + "to readiness probe") yield None return - yield rpc_client # type: ignore[misc] + yield mp_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - rpc_client.close() + mp_engine_client.close() - # Wait for server process to join - rpc_server_process.join() + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess - multiprocess.mark_process_dead(rpc_server_process.pid) - - -async def _maybe_switch_model( - request_model: str, app_state, - raw_request: Request) -> Optional[ErrorResponse]: - """Switch to requested model if different from currently loaded one.""" - global model_is_loaded, async_engine_client, engine_args, served_model_names - - if not model_is_loaded: - return None - - models = await completion(raw_request).show_available_models() + multiprocess.mark_process_dead(engine_process.pid) - for model in models.data: - if request_model in (model.id, model.root): - return None - - if not app_state.args.allow_inline_model_loading: - return JSONResponse( - content={ - "error": { - "message": "Requested model is not currently loaded. " - "Inline model loading is disabled. Enable it with " - "--allow-inline-model-loading.", - "type": "invalid_request_error", - "code": "model_not_loaded" - } - }, - status_code=400 - ) # type: ignore - - # Authentication checks - api_key = envs.APHRODITE_API_KEY or app_state.args.api_keys - admin_key = envs.APHRODITE_ADMIN_KEY or app_state.args.admin_key - - if api_key: - api_key_header = raw_request.headers.get("x-api-key") - auth_header = raw_request.headers.get("Authorization") - - if not admin_key: - return JSONResponse( - content={ - "error": { - "message": "Admin key not configured. " - "Inline model loading is disabled.", - "type": "invalid_request_error", - "code": "admin_key_required" - } - }, - status_code=401 - ) # type: ignore - - if not (api_key_header == admin_key or - auth_header == f"Bearer {admin_key}"): - return JSONResponse( - content={ - "error": { - "message": "Admin privileges required for inline " - "model loading.", - "type": "invalid_request_error", - "code": "unauthorized" - } - }, - status_code=401 - ) # type: ignore - - logger.info(f"Switching from {served_model_names[0]} to {request_model}") - - try: - args = app_state.args - current_client = engine_client(raw_request) - - # First shut down the current engine - if not args.disable_frontend_multiprocessing: - await current_client.kill() - else: - await current_client.shutdown_background_loop() - - model_is_loaded = False - - yaml_config = get_model_config_yaml(request_model, args.download_dir) - - if yaml_config: - parser = FlexibleArgumentParser() - parser = make_arg_parser(parser) - engine_args = parser.parse_args([]) # empty args - - for key, value in yaml_config.items(): - if hasattr(engine_args, key): - setattr(engine_args, key, value) - - engine_args.model = request_model - engine_args = AsyncEngineArgs.from_cli_args(engine_args) - else: - # Fallback to minimal config - engine_args = AsyncEngineArgs(model=request_model) - - # Create new engine client without context manager - if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization) - or args.disable_frontend_multiprocessing): - new_engine_client = AsyncAphrodite.from_engine_args(engine_args) - await new_engine_client.setup() - else: - if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: - global prometheus_multiproc_dir - prometheus_multiproc_dir = tempfile.TemporaryDirectory() - os.environ[ - "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name - - rpc_path = get_open_zmq_ipc_path() - logger.info( - f"Multiprocessing frontend to use {rpc_path} for RPC Path.") - - rpc_client = AsyncEngineRPCClient(rpc_path) - - context = multiprocessing.get_context("spawn") - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, rpc_path)) - rpc_server_process.start() - logger.info( - f"Started engine process with PID {rpc_server_process.pid}") - - try: - while True: - try: - await rpc_client.setup() - break - except TimeoutError as e: - if not rpc_server_process.is_alive(): - raise RuntimeError( - "RPC Server died before responding to " - "readiness probe") from e - - new_engine_client = rpc_client - model_config = await new_engine_client.get_model_config() - new_args = copy.deepcopy(args) - new_args.model = request_model - - init_app_state( - new_engine_client, model_config, - raw_request.app.state, new_args) - - served_model_names = [request_model] - model_is_loaded = True - return None - - except Exception as e: - # Clean up RPC resources on error - rpc_server_process.terminate() - rpc_client.close() - rpc_server_process.join() - raise e - - except Exception as e: - error_msg = f"Error while switching models: {str(e)}" - logger.error(error_msg) - return JSONResponse( - content={ - "error": { - "message": error_msg, - "type": "invalid_request_error", - "code": "model_load_error" - } - }, - status_code=500 - ) # type: ignore def mount_metrics(app: FastAPI): # Lazy import for prometheus multiprocessing. @@ -428,6 +248,55 @@ def mount_metrics(app: FastAPI): app.routes.append(metrics_route) +async def _handle_model_switch( + raw_request: Request, + requested_model: str +) -> Optional[JSONResponse]: + """Helper function to handle model switching if needed. + Returns error response if something went wrong, None if successful.""" + + if not raw_request.app.state.args.allow_inline_model_loading: + return None + + if not raw_request.app.state.model_is_loaded: + config = get_model_config_yaml(requested_model) + request_data = {"model": requested_model} + if config: + config.pop("model", None) + request_data.update(config) + + load_response = await load_model( + raw_request, + request=json.dumps(request_data) + ) + if load_response.status_code != 200: + return load_response + return None + + current_model = raw_request.app.state.current_model + if current_model == requested_model: + return None + + unload_response = await unload_model(raw_request) + if unload_response.status_code != 200: + return unload_response + + config = get_model_config_yaml(requested_model) + request_data = {"model": requested_model} + if config: + config.pop("model", None) + request_data.update(config) + + load_response = await load_model( + raw_request, + request=json.dumps(request_data) + ) + if load_response.status_code != 200: + return load_response + + return None + + def chat(request: Request) -> OpenAIServingChat: return request.app.state.openai_serving_chat @@ -440,148 +309,205 @@ def tokenization(request: Request) -> OpenAIServingTokenization: def embedding(request: Request) -> OpenAIServingEmbedding: return request.app.state.openai_serving_embedding -def engine_client(request: Request) -> AsyncEngineClient: +def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @router.delete("/v1/model/unload") async def unload_model(raw_request: Request): - """Unload the current model and shut down the server.""" - logger.info("Received request to unload model.") - - try: - args = raw_request.app.state.args - if not args.disable_frontend_multiprocessing: - await engine_client(raw_request).kill() - else: - await engine_client(raw_request).shutdown_background_loop() - - global model_is_loaded - model_is_loaded = False + """Unload the model and shut down the engine process.""" + if not raw_request.app.state.model_is_loaded: return JSONResponse( content={ - "status": "success", - "message": "Model unloaded successfully" - } + "status": "error", + "message": "No model loaded." + }, + status_code=500 ) + client = raw_request.app.state.engine_client + + if isinstance(client, MQAphroditeEngineClient): + try: + shutdown_req = RPCShutdownRequest() + await client.input_socket.send_multipart( + (pickle.dumps(shutdown_req),), copy=False + ) - except Exception as e: - error_msg = f"Error while unloading model: {str(e)}" - logger.error(error_msg) + response = await client.output_socket.recv_multipart() + if pickle.loads(response[0]) != APHRODITE_RPC_SUCCESS_STR: + raise RuntimeError("Engine shutdown failed") + + client.output_loop.cancel() + if client.health_loop is not None: + client.health_loop.cancel() + + client.close() + + raw_request.app.state.engine_client = None + raw_request.app.state.openai_serving_chat = None + raw_request.app.state.openai_serving_completion = None + raw_request.app.state.openai_serving_embedding = None + raw_request.app.state.openai_serving_tokenization = None + raw_request.app.state.model_is_loaded = False + + return JSONResponse(content={"status": "success"}) + + except Exception as e: + return JSONResponse( + content={ + "status": "error", + "message": f"Failed to shutdown engine: {str(e)}" + }, + status_code=500 + ) + else: return JSONResponse( - content={"status": "error", "message": error_msg}, - status_code=500 + content={ + "status": "error", + "message": "Model unloading only supported with multiprocessing" + " backend" + }, + status_code=400 ) - @router.post("/v1/model/load") -async def load_model(config_file: UploadFile, raw_request: Request): - """Load a model using a YAML configuration file.""" - global model_is_loaded, async_engine_client, engine_args - - if model_is_loaded: +async def load_model( + raw_request: Request, + config_file: Optional[UploadFile] = None, + request: Optional[str] = Form(None) +): + """Load a new model after unloading the previous one. + Accept either a config file, a JSON request body, or both.""" + if raw_request.app.state.model_is_loaded: return JSONResponse( content={ - "error": { - "message": "A model is already loaded. " - "Please unload it first.", - "type": "invalid_request_error", - "code": "model_already_loaded" - } + "status": "error", + "message": "A model is already loaded. Please unload it first." }, status_code=400 ) try: - config_text = await config_file.read() - config: Dict[Any, Any] = yaml.safe_load(config_text) - - args = [] - for key, value in config.items(): - key = key.replace('_', '-') - - if isinstance(value, bool): - if value: - args.append(f"--{key}") - elif isinstance(value, (list, tuple)): - if key in ['lora-modules', 'prompt-adapters']: - for item in value: - args.append(f"--{key}") - args.append(f"{item['name']}={item['path']}") - else: - for item in value: - args.append(f"--{key}") - args.append(str(item)) - else: - args.append(f"--{key}") - args.append(str(value)) - parser = FlexibleArgumentParser() parser = make_arg_parser(parser) - parsed_args = parser.parse_args(args) - engine_args = AsyncEngineArgs.from_cli_args(parsed_args) - - # Create new engine client without context manager - if (model_is_embedding(engine_args.model, - engine_args.trust_remote_code, - engine_args.quantization) - or parsed_args.disable_frontend_multiprocessing): - new_engine_client = AsyncAphrodite.from_engine_args(engine_args) - await new_engine_client.setup() + new_args = parser.parse_args([]) + + original_args = api_server_args + essential_params = [ + 'host', 'port', 'api_keys', 'admin_key', + 'disable_frontend_multiprocessing', 'root_path', + 'ssl_keyfile', 'ssl_certfile' + ] + for param in essential_params: + if hasattr(original_args, param): + setattr(new_args, param, getattr(original_args, param)) + + if config_file: + yaml_content = await config_file.read() + config_args = yaml.safe_load(yaml_content) + if config_args: + for key, value in config_args.items(): + if hasattr(new_args, key): + setattr(new_args, key, value) + + json_args = None + if request: + try: + json_args = json.loads(request) + except json.JSONDecodeError: + return JSONResponse( + content={ + "status": "error", + "message": "Invalid JSON in request form field." + }, + status_code=400 + ) else: - if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: - global prometheus_multiproc_dir - prometheus_multiproc_dir = tempfile.TemporaryDirectory() - os.environ[ - "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name - - rpc_path = get_open_zmq_ipc_path() - logger.info( - f"Multiprocessing frontend to use {rpc_path} for RPC Path.") - - rpc_client = AsyncEngineRPCClient(rpc_path) - new_engine_client = rpc_client - - context = multiprocessing.get_context("spawn") - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, rpc_path)) - rpc_server_process.start() - logger.info( - f"Started engine process with PID {rpc_server_process.pid}") + try: + json_args = await raw_request.json() + except Exception: + if not config_file: + return JSONResponse( + content={ + "status": "error", + "message": "Must provide either config_file or " + "valid JSON request body." + }, + status_code=400 + ) + if json_args: + for key, value in json_args.items(): + if hasattr(new_args, key): + setattr(new_args, key, value) + + if not hasattr(new_args, 'model') or not new_args.model: + return JSONResponse( + content={ + "status": "error", + "message": "No model specified in config or request body." + }, + status_code=400 + ) + + engine_args = AsyncEngineArgs.from_cli_args(new_args) + + if (MQAphroditeEngineClient.is_unsupported_config(engine_args) + or new_args.disable_frontend_multiprocessing): + return JSONResponse( + content={ + "status": "error", + "message": "Model loading only supported with " + "multiprocessing backend." + }, + status_code=400 + ) + + ipc_path = get_open_zmq_ipc_path() + context = multiprocessing.get_context("spawn") + engine_process = context.Process( + target=run_mp_engine, + args=(engine_args, ipc_path) + ) + engine_process.start() + + engine_config = engine_args.create_engine_config() + engine_client = MQAphroditeEngineClient(ipc_path, engine_config) + + try: while True: try: - await new_engine_client.setup() + await engine_client.setup() break - except TimeoutError as e: - if not rpc_server_process.is_alive(): - raise RuntimeError( - "RPC Server died before responding to readiness " - "probe") from e - - model_config = await engine_client(raw_request).get_model_config() - init_app_state(engine_client(raw_request), model_config, - raw_request.app.state, parsed_args) - - model_is_loaded = True - return JSONResponse( - content={ - "status": "success", - "message": "Model loaded successfully" - } - ) + except TimeoutError: + if not engine_process.is_alive(): + return JSONResponse( + content={ + "status": "error", + "message": "Engine process died before " + "responding to readiness probe." + }, + status_code=500 + ) + + model_config = await engine_client.get_model_config() + init_app_state( + engine_client, model_config, raw_request.app.state, new_args) + raw_request.app.state.model_is_loaded = True + raw_request.app.state.current_model = new_args.model + + return JSONResponse(content={"status": "success"}) + + except Exception as e: + engine_process.terminate() + engine_client.close() + raise e except Exception as e: - error_msg = f"Error while loading model: {str(e)}" - logger.error(error_msg) return JSONResponse( content={ - "error": { - "message": error_msg, - "type": "invalid_request_error", - "code": "model_load_error" - } + "status": "error", + "message": f"Failed to load model: {str(e)}" }, status_code=500 ) @@ -595,7 +521,12 @@ async def health(raw_request: Request) -> Response: @router.post("/v1/tokenize") async def tokenize(request: TokenizeRequest, raw_request: Request): - if not model_is_loaded: + if hasattr(request, "model"): + error_response = await _handle_model_switch(raw_request, request.model) + if error_response is not None: + return error_response + + if not raw_request.app.state.model_is_loaded: return JSONResponse( content={ "status": "error", @@ -614,7 +545,13 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post("/v1/detokenize") async def detokenize(request: DetokenizeRequest, raw_request: Request): - if not model_is_loaded: + if hasattr(request, "model"): + error_response = await _handle_model_switch( + raw_request, request.model) + if error_response is not None: + return error_response + + if not raw_request.app.state.model_is_loaded: return JSONResponse( content={ "status": "error", @@ -633,6 +570,14 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): @router.get("/v1/models") async def show_available_models(raw_request: Request): + if not raw_request.app.state.model_is_loaded: + return JSONResponse( + content={ + "status": "error", + "message": "No model loaded." + }, + status_code=500 + ) models = await completion(raw_request).show_available_models() return JSONResponse(content=models.model_dump()) @@ -677,10 +622,19 @@ async def serviceinfo(): @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): - error_check_ret = await _maybe_switch_model( - request.model, raw_request.app.state, raw_request) - if error_check_ret is not None: - return error_check_ret + if hasattr(request, "model"): + error_response = await _handle_model_switch(raw_request, request.model) + if error_response is not None: + return error_response + + if not raw_request.app.state.model_is_loaded: + return JSONResponse( + content={ + "status": "error", + "message": "No model loaded." + }, + status_code=500 + ) generator = await chat(raw_request).create_chat_completion( request, raw_request) if isinstance(generator, ErrorResponse): @@ -696,10 +650,19 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - error_check_ret = await _maybe_switch_model( - request.model, raw_request.app.state, raw_request) - if error_check_ret is not None: - return error_check_ret + if hasattr(request, "model"): + error_response = await _handle_model_switch(raw_request, request.model) + if error_response is not None: + return error_response + + if not raw_request.app.state.model_is_loaded: + return JSONResponse( + content={ + "status": "error", + "message": "No model loaded." + }, + status_code=500 + ) generator = await completion(raw_request).create_completion( request, raw_request) if isinstance(generator, ErrorResponse): @@ -714,7 +677,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): - if not model_is_loaded: + if hasattr(request, "model"): + error_response = await _handle_model_switch(raw_request, request.model) + if error_response is not None: + return error_response + + if not raw_request.app.state.model_is_loaded: return JSONResponse( content={ "status": "error", @@ -733,7 +701,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @router.post("/v1/lora/load") async def load_lora(lora: LoRAModulePath, raw_request: Request): - if not model_is_loaded: + if not raw_request.app.state.model_is_loaded: return JSONResponse( content={ "status": "error", @@ -742,7 +710,7 @@ async def load_lora(lora: LoRAModulePath, raw_request: Request): status_code=500 ) completion(raw_request).add_lora(lora) - if engine_args.enable_lora is False: + if args.enable_lora is False: logger.error("LoRA is not enabled in the engine. " "Please start the server with the " "--enable-lora flag!") @@ -751,7 +719,7 @@ async def load_lora(lora: LoRAModulePath, raw_request: Request): @router.delete("/v1/lora/unload") async def unload_lora(lora_name: str, raw_request: Request): - if not model_is_loaded: + if not raw_request.app.state.model_is_loaded: return JSONResponse( content={ "status": "error", @@ -766,7 +734,7 @@ async def unload_lora(lora_name: str, raw_request: Request): @router.post("/v1/soft_prompt/load") async def load_soft_prompt(soft_prompt: PromptAdapterPath, raw_request: Request): - if not model_is_loaded: + if not raw_request.app.state.model_is_loaded: return JSONResponse( content={ "status": "error", @@ -775,7 +743,7 @@ async def load_soft_prompt(soft_prompt: PromptAdapterPath, status_code=500 ) completion(raw_request).add_prompt_adapter(soft_prompt) - if engine_args.enable_prompt_adapter is False: + if args.enable_prompt_adapter is False: logger.error("Prompt Adapter is not enabled in the engine. " "Please start the server with the " "--enable-prompt-adapter flag!") @@ -783,7 +751,7 @@ async def load_soft_prompt(soft_prompt: PromptAdapterPath, @router.delete("/v1/soft_prompt/unload") async def unload_soft_prompt(soft_prompt_name: str, raw_request: Request): - if not model_is_loaded: + if not raw_request.app.state.model_is_loaded: return JSONResponse( content={ "status": "error", @@ -996,7 +964,7 @@ async def get_max_length() -> JSONResponse: @kai_api.get("/config/max_context_length") @extra_api.get("/true_max_context_length") async def get_max_context_length() -> JSONResponse: - max_context_length = engine_args.max_model_len + max_context_length = args.max_model_len return JSONResponse({"value": max_context_length}) @@ -1040,6 +1008,7 @@ def build_app(args: Namespace) -> FastAPI: app.include_router(router) app.root_path = args.root_path app.state.args = args + app.state.model_is_loaded = False if args.launch_kobold_api: logger.warning("Kobold API is now enabled by default. " "This flag will be removed in the future.") @@ -1123,7 +1092,7 @@ async def authentication(request: Request, call_next): def init_app_state( - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, state: State, args: Namespace, @@ -1149,11 +1118,12 @@ def init_app_state( else: request_logger = RequestLogger(max_log_len=args.max_log_len) - state.engine_client = async_engine_client + state.engine_client = engine_client state.log_stats = not args.disable_log_stats + state.current_model = args.model state.openai_serving_chat = OpenAIServingChat( - async_engine_client, + engine_client, model_config, served_model_names, args.response_role, @@ -1166,7 +1136,7 @@ def init_app_state( tool_parser=args.tool_call_parser ) state.openai_serving_completion = OpenAIServingCompletion( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -1175,13 +1145,13 @@ def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) state.openai_serving_embedding = OpenAIServingEmbedding( - async_engine_client, + engine_client, model_config, served_model_names, request_logger=request_logger, ) state.openai_serving_tokenization = OpenAIServingTokenization( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -1207,19 +1177,21 @@ def signal_handler(*_) -> None: raise KeyboardInterrupt("terminated") signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as async_engine_client: + async with build_engine_client(args) as engine_client: # If None, creation of the client failed and we exit. - if async_engine_client is None: + if engine_client is None: return app = build_app(args) - model_config = await async_engine_client.get_model_config() - init_app_state(async_engine_client, model_config, app.state, args) + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) protocol = "https" if args.ssl_certfile else "http" root_path = args.root_path.rstrip("/") if args.root_path else "" host_name = args.host if args.host else "localhost" port_str = str(args.port) + app.state.model_is_loaded = True + if SERVE_KOBOLD_LITE_UI: ui_url = f"{protocol}://{host_name}:{port_str}{root_path}/" @@ -1233,7 +1205,6 @@ def signal_handler(*_) -> None: shutdown_task = await serve_http( app, - limit_concurrency=async_engine_client.limit_concurrency, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/aphrodite/endpoints/openai/rpc/__init__.py b/aphrodite/endpoints/openai/rpc/__init__.py deleted file mode 100644 index e7a3c240f..000000000 --- a/aphrodite/endpoints/openai/rpc/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - -from aphrodite.common.sampling_params import SamplingParams -from aphrodite.inputs import PromptInputs -from aphrodite.lora.request import LoRARequest -from aphrodite.prompt_adapter.request import PromptAdapterRequest - -# Success string used for RPC instructions. -APHRODITE_RPC_SUCCESS_STR = "SUCCESS" -# Minimum value of ZMQ.SOCKET_LIMIT to run mp. -APHRODITE_RPC_SOCKET_LIMIT_CUTOFF = 2000 -# HWM is set to Infinity. -APHRODITE_RPC_ZMQ_HWM = 0 - - -@dataclass -class RPCGenerateRequest: - inputs: PromptInputs - sampling_params: SamplingParams - request_id: str - lora_request: Optional[LoRARequest] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCUtilityRequest(Enum): - IS_SERVER_READY = 1 - GET_MODEL_CONFIG = 2 - GET_DECODING_CONFIG = 3 - GET_PARALLEL_CONFIG = 4 - GET_SCHEDULER_CONFIG = 5 - GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - SHUTDOWN_SERVER = 9 - -RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] diff --git a/aphrodite/endpoints/openai/rpc/client.py b/aphrodite/endpoints/openai/rpc/client.py deleted file mode 100644 index cdb1e625c..000000000 --- a/aphrodite/endpoints/openai/rpc/client.py +++ /dev/null @@ -1,420 +0,0 @@ -import asyncio -import pickle -from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Iterator, Optional -from uuid import uuid4 - -import cloudpickle -import zmq -import zmq.asyncio -from loguru import logger -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from aphrodite.common.envs import APHRODITE_RPC_GET_DATA_TIMEOUT_MS -from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput -from aphrodite.common.sampling_params import SamplingParams -from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_SOCKET_LIMIT_CUTOFF, - APHRODITE_RPC_SUCCESS_STR, - APHRODITE_RPC_ZMQ_HWM, - RPC_REQUEST_TYPE, RPCAbortRequest, - RPCGenerateRequest, - RPCUtilityRequest) -from aphrodite.inputs import PromptInputs -from aphrodite.lora.request import LoRARequest -from aphrodite.prompt_adapter.request import PromptAdapterRequest -from aphrodite.transformers_utils.tokenizer_group import ( - init_tokenizer_from_configs) - -# Path used for inprocess proxy. -INPROC_PROXY_PATH = f"inproc://{uuid4()}" - - -class RPCClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class AsyncEngineRPCClient: - """ - RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. - - The overall design mirrors the Asynchronous Client Server Pattern - https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern - On startup, the RPCClient: - - makes DEALER socket (to_rpc_server) that connects to the RPCServer - via ipc, which uses unix sockets under the hood - (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) - - makes ROUTER socket (from_api_server) that binds to a random - inproc address, which uses memory under the hood - (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) - - runs a proxy in a background asyncio task between - from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) - Each request handled by the asyncio api_server calls generate(): - - make a DEALER socket that connects to from_api_server via inproc - - send a RCPGenerateRequest to the inproc socket - - background proxy forwards the request from inproc -> ipc - - RPCServer responds to the request one token at a time over ipc - - background proxy forwards the response from ipc -> inproc - The connection looks like this: - DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER - - Message routing is performed via identities that are managed by the - ROUTER socket. ROUTER sockets track every connection it has and - tells the caller about these. The way it tells the caller is to stick - the connection identity in front of each message received. When we - send the message via a ROUTER, we first send an identity frame. - See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope - for more details on connection identities. - This proxy design enables us to use a single unix socket, which - improves performance by avoiding syscalls (~5%) and avoids resource limits - such as ulimit, which defaults to 1024 on ubuntu. - Note: we run set_hwm(0) on each socket, which sets the HWM to inf, - which is required to avoid dropping messages under high load. - This is generally not advisable. However, since we are in control - of both sides of the connection + failure on either side is - catastrophic to the overall system health and memory profiling - suggests limited memory overhead relative to asyncio, we will - proceed for now. - See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks - for more details on high water marks. - """ - - def __init__(self, rpc_path: str): - self.context = zmq.asyncio.Context() - self._data_timeout = APHRODITE_RPC_GET_DATA_TIMEOUT_MS - self._errored = False - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - if socket_limit < APHRODITE_RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests Aphrodite can process. " - "Launch Aphrodite with --disable-frontend-multiprocessing and " - "open a GitHub issue so we can investigate.") - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) - self.to_rpc_server.set_hwm(APHRODITE_RPC_ZMQ_HWM) - self.to_rpc_server.bind(rpc_path) - # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server: Socket = self.context.socket( - zmq.constants.ROUTER) - self.from_api_server.set_hwm(APHRODITE_RPC_ZMQ_HWM) - self.from_api_server.bind(INPROC_PROXY_PATH) - # Asyncio background task for the proxy. - self.proxy_in_task = asyncio.create_task( - self.run_proxy(self.from_api_server, self.to_rpc_server)) - self.proxy_out_task = asyncio.create_task( - self.run_proxy(self.to_rpc_server, self.from_api_server)) - # Since we open 1 inproc socket per request, we have a hard cap on - # the number of requests that can run in Aphrodite w. frontend - # mulitprocessing. This value is used uvicorn to launch - # with --limit-concurrency to return 503 when server is overloaded. - # We need 2 sockets per request - 2: - # 1 for generate(), 1 for abort(), do_log_stats(), check_health() - self.limit_concurrency = socket_limit // 2 - 2 - - async def run_proxy(self, socket_from: Socket, socket_to: Socket): - """Background task that runs a proxy""" - while True: - frames = await socket_from.recv_multipart(copy=False) - await socket_to.send_multipart(frames, copy=False) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Wait until server is ready. - await self._wait_for_server_rpc() - - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.from_api_server.close() - self.to_rpc_server.close() - self.context.destroy() - - - @contextmanager - def to_proxy_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(APHRODITE_RPC_ZMQ_HWM) - try: - socket.connect(INPROC_PROXY_PATH) - yield socket - finally: - socket.close(linger=0) - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, - expected_type: Any, - error_message: str) -> Any: - """Send an RPC request that is expecting data back.""" - - with self.to_proxy_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data - - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): - """Send one-way RPC request to trigger an action.""" - - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) - - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - return pickle.loads(frame.buffer) - - # Make a new socket connection. - if socket is None: - with self.to_proxy_socket() as socket: - response = await do_rpc_call(socket, request) - - # Use existing socket connection. - else: - response = await do_rpc_call(socket, request) - - if not isinstance( - response, str) or response != APHRODITE_RPC_SUCCESS_STR: - if isinstance(response, Exception): - logger.error(error_message) - raise response - raise ValueError(error_message) - - async def get_tokenizer(self, lora_request: LoRARequest): - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def _wait_for_server_rpc(self): - """Wait for the RPCServer to start up.""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _get_model_config_rpc(self) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") - - async def _get_decoding_config_rpc(self) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") - - async def _get_parallel_config_rpc(self) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") - - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") - - async def _get_lora_config_rpc(self) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") - - async def abort(self, request_id: str): - """Send an ABORT_REQUEST signal to the RPC Server""" - # Suppress timeouts as well. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") - - async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") - - @property - def is_running(self) -> bool: - return not self._errored - - @property - def is_stopped(self) -> bool: - return self._errored - - @property - def errored(self) -> bool: - return self._errored - - async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - finished = False - try: - with self.to_proxy_socket() as socket: - - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request)), )) - - # Stream back the results from the RPC Server. - while not finished: - message = await socket.recv(copy=False) - assert isinstance(message, Frame) - request_output = pickle.loads(message.buffer) - - if isinstance(request_output, Exception): - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health(socket=socket) - except Exception as e: - self._errored = True - logger.exception(repr(e)) - # NB: do before raising here so that the flag is set - # by the time the caller receives this exception - raise request_output - - finished = request_output.finished - yield request_output - finally: - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - - async def check_health(self, socket: Optional[Socket] = None) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") - - async def kill(self): - """Cleanly shut down the RPC client and engine.""" - try: - # Send shutdown signal to RPC server - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.SHUTDOWN_SERVER, - error_message="Failed to send shutdown signal to RPC server" - ) - except Exception as e: - logger.error(f"Error while shutting down RPC server: {str(e)}") - finally: - # Close local resources - self.close() \ No newline at end of file diff --git a/aphrodite/endpoints/openai/rpc/server.py b/aphrodite/endpoints/openai/rpc/server.py deleted file mode 100644 index a06595ead..000000000 --- a/aphrodite/endpoints/openai/rpc/server.py +++ /dev/null @@ -1,234 +0,0 @@ -import asyncio -import os -import pickle -import signal -from typing import Any, Coroutine, Union - -import cloudpickle -import zmq -import zmq.asyncio -from loguru import logger -from typing_extensions import Never -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from aphrodite import AsyncAphrodite, AsyncEngineArgs -from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from aphrodite.common.utils import in_windows -from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_SUCCESS_STR, - APHRODITE_RPC_ZMQ_HWM, - RPCAbortRequest, - RPCGenerateRequest, - RPCUtilityRequest) - -if in_windows(): - import winloop as uvloop -else: - import uvloop - - -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, rpc_path: str): - # Initialize engine first. - self.engine = AsyncAphrodite.from_engine_args(async_engine_args) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.DEALER) - self.socket.set_hwm(APHRODITE_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - # Clear the engine reference so that it can be GC'ed. - self.engine = None - - async def get_config(self, identity, request): - try: - config: CONFIG_TYPE - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - config = await self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - config = await self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - config = await self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - config = await self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - config = await self.engine.get_parallel_config() - else: - raise ValueError(f"Unknown Config Request: {request}") - - await self.socket.send_multipart((identity, pickle.dumps(config)), - copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart( - (identity, pickle.dumps(APHRODITE_RPC_SUCCESS_STR))) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart( - (identity, pickle.dumps(APHRODITE_RPC_SUCCESS_STR))) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - try: - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - result: Union[str, Exception] = APHRODITE_RPC_SUCCESS_STR - except Exception as e: - result = e - await self.socket.send_multipart((identity, pickle.dumps(result))) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - try: - results_generator = self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - prompt_adapter_request=generate_request.prompt_adapter_request) - - async for request_output in results_generator: - await self.socket.send_multipart( - (identity, pickle.dumps(request_output)), copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - (identity, pickle.dumps(APHRODITE_RPC_SUCCESS_STR))) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - def _make_handler_coro(self, identity, - message: Frame) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - return self.get_config(identity, request) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: - return self.check_health(identity) - elif request == RPCUtilityRequest.SHUTDOWN_SERVER: - return self.shutdown(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart(copy=False) - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - async def shutdown(self, identity): - """Handle shutdown request from client.""" - try: - # Clean shutdown of engine - self.engine.shutdown_background_loop() - await self.socket.send_multipart( - [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)] - ) - except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) - finally: - # Schedule server shutdown - asyncio.create_task(self._delayed_shutdown()) - - async def _delayed_shutdown(self): - """Helper to shut down server after response is sent""" - await asyncio.sleep(1) - self.cleanup() - # Force exit the process - os._exit(0) - - - -async def run_server(server: AsyncEngineRPCServer): - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("Aphrodite ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - rpc_path: str): - def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing - raise KeyboardInterrupt("AsyncEngineRPCServer terminated") - signal.signal(signal.SIGTERM, signal_handler) - server = AsyncEngineRPCServer(async_engine_args, rpc_path) - uvloop.run(run_server(server)) diff --git a/aphrodite/endpoints/openai/serving_chat.py b/aphrodite/endpoints/openai/serving_chat.py index 8d0dd5b87..c0696886a 100644 --- a/aphrodite/endpoints/openai/serving_chat.py +++ b/aphrodite/endpoints/openai/serving_chat.py @@ -33,7 +33,7 @@ from aphrodite.endpoints.openai.tool_parsers import (Hermes2ProToolParser, MistralToolParser, ToolParser) -from aphrodite.engine.protocol import AsyncEngineClient +from aphrodite.engine.protocol import EngineClient from aphrodite.inputs import TokensPrompt from aphrodite.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer) @@ -42,7 +42,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -54,7 +54,7 @@ def __init__(self, return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, tool_parser: Optional[str] = None): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -102,6 +102,12 @@ async def create_chat_completion( logger.error(f"Error with model {error_check_ret}") return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + try: ( lora_request, @@ -109,8 +115,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) @@ -204,7 +209,7 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - result_generator = self.async_engine_client.generate( + result_generator = self.engine_client.generate( engine_inputs, sampling_params, request_id, diff --git a/aphrodite/endpoints/openai/serving_completions.py b/aphrodite/endpoints/openai/serving_completions.py index 5906be6e7..dfefcd3cc 100644 --- a/aphrodite/endpoints/openai/serving_completions.py +++ b/aphrodite/endpoints/openai/serving_completions.py @@ -19,7 +19,7 @@ from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath) -from aphrodite.engine.protocol import AsyncEngineClient +from aphrodite.engine.protocol import EngineClient from aphrodite.transformers_utils.tokenizer import AnyTokenizer TypeTokenIDs = List[int] @@ -32,7 +32,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -41,7 +41,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -67,6 +67,12 @@ async def create_completion( if error_check_ret is not None: return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + # Return error for unsupported features. if request.suffix is not None: return self.create_error_response( @@ -84,8 +90,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -113,7 +118,7 @@ async def create_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - generator = self.async_engine_client.generate( + generator = self.engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, diff --git a/aphrodite/endpoints/openai/serving_embedding.py b/aphrodite/endpoints/openai/serving_embedding.py index 644b6a8d7..5ddc5f020 100644 --- a/aphrodite/endpoints/openai/serving_embedding.py +++ b/aphrodite/endpoints/openai/serving_embedding.py @@ -17,7 +17,7 @@ EmbeddingResponseData, ErrorResponse, UsageInfo) from aphrodite.endpoints.openai.serving_engine import OpenAIServing -from aphrodite.engine.protocol import AsyncEngineClient +from aphrodite.engine.protocol import EngineClient TypeTokenIDs = List[int] @@ -59,13 +59,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -106,9 +106,7 @@ async def create_embedding( lora_request, prompt_adapter_request, ) = self._maybe_get_adapters(request) - - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() prompts = list( @@ -132,7 +130,7 @@ async def create_embedding( "Prompt adapter is not supported " "for embedding models") - generator = self.async_engine_client.encode( + generator = self.engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, diff --git a/aphrodite/endpoints/openai/serving_engine.py b/aphrodite/endpoints/openai/serving_engine.py index a5cb69eda..b020cb45c 100644 --- a/aphrodite/endpoints/openai/serving_engine.py +++ b/aphrodite/endpoints/openai/serving_engine.py @@ -27,7 +27,7 @@ TokenizeCompletionRequest, TokenizeRequest) # yapf: enable -from aphrodite.engine.protocol import AsyncEngineClient +from aphrodite.engine.protocol import EngineClient from aphrodite.inputs.parse import parse_and_batch_prompt from aphrodite.lora.request import LoRARequest from aphrodite.modeling.guided_decoding import ( @@ -62,7 +62,7 @@ class OpenAIServing: def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -73,7 +73,7 @@ def __init__( ): super().__init__() - self.async_engine_client = async_engine_client + self.engine_client = engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -156,7 +156,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessorFunc]: - decoding_config = await self.async_engine_client.get_decoding_config() + decoding_config = await self.engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/aphrodite/endpoints/openai/serving_tokenization.py b/aphrodite/endpoints/openai/serving_tokenization.py index 4db0bf792..4adda64fa 100644 --- a/aphrodite/endpoints/openai/serving_tokenization.py +++ b/aphrodite/endpoints/openai/serving_tokenization.py @@ -20,7 +20,7 @@ # yapf: enable from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) -from aphrodite.engine.protocol import AsyncEngineClient +from aphrodite.engine.protocol import EngineClient from aphrodite.transformers_utils.tokenizer import MistralTokenizer @@ -28,7 +28,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -36,7 +36,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -65,7 +65,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) prompt: Union[str, List[int]] if isinstance(request, TokenizeChatRequest): @@ -131,7 +131,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, diff --git a/aphrodite/engine/aphrodite_engine.py b/aphrodite/engine/aphrodite_engine.py index b0d7920c0..f7a2dc887 100644 --- a/aphrodite/engine/aphrodite_engine.py +++ b/aphrodite/engine/aphrodite_engine.py @@ -1185,6 +1185,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # torch.distributed ops which may otherwise timeout, and unblocks # the RPC thread in the workers so that they can process any other # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() return ctx.request_outputs @@ -1498,6 +1499,22 @@ def check_health(self) -> None: self.tokenizer.check_health() self.model_executor.check_health() + def shutdown(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() + if hasattr(self, 'tokenizer') and self.tokenizer is not None: + self.tokenizer = None + if hasattr(self, 'scheduler'): + self.scheduler.clear() + if hasattr(self, 'cached_scheduler_outputs'): + self.cached_scheduler_outputs.clear() + if hasattr(self, 'scheduler_contexts'): + self.scheduler_contexts.clear() + if hasattr(self, 'stat_loggers'): + self.stat_loggers.clear() + if hasattr(self, 'model_executor'): + self.model_executor.shutdown() + + def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() diff --git a/aphrodite/engine/async_aphrodite.py b/aphrodite/engine/async_aphrodite.py index aeaade8b0..c3c205830 100644 --- a/aphrodite/engine/async_aphrodite.py +++ b/aphrodite/engine/async_aphrodite.py @@ -592,9 +592,12 @@ def errored(self) -> bool: return self._errored_with is not None @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - return None + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/aphrodite/engine/multiprocessing/__init__.py b/aphrodite/engine/multiprocessing/__init__.py new file mode 100644 index 000000000..5d5d20b7b --- /dev/null +++ b/aphrodite/engine/multiprocessing/__init__.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Mapping, Optional, Union + +from aphrodite.common.outputs import RequestOutput +from aphrodite.common.sampling_params import SamplingParams +from aphrodite.inputs import PromptInputs +from aphrodite.lora.request import LoRARequest +from aphrodite.prompt_adapter.request import PromptAdapterRequest + +APHRODITE_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCHealthRequest: + pass + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +@dataclass +class RPCShutdownRequest: + pass + + +RPC_REQUEST_T = Union[ + RPCGenerateRequest, + RPCAbortRequest, + RPCHealthRequest, + RPCStartupRequest, + RPCShutdownRequest, +] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/aphrodite/engine/multiprocessing/client.py b/aphrodite/engine/multiprocessing/client.py new file mode 100644 index 000000000..64de3c23e --- /dev/null +++ b/aphrodite/engine/multiprocessing/client.py @@ -0,0 +1,460 @@ +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) + +import cloudpickle +import zmq +import zmq.asyncio +from loguru import logger +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from aphrodite.common.config import DecodingConfig, EngineConfig, ModelConfig +from aphrodite.common.envs import APHRODITE_RPC_TIMEOUT +from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput +from aphrodite.common.sampling_params import SamplingParams +from aphrodite.engine.args_tools import AsyncEngineArgs +from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR, + ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + RPCAbortRequest, RPCError, + RPCGenerateRequest, + RPCHealthRequest, + RPCStartupRequest, + RPCStartupResponse) +from aphrodite.inputs import PromptInputs +from aphrodite.lora.request import LoRARequest +from aphrodite.prompt_adapter.request import PromptAdapterRequest +from aphrodite.transformers_utils.tokenizer_group import ( + init_tokenizer_from_configs) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQAphroditeEngineClient: + """A client wrapper for MQAphroditeEngine that conforms to the + EngineClient protocol. + + MQAphroditeEngine and MQAphroditeEngineClient are intended to run in + separate processes communicating via zeromq ipc sockets. + + The entrypoint to MQAphroditeEngineClient is through the generate() + method. On generate() MQAphroditeEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQAphroditeEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQAphroditeEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQAphroditeEngine via zmq (each list is the output + of one engine_step in the AphroditeEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: EngineConfig): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + enable_lora=bool(engine_config.lora_config), + ) + + # Send RPCGenerateRequest to the MQAphroditeEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQAphroditeEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for ack of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + # Loop to check health of the AphroditeEngine periodically. + # Started after the MQAphroditeEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + + @staticmethod + def is_unsupported_config(engine_args: AsyncEngineArgs): + if engine_args.pipeline_parallel_size > 1: + return True + + is_embedding = ModelConfig( + model=engine_args.model, + revision=engine_args.revision, + tokenizer=engine_args.model, + tokenizer_mode="auto", + trust_remote_code=engine_args.trust_remote_code, + quantization=engine_args.quantization, + seed=0, + dtype="auto").embedding_mode + + return is_embedding + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_check_health_loop(self, timeout: int): + """Background loop that continually probes the RPCServer for health. + + The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which + the MQAphroditeEngine server is blocking on. + + The Server replies on the HEALTH_SOCKET (rather than on the + OUTPUT_SOCKET such that the messages are not intermingled with + output streaming). + """ + + try: + while True: + if await self.health_socket.poll(timeout=timeout) == 0: + # Wakeup every N seconds and do a health probe. + await self._send_one_way_rpc_request( + RPCHealthRequest(), self.input_socket) + + # Wait for ack from the health socket. + await self._await_ack(error_message="Health check failed.", + socket=self.health_socket) + else: + # Server sent a health status message unprompted. + await self._check_success( + error_message="Health check failed.", + socket=self.health_socket) + + logger.debug("Health probe successful.") + + except asyncio.CancelledError: + logger.debug( + "Shutting down MQAphroditeEngineClient check health loop.") + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll( + timeout=APHRODITE_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQAphroditeEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPAphroditeEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + f"Received Exception {error} rather than RPCError " + "from MPAphroditeEngine. This should never happen.") + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + + if request_id is None: + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + else: + # Put each output into the appropriate steam. + for request_output in request_outputs: + queue = self.output_queues.get( + request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + except asyncio.CancelledError: + logger.debug( + "Shutting down MQAphroditeEngineClient output handler.") + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + self.health_loop = asyncio.create_task( + self.run_check_health_loop(timeout=APHRODITE_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=APHRODITE_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{APHRODITE_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=APHRODITE_RPC_TIMEOUT) == 0: + raise TimeoutError("MQAphroditeEngine didn't reply within " + f"{APHRODITE_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a APHRODITE_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != APHRODITE_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats(self): + """Ignore do_log_stats (handled on MQAphroditeEngine polling)""" + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[RequestOutput, None]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if sampling_params.logits_processors: + # Defensive shallow copy + sampling_params = copy.copy(sampling_params) + logits_processors = sampling_params.logits_processors + sampling_params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + + # 3) Send the RPCGenerateRequest to the MQAphroditeEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def encode(self, *args, + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") \ No newline at end of file diff --git a/aphrodite/engine/multiprocessing/engine.py b/aphrodite/engine/multiprocessing/engine.py new file mode 100644 index 000000000..836a25fe9 --- /dev/null +++ b/aphrodite/engine/multiprocessing/engine.py @@ -0,0 +1,325 @@ +import os +import pickle +import signal +import sys +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq +from loguru import logger + +from aphrodite import AphroditeEngine, AsyncEngineArgs +from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from aphrodite.common.outputs import RequestOutput +from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR, + ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, + REQUEST_OUTPUTS_T, + RPCAbortRequest, RPCError, + RPCGenerateRequest, + RPCHealthRequest, + RPCShutdownRequest, + RPCStartupRequest, + RPCStartupResponse) + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(APHRODITE_RPC_SUCCESS_STR), ) + + +class MQAphroditeEngine: + """A multiprocessing wrapper for :class:`AphroditeEngine`. + + This class is used to wrap the :class:`AphroditeEngine` class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The :class:`AphroditeEngine.generate` is kicked off when a new + RPCGenerateRequest is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the AphroditeEngine if there are any, calls the internal + :class:`AphroditeEngine.step()`, and sends the RequestOutputs back over + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for :class:`AphroditeEngine`. + **kwargs: Arguments for :class:`AphroditeEngine`. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + self.engine = AphroditeEngine(*args, **kwargs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, ipc_path: str): + """Creates an MQAphroditeEngine from the engine arguments.""" + + engine_config = engine_args.create_engine_config() + + executor_class = AphroditeEngine._get_executor_cls(engine_config) + + return cls( + ipc_path=ipc_path, + use_async_sockets=engine_config.model_config.use_async_output_proc, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQAphroditeEngine.") + finally: + logger.debug("MQAphroditeEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + response = RPCStartupResponse( + tracing_enabled=False) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the AphroditeEngine.""" + + while True: + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + + try: + return self.engine.step() + except SystemExit: + raise + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCGenerateRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + lprocs = cloudpickle.loads(frames[1].buffer) + request.sampling_params.logits_processors = lprocs + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCHealthRequest): + self._handle_health_request() + elif isinstance(request, RPCShutdownRequest): + self.engine.shutdown() + self._send_outputs(APHRODITE_RPC_SUCCESS_STR) + break + else: + raise ValueError("Unknown RPCRequest Type: {request}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e + + def _handle_generate_request(self, request: RPCGenerateRequest): + """Handle RPCGenerateRequest by adding it to the AphroditeEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + prompt_adapter_request=request.prompt_adapter_request) + + if self.log_requests: + logger.info(f"Added request {request.request_id}.") + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info(f"Aborted request {request.request_id}.") + + def _handle_health_request(self): + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + + # Raises error if unhealthy. + self.engine.check_health() + self._send_healthy() + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send List of RequestOutput to RPCClient.""" + if outputs: + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + error_bytes = pickle.dumps(error) + self.health_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + +def run_mp_engine(engine_args: AsyncEngineArgs, ipc_path: str): + def signal_handler(*_) -> None: + with open(os.devnull, 'w') as devnull: + sys.stderr = devnull + raise KeyboardInterrupt("MQAphroditeEngine terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + try: + engine = MQAphroditeEngine.from_engine_args(engine_args=engine_args, + ipc_path=ipc_path) + engine.start() + except KeyboardInterrupt as e: + if str(e) == "MQAphroditeEngine terminated": + pass + else: + raise diff --git a/aphrodite/engine/protocol.py b/aphrodite/engine/protocol.py index 2f6451620..50b4026e3 100644 --- a/aphrodite/engine/protocol.py +++ b/aphrodite/engine/protocol.py @@ -14,8 +14,8 @@ @runtime_checkable -class AsyncEngineClient(Protocol): - """Protocol class for Clients to AsyncAphrodite""" +class EngineClient(Protocol): + """Protocol class for Clients to Engine""" @property def is_running(self) -> bool: @@ -30,8 +30,7 @@ def errored(self) -> bool: ... @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" + def dead_error(self) -> BaseException: ... def generate( diff --git a/aphrodite/executor/cpu_executor.py b/aphrodite/executor/cpu_executor.py index 0b42a3de2..1bcd82437 100644 --- a/aphrodite/executor/cpu_executor.py +++ b/aphrodite/executor/cpu_executor.py @@ -106,6 +106,7 @@ def _init_executor(self) -> None: )) for rank in range(1, world_size) ] + self.worker_monitor = None if world_size != 1 or is_async: if is_async: async_worker_list = self.workers + [self.driver_worker] diff --git a/aphrodite/executor/multiproc_worker_utils.py b/aphrodite/executor/multiproc_worker_utils.py index 3d75c10ed..5dd42a2e3 100644 --- a/aphrodite/executor/multiproc_worker_utils.py +++ b/aphrodite/executor/multiproc_worker_utils.py @@ -169,6 +169,8 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], self.tasks[task_id] = future try: self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise except BaseException as e: del self.tasks[task_id] raise ChildProcessError("worker died") from e @@ -223,6 +225,8 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except SystemExit: + raise except KeyboardInterrupt: break except BaseException as e: diff --git a/aphrodite/server/launch.py b/aphrodite/server/launch.py index ef6453133..0a975c3f0 100644 --- a/aphrodite/server/launch.py +++ b/aphrodite/server/launch.py @@ -1,7 +1,7 @@ import asyncio import signal from http import HTTPStatus -from typing import Any, Optional +from typing import Any import uvicorn from fastapi import FastAPI, Request, Response @@ -10,23 +10,13 @@ import aphrodite.common.envs as envs from aphrodite.common.utils import find_process_using_port, in_windows from aphrodite.engine.async_aphrodite import AsyncEngineDeadError +from aphrodite.engine.multiprocessing import MQEngineDeadError APHRODITE_KEEP_ALIVE_ON_ENGINE_DEATH = ( envs.APHRODITE_KEEP_ALIVE_ON_ENGINE_DEATH) -async def serve_http(app: FastAPI, limit_concurrency: Optional[int], - **uvicorn_kwargs: Any): - - # Set concurrency limits in uvicorn if running in multiprocessing mode - # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if limit_concurrency is not None: - logger.info( - "Launching Uvicorn with --limit_concurrency " - f"{limit_concurrency}. " - f"To avoid this limit at the expense of performance run with " - "--disable-frontend-multiprocessing", limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = limit_concurrency +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) @@ -62,7 +52,7 @@ async def dummy_shutdown() -> None: logger.info( f"port {port} is used by process {process} launched with " f"command:\n{' '.join(process.cmdline())}") - logger.info("Gracefully stopping http server") + logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() @@ -89,7 +79,7 @@ async def runtime_error_handler(request: Request, __): return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) @app.exception_handler(AsyncEngineDeadError) - async def engine_dead_handler(_, __): + async def async_engine_dead_handler(_, __): """Kill the server if the async engine is already dead. It will not handle any further requests.""" if not APHRODITE_KEEP_ALIVE_ON_ENGINE_DEATH: @@ -98,3 +88,14 @@ async def engine_dead_handler(_, __): server.should_exit = True return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.APHRODITE_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.error("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/tests/benchmarks/engine/throughput.py b/tests/benchmarks/engine/throughput.py index 113b44005..d5f0fb3dd 100644 --- a/tests/benchmarks/engine/throughput.py +++ b/tests/benchmarks/engine/throughput.py @@ -14,7 +14,7 @@ from aphrodite.common.utils import (FlexibleArgumentParser, merge_async_iterators) from aphrodite.endpoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_engine_client_from_engine_args) from aphrodite.engine.args_tools import (DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs) from aphrodite.quantization import QUANTIZATION_METHODS @@ -192,12 +192,10 @@ async def run_aphrodite_async( num_scheduler_steps=num_scheduler_steps, use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, - worker_use_ray=False, - engine_use_ray=False, disable_log_requests=True, ) - async with build_async_engine_client_from_engine_args( + async with build_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: # Add the requests to the engine.