Skip to content

Commit

Permalink
api: optimize zeromq frontend performance (#951)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 21, 2024
1 parent cef6da8 commit 39b2e83
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 57 deletions.
81 changes: 39 additions & 42 deletions aphrodite/endpoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import pickle
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncGenerator, Iterator, Optional
from uuid import uuid4

import cloudpickle
import zmq
import zmq.asyncio
from loguru import logger
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket

from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
Expand Down Expand Up @@ -103,16 +106,19 @@ def __init__(self, rpc_path: str):
# not run into ulimit issues)
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(APHRODITE_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
self.from_api_server: Socket = self.context.socket(
zmq.constants.ROUTER)
self.from_api_server.set_hwm(APHRODITE_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task(
self.proxy_in_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
self.proxy_out_task = asyncio.create_task(
self.run_proxy(self.to_rpc_server, self.from_api_server))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in Aphrodite w. frontend
# mulitprocessing. This value is used uvicorn to launch
Expand All @@ -121,20 +127,11 @@ def __init__(self, rpc_path: str):
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2

async def run_proxy(self, socket_from, socket_to):
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
"""Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])
frames = await socket_from.recv_multipart(copy=False)
await socket_to.send_multipart(frames, copy=False)

async def setup(self):
"""Setup the client before it starts sending server requests."""
Expand Down Expand Up @@ -165,7 +162,7 @@ def close(self):


@contextmanager
def to_proxy_socket(self):
def to_proxy_socket(self) -> Iterator[Socket]:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
Expand All @@ -191,15 +188,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ),
copy=False)

# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
data = pickle.loads(frame.buffer)

if isinstance(data, Exception):
# Re-raise exceptions returned by the server
Expand All @@ -217,23 +217,23 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

return data

async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[zmq.asyncio.Socket] = None):
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[Socket] = None):
"""Send one-way RPC request to trigger an action."""

async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE):
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):

await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ))

if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

return cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
return pickle.loads(frame.buffer)

# Make a new socket connection.
if socket is None:
Expand Down Expand Up @@ -358,20 +358,19 @@ async def generate(
with self.to_proxy_socket() as socket:

# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request))
])
await socket.send_multipart((cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)), ))

# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv()
request_output = cloudpickle.loads(message)
message = await socket.recv(copy=False)
assert isinstance(message, Frame)
request_output = pickle.loads(message.buffer)

if isinstance(request_output, Exception):
# On exception, check if the server is still healthy
Expand All @@ -393,9 +392,7 @@ async def generate(
if not finished and not self._errored:
await self.abort(request_id)

async def check_health(self,
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
async def check_health(self, socket: Optional[Socket] = None) -> None:
"""Raise if unhealthy"""

await self._send_one_way_rpc_request(
Expand Down
36 changes: 21 additions & 15 deletions aphrodite/endpoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import pickle
import signal
from typing import Any, Coroutine, Union

Expand All @@ -8,6 +9,8 @@
import zmq.asyncio
from loguru import logger
from typing_extensions import Never
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket

from aphrodite import AsyncAphrodite, AsyncEngineArgs
from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
Expand Down Expand Up @@ -39,7 +42,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs, rpc_path: str):
self.context = zmq.asyncio.Context()

# Init socket.
self.socket = self.context.socket(zmq.constants.DEALER)
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(APHRODITE_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)

Expand Down Expand Up @@ -67,23 +70,24 @@ async def get_config(self, identity, request):
else:
raise ValueError(f"Unknown Config Request: {request}")

await self.socket.send_multipart(
[identity, cloudpickle.dumps(config)])
await self.socket.send_multipart((identity, pickle.dumps(config)),
copy=False)

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
(identity, pickle.dumps(APHRODITE_RPC_SUCCESS_STR)))

async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart(
[identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
(identity, pickle.dumps(APHRODITE_RPC_SUCCESS_STR)))

async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
Expand All @@ -93,7 +97,7 @@ async def abort(self, identity, request: RPCAbortRequest):
result: Union[str, Exception] = APHRODITE_RPC_SUCCESS_STR
except Exception as e:
result = e
await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
await self.socket.send_multipart((identity, pickle.dumps(result)))

async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
Expand All @@ -106,25 +110,27 @@ async def generate(self, identity, generate_request: RPCGenerateRequest):

async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
(identity, pickle.dumps(request_output)), copy=False)

except Exception as e:
### Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
(identity, pickle.dumps(APHRODITE_RPC_SUCCESS_STR)))

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)

def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
message: Frame) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""

request = cloudpickle.loads(message)
request = cloudpickle.loads(message.buffer)

if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
Expand Down Expand Up @@ -161,7 +167,7 @@ async def run_server_loop(self):
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
identity, message = await self.socket.recv_multipart(copy=False)

# Process the request async.
task = asyncio.create_task(
Expand Down

0 comments on commit 39b2e83

Please sign in to comment.