Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler methods #13

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
default_language_version:
python: python3.9
python: python3.10
josephydu marked this conversation as resolved.
Show resolved Hide resolved

repos:
- repo: https://github.com/PyCQA/isort
Expand Down
179 changes: 174 additions & 5 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import logging
import multiprocessing as mp
import multiprocessing.connection
from enum import Enum, auto

import zmq

from sglang.srt.managers.io_struct import (
ControllerInfo,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
Expand All @@ -37,12 +39,42 @@

logger = logging.getLogger(__name__)

import random


# for pre radix scheduler
def _key_match(key0, key1):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
break
i += 1
return i


def get_match_len(node, key, match_length: int) -> int:
if len(key) == 0:
return match_length

if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
match_length += prefix_len
if prefix_len < len(child.key):
return match_length
else:
return get_match_len(child, key[prefix_len:], match_length)
else:
return match_length


class LoadBalanceMethod(Enum):
"""Load balance method."""

ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
RESOURCES_AWARE = auto()
PRE_RADIX = auto()

@classmethod
def from_str(cls, method: str):
Expand Down Expand Up @@ -74,9 +106,26 @@ def __init__(self, server_args, port_args) -> None:
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
LoadBalanceMethod.RESOURCES_AWARE: self.resources_aware_scheduler,
LoadBalanceMethod.PRE_RADIX: self.pre_radix_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]

# For resources aware
self.dp_size = server_args.dp_size
self.controller_info = ControllerInfo(server_args.dp_size)
self.pre_available_kv_cache = []
self.main_available_kv_cache = []

self.pre_num_running_req = []
self.main_num_running_req = []

self.pre_num_waiting_req = []
self.main_num_waiting_req = []

# For pre_radix
self.zmq_raidx = server_args.load_balance_method == "pre_radix"

# Start data parallel workers
base_gpu_id = 0
self.workers = []
Expand All @@ -85,21 +134,32 @@ def __init__(self, server_args, port_args) -> None:
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name

send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
server_args, tmp_port_args, base_gpu_id, dp_rank, self.controller_info
)

self.workers.append(send_to)
base_gpu_id += server_args.tp_size

if self.zmq_raidx:
import threading

self.newest_tree_cache = {}

self.recv_tree_cache_lock = threading.Lock()
self.recv_tree_cache_thread = threading.Thread(
target=self.loop_for_recv_tree_cache
)
else:
self.newest_tree_cache = None
self.recv_tree_cache_thread = None

def launch_tensor_parallel_group(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
controller_info: ControllerInfo,
):
# Launch tensor parallel scheduler processes
scheduler_procs = []
Expand All @@ -114,7 +174,15 @@ def launch_tensor_parallel_group(
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
args=(
server_args,
port_args,
gpu_id,
tp_rank,
dp_rank,
writer,
controller_info,
),
)
proc.start()
scheduler_procs.append(proc)
Expand All @@ -129,10 +197,108 @@ def launch_tensor_parallel_group(

return send_to

def loop_for_recv_tree_cache(self):
while True:
self.recv_tree_cache()

def recv_tree_cache(self):
while True:
recv_radix_cache = self.controller_info.radix_queue.get()
if recv_radix_cache:
# logger.info('[recv_tree_cache] receive new data')
gpu_id = recv_radix_cache.gpu_id
if (
gpu_id not in self.newest_tree_cache
or recv_radix_cache.time > self.newest_tree_cache[gpu_id].time
):
with self.recv_tree_cache_lock:
if gpu_id in self.newest_tree_cache:
del self.newest_tree_cache[gpu_id]
self.newest_tree_cache[gpu_id] = recv_radix_cache
del recv_radix_cache

def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)

def update_memory_and_requests(self):
available_mem = [k.value for k in self.controller_info.available_kv_cache]
num_reqs_running = [k.value for k in self.controller_info.running_reqs]
num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs]

if not self.pre_available_kv_cache:
self.pre_available_kv_cache = available_mem.copy()
if not self.main_available_kv_cache:
self.main_available_kv_cache = available_mem.copy()
if self.pre_available_kv_cache != available_mem:
self.pre_available_kv_cache = available_mem.copy()
self.main_available_kv_cache = available_mem.copy()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only need remained one copy


if not self.pre_num_running_req:
self.pre_num_running_req = num_reqs_running.copy()
if not self.main_num_running_req:
self.main_num_running_req = num_reqs_running.copy()
if self.pre_num_running_req != num_reqs_running:
self.main_num_running_req = num_reqs_running.copy()
self.pre_num_running_req = num_reqs_running.copy()

if not self.pre_num_waiting_req:
self.pre_num_waiting_req = num_reqs_waiting.copy()
if not self.main_num_waiting_req:
self.main_num_waiting_req = num_reqs_waiting.copy()
if self.pre_num_waiting_req != num_reqs_waiting:
self.main_num_waiting_req = num_reqs_waiting.copy()
self.pre_num_waiting_req = num_reqs_waiting.copy()

def allocate_gpu(self, req):
all_waiting = min(self.main_num_waiting_req) > 0
no_waiting = [1 if waiting == 0 else 0 for waiting in self.main_num_waiting_req]

if all_waiting:
ratio = [
run / wait
for run, wait in zip(
self.main_num_running_req, self.main_num_waiting_req
)
]
max_ratio = max(ratio)
indices = [i for i, x in enumerate(ratio) if x == max_ratio]
gpu_idx = random.choice(indices)
else:
filter_result = [
a * b for a, b in zip(no_waiting, self.main_available_kv_cache)
]
max_value = max(filter_result)
max_indices = [
index for index, value in enumerate(filter_result) if value == max_value
]
gpu_idx = random.choice(max_indices)

self.main_num_waiting_req[gpu_idx] += 1
self.main_available_kv_cache[gpu_idx] -= len(req.input_ids)
return gpu_idx

def resources_aware_scheduler(self, req):
self.update_memory_and_requests()
gpu_idx = self.allocate_gpu(req)
self.workers[gpu_idx].send_pyobj(req)

def pre_radix_scheduler(self, req):
prefix_lens = [0] * self.dp_size

with self.recv_tree_cache_lock:
for gpu_id, radix_cache in self.newest_tree_cache.items():
pre_len = get_match_len(radix_cache.root_node, req.input_ids, 0)
prefix_lens[gpu_id] = pre_len

# NOTE: 100 is used to reduce the influence of random input
# e.g. If the match nums is [1, 2, 0, 0, 0, 0], we think the scheduer method should be resources aware
if max(prefix_lens) <= 100:
self.resources_aware_scheduler(req)
else:
gpu_idx = prefix_lens.index(max(prefix_lens))
self.workers[gpu_idx].send_pyobj(req)

def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()

Expand All @@ -144,6 +310,7 @@ def event_loop(self):
except zmq.ZMQError:
break

# logger.info(f"[event_loop]{type(recv_req)}")
if isinstance(
recv_req,
(
Expand All @@ -170,6 +337,8 @@ def run_data_parallel_controller_process(
try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send("ready")
if controller.recv_tree_cache_thread:
controller.recv_tree_cache_thread.start()
controller.event_loop()
except Exception:
msg = get_exception_traceback()
Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
processes (TokenizerManager, DetokenizerManager, Controller).
"""

import multiprocessing
import uuid
from dataclasses import dataclass
from enum import Enum
from multiprocessing import Value
from typing import Dict, List, Optional, Union

from sglang.srt.managers.schedule_batch import BaseFinishReason
Expand Down Expand Up @@ -353,3 +355,19 @@ class AbortReq:
class ProfileReq(Enum):
START_PROFILE = 1
STOP_PROFILE = 2


class ControllerInfo:
def __init__(self, dp_size):
self.available_kv_cache = []
self.running_reqs = []
self.waiting_reqs = []
self.lock = multiprocessing.Lock()

# For pre radix
self.radix_queue = multiprocessing.Queue()

for i in range(dp_size):
self.available_kv_cache.append(Value("i", 0))
self.running_reqs.append(Value("i", 0))
self.waiting_reqs.append(Value("i", 0))
Loading
Loading