From 53b8f0fd34c1519b2ad7520c6f5f5e9f6f4f1ce0 Mon Sep 17 00:00:00 2001 From: YouYangxiu Date: Mon, 21 Oct 2024 16:05:34 +0800 Subject: [PATCH 1/4] add two new load-balance-method --- .pre-commit-config.yaml | 2 +- .../srt/managers/data_parallel_controller.py | 182 +++++++++++++++++- python/sglang/srt/managers/io_struct.py | 18 ++ python/sglang/srt/managers/scheduler.py | 60 +++++- python/sglang/srt/mem_cache/radix_cache.py | 8 + python/sglang/srt/server_args.py | 3 +- 6 files changed, 264 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7489004bd0..2bd3402981 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.9 + python: python3.10 repos: - repo: https://github.com/PyCQA/isort diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 1b7da747f1..87b2ab926a 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -17,11 +17,13 @@ import logging import multiprocessing as mp +import multiprocessing.connection from enum import Enum, auto import zmq from sglang.srt.managers.io_struct import ( + ControllerInfo, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, TokenizedRewardReqInput, @@ -37,12 +39,42 @@ logger = logging.getLogger(__name__) +import random + + +# for pre radix scheduler +def _key_match(key0, key1): + i = 0 + for k0, k1 in zip(key0, key1): + if k0 != k1: + break + i += 1 + return i + + +def get_match_len(node, key, match_length: int) -> int: + if len(key) == 0: + return match_length + + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = _key_match(child.key, key) + match_length += prefix_len + if prefix_len < len(child.key): + return match_length + else: + return get_match_len(child, key[prefix_len:], match_length) + else: + return match_length + class LoadBalanceMethod(Enum): """Load balance method.""" ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() + RESOURCES_AWARE = auto() + PRE_RADIX = auto() @classmethod def from_str(cls, method: str): @@ -74,9 +106,29 @@ def __init__(self, server_args, port_args) -> None: dispatch_lookup = { LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + LoadBalanceMethod.RESOURCES_AWARE: self.resources_aware_scheduler, + LoadBalanceMethod.PRE_RADIX: self.pre_radix_scheduler, } self.dispatching = dispatch_lookup[self.load_balance_method] + # For resources aware + self.dp_size = server_args.dp_size + self.controller_info = ControllerInfo(server_args.dp_size) + self.pre_available_kv_cache = [] + self.main_available_kv_cache = [] + + self.pre_num_running_req = [] + self.main_num_running_req = [] + + self.pre_num_waiting_req = [] + self.main_num_waiting_req = [] + + # For pre_radix + self.choosen_gpu_per_req = {} + + # For zmq_radix + self.zmq_raidx = server_args.load_balance_method == "zmq_radix" + # Start data parallel workers base_gpu_id = 0 self.workers = [] @@ -85,21 +137,32 @@ def __init__(self, server_args, port_args) -> None: tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name send_to = self.launch_tensor_parallel_group( - server_args, - tmp_port_args, - base_gpu_id, - dp_rank, + server_args, tmp_port_args, base_gpu_id, dp_rank, self.controller_info ) self.workers.append(send_to) base_gpu_id += server_args.tp_size + if self.zmq_raidx: + import threading + + self.newest_tree_cache = {} + + self.recv_tree_cache_lock = threading.Lock() + self.recv_tree_cache_thread = threading.Thread( + target=self.loop_for_recv_tree_cache + ) + else: + self.newest_tree_cache = None + self.recv_tree_cache_thread = None + def launch_tensor_parallel_group( self, server_args: ServerArgs, port_args: PortArgs, base_gpu_id: int, dp_rank: int, + controller_info: ControllerInfo, ): # Launch tensor parallel scheduler processes scheduler_procs = [] @@ -114,7 +177,15 @@ def launch_tensor_parallel_group( gpu_id = base_gpu_id + tp_rank % tp_size_per_node proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), + args=( + server_args, + port_args, + gpu_id, + tp_rank, + dp_rank, + writer, + controller_info, + ), ) proc.start() scheduler_procs.append(proc) @@ -129,10 +200,108 @@ def launch_tensor_parallel_group( return send_to + def loop_for_recv_tree_cache(self): + while True: + self.recv_tree_cache() + + def recv_tree_cache(self): + while True: + recv_radix_cache = self.controller_info.radix_queue.get() + if recv_radix_cache: + # logger.info('[recv_tree_cache] receive new data') + gpu_id = recv_radix_cache.gpu_id + if ( + gpu_id not in self.newest_tree_cache + or recv_radix_cache.time > self.newest_tree_cache[gpu_id].time + ): + with self.recv_tree_cache_lock: + if gpu_id in self.newest_tree_cache: + del self.newest_tree_cache[gpu_id] + self.newest_tree_cache[gpu_id] = recv_radix_cache + del recv_radix_cache + def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) + def update_memory_and_requests(self): + available_mem = [k.value for k in self.controller_info.available_kv_cache] + num_reqs_running = [k.value for k in self.controller_info.running_reqs] + num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + + if not self.pre_available_kv_cache: + self.pre_available_kv_cache = available_mem.copy() + if not self.main_available_kv_cache: + self.main_available_kv_cache = available_mem.copy() + if self.pre_available_kv_cache != available_mem: + self.pre_available_kv_cache = available_mem.copy() + self.main_available_kv_cache = available_mem.copy() + + if not self.pre_num_running_req: + self.pre_num_running_req = num_reqs_running.copy() + if not self.main_num_running_req: + self.main_num_running_req = num_reqs_running.copy() + if self.pre_num_running_req != num_reqs_running: + self.main_num_running_req = num_reqs_running.copy() + self.pre_num_running_req = num_reqs_running.copy() + + if not self.pre_num_waiting_req: + self.pre_num_waiting_req = num_reqs_waiting.copy() + if not self.main_num_waiting_req: + self.main_num_waiting_req = num_reqs_waiting.copy() + if self.pre_num_waiting_req != num_reqs_waiting: + self.main_num_waiting_req = num_reqs_waiting.copy() + self.pre_num_waiting_req = num_reqs_waiting.copy() + + def allocate_gpu(self, req): + all_waiting = min(self.main_num_waiting_req) > 0 + no_waiting = [1 if waiting == 0 else 0 for waiting in self.main_num_waiting_req] + + if all_waiting: + ratio = [ + run / wait + for run, wait in zip( + self.main_num_running_req, self.main_num_waiting_req + ) + ] + max_ratio = max(ratio) + indices = [i for i, x in enumerate(ratio) if x == max_ratio] + gpu_idx = random.choice(indices) + else: + filter_result = [ + a * b for a, b in zip(no_waiting, self.main_available_kv_cache) + ] + max_value = max(filter_result) + max_indices = [ + index for index, value in enumerate(filter_result) if value == max_value + ] + gpu_idx = random.choice(max_indices) + + self.main_num_waiting_req[gpu_idx] += 1 + self.main_available_kv_cache[gpu_idx] -= len(req.input_ids) + return gpu_idx + + def resources_aware_scheduler(self, req): + self.update_memory_and_requests() + gpu_idx = self.allocate_gpu(req) + self.workers[gpu_idx].send_pyobj(req) + + def pre_radix_scheduler(self, req): + prefix_lens = [0] * self.dp_size + + with self.recv_tree_cache_lock: + for gpu_id, radix_cache in self.newest_tree_cache.items(): + pre_len = get_match_len(radix_cache.root_node, req.input_ids, 0) + prefix_lens[gpu_id] = pre_len + + # NOTE: 100 is used to reduce the influence of random input + # e.g. If the match nums is [1, 2, 0, 0, 0, 0], we think the scheduer method should be resources aware + if max(prefix_lens) <= 100: + self.resources_aware_scheduler(req) + else: + gpu_idx = prefix_lens.index(max(prefix_lens)) + self.workers[gpu_idx].send_pyobj(req) + def shortest_queue_scheduler(self, input_requests): raise NotImplementedError() @@ -144,6 +313,7 @@ def event_loop(self): except zmq.ZMQError: break + # logger.info(f"[event_loop]{type(recv_req)}") if isinstance( recv_req, ( @@ -170,6 +340,8 @@ def run_data_parallel_controller_process( try: controller = DataParallelController(server_args, port_args) pipe_writer.send("ready") + if controller.recv_tree_cache_thread: + controller.recv_tree_cache_thread.start() controller.event_loop() except Exception: msg = get_exception_traceback() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9625ff44eb..3841736497 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -18,9 +18,11 @@ processes (TokenizerManager, DetokenizerManager, Controller). """ +import multiprocessing import uuid from dataclasses import dataclass from enum import Enum +from multiprocessing import Value from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason @@ -353,3 +355,19 @@ class AbortReq: class ProfileReq(Enum): START_PROFILE = 1 STOP_PROFILE = 2 + + +class ControllerInfo: + def __init__(self, dp_size): + self.available_kv_cache = [] + self.running_reqs = [] + self.waiting_reqs = [] + self.lock = multiprocessing.Lock() + + # For pre radix + self.radix_queue = multiprocessing.Queue() + + for i in range(dp_size): + self.available_kv_cache.append(Value("i", 0)) + self.running_reqs.append(Value("i", 0)) + self.waiting_reqs.append(Value("i", 0)) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 990fbeaa85..a47f76580d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -17,10 +17,12 @@ import json import logging +import multiprocessing import os import time import warnings from collections import deque +from copy import deepcopy from types import SimpleNamespace from typing import List, Optional, Union @@ -37,6 +39,7 @@ AbortReq, BatchEmbeddingOut, BatchTokenIDOut, + ControllerInfo, FlushCacheReq, ProfileReq, TokenizedEmbeddingReqInput, @@ -61,7 +64,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.mem_cache.chunk_cache import ChunkCache -from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixCacheSend from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( broadcast_pyobj, @@ -94,6 +97,7 @@ def __init__( gpu_id: int, tp_rank: int, dp_rank: Optional[int], + controller_info: Optional[ControllerInfo] = None, ): # Parse args self.server_args = server_args @@ -265,6 +269,22 @@ def __init__( with_stack=True, ) + # init controller info + if controller_info and self.tp_rank == 0: + self.controller_info = controller_info + self.gpu_id = gpu_id + self.controller_info.available_kv_cache[self.gpu_id].value = ( + self.token_to_kv_pool.available_size() + ) + if self.server_args.load_balance_method == "zmq_radix": + self.pre_radix = True + import threading + + self.change_cnt_lock = threading.Lock() + threading.Thread(target=self.loop_for_send_tree_cache).start() + else: + self.controller_info = None + @torch.inference_mode() def event_loop_normal(self): """A normal blocking scheduler loop.""" @@ -323,6 +343,25 @@ def event_loop_overlap(self): self.last_batch = batch + def loop_for_send_tree_cache(self): + while True: + self.send_tree_cache_to_queue() + time.sleep(1) + + def send_tree_cache_to_queue(self): + if self.pre_radix: + try: + node = deepcopy(self.tree_cache.root_node) + send_data = RadixCacheSend( + gpu_id=self.gpu_id, root_node=node, time=time.time() + ) + del node + self.controller_info.radix_queue.put(send_data) + # logger.info("[send_tree_cache_to_queue] has send new data") + except Exception as e: + # logger.info(f"[send_tree_cache_to_queue]error:{e}") + return + def recv_requests(self): if self.tp_rank == 0: recv_reqs = [] @@ -749,6 +788,20 @@ def process_batch_result(self, batch: ScheduleBatch, result): else: self.process_batch_result_prefill(batch, result) + # update controller info + if self.controller_info: + with self.controller_info.lock: + self.controller_info.available_kv_cache[self.gpu_id].value = ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + self.controller_info.running_reqs[self.gpu_id].value = ( + len(self.running_batch.reqs) if self.running_batch else 0 + ) + self.controller_info.waiting_reqs[self.gpu_id].value = len( + self.waiting_queue + ) + def process_batch_result_prefill(self, batch: ScheduleBatch, result): if self.is_generation: logits_output, next_token_ids, bid = result @@ -1113,6 +1166,7 @@ def run_scheduler_process( tp_rank: int, dp_rank: Optional[int], pipe_writer, + controller_info, ): if dp_rank is None: configure_logger(server_args, prefix=f" TP{tp_rank}") @@ -1122,7 +1176,9 @@ def run_scheduler_process( suppress_other_loggers() try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) + scheduler = Scheduler( + server_args, port_args, gpu_id, tp_rank, dp_rank, controller_info + ) pipe_writer.send("ready") if server_args.enable_overlap_schedule: scheduler.event_loop_overlap() diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 8cd8354b6b..012212155d 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -22,6 +22,7 @@ import heapq import time from collections import defaultdict +from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, List, Optional import torch @@ -33,6 +34,13 @@ from sglang.srt.managers.schedule_batch import Req +@dataclass +class RadixCacheSend: + gpu_id: int + root_node: TreeNode + time: time + + class TreeNode: def __init__(self): self.children = defaultdict(TreeNode) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6ccd891857..cc83b2bb85 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -444,9 +444,10 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[ "round_robin", "shortest_queue", + "resources_aware", + "pre_radix", ], ) - # Multi-node distributed serving args parser.add_argument( "--dist-init-addr", From 13387d30421bdfc63c98b20f3a59f9e77095c96e Mon Sep 17 00:00:00 2001 From: YouYangxiu Date: Mon, 21 Oct 2024 16:11:07 +0800 Subject: [PATCH 2/4] change the method name from zmq_raix to pre_radix --- python/sglang/srt/managers/data_parallel_controller.py | 5 +---- python/sglang/srt/managers/scheduler.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 87b2ab926a..794fe01579 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -124,10 +124,7 @@ def __init__(self, server_args, port_args) -> None: self.main_num_waiting_req = [] # For pre_radix - self.choosen_gpu_per_req = {} - - # For zmq_radix - self.zmq_raidx = server_args.load_balance_method == "zmq_radix" + self.zmq_raidx = server_args.load_balance_method == "pre_radix" # Start data parallel workers base_gpu_id = 0 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a47f76580d..5a17286302 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -276,7 +276,7 @@ def __init__( self.controller_info.available_kv_cache[self.gpu_id].value = ( self.token_to_kv_pool.available_size() ) - if self.server_args.load_balance_method == "zmq_radix": + if self.server_args.load_balance_method == "pre_radix": self.pre_radix = True import threading From f2119176e98b2c5f7493617cf255c5b9c2b40787 Mon Sep 17 00:00:00 2001 From: YouYangxiu Date: Wed, 23 Oct 2024 18:27:34 +0800 Subject: [PATCH 3/4] fix bug in comments --- python/sglang/srt/managers/data_parallel_controller.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 794fe01579..ae2b76d827 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -124,7 +124,7 @@ def __init__(self, server_args, port_args) -> None: self.main_num_waiting_req = [] # For pre_radix - self.zmq_raidx = server_args.load_balance_method == "pre_radix" + self.pre_raidx = server_args.load_balance_method == "pre_radix" # Start data parallel workers base_gpu_id = 0 @@ -140,7 +140,7 @@ def __init__(self, server_args, port_args) -> None: self.workers.append(send_to) base_gpu_id += server_args.tp_size - if self.zmq_raidx: + if self.pre_raidx: import threading self.newest_tree_cache = {} @@ -231,7 +231,7 @@ def update_memory_and_requests(self): if not self.main_available_kv_cache: self.main_available_kv_cache = available_mem.copy() if self.pre_available_kv_cache != available_mem: - self.pre_available_kv_cache = available_mem.copy() + self.pre_available_kv_cache = available_mem self.main_available_kv_cache = available_mem.copy() if not self.pre_num_running_req: @@ -239,7 +239,7 @@ def update_memory_and_requests(self): if not self.main_num_running_req: self.main_num_running_req = num_reqs_running.copy() if self.pre_num_running_req != num_reqs_running: - self.main_num_running_req = num_reqs_running.copy() + self.main_num_running_req = num_reqs_running self.pre_num_running_req = num_reqs_running.copy() if not self.pre_num_waiting_req: @@ -247,7 +247,7 @@ def update_memory_and_requests(self): if not self.main_num_waiting_req: self.main_num_waiting_req = num_reqs_waiting.copy() if self.pre_num_waiting_req != num_reqs_waiting: - self.main_num_waiting_req = num_reqs_waiting.copy() + self.main_num_waiting_req = num_reqs_waiting self.pre_num_waiting_req = num_reqs_waiting.copy() def allocate_gpu(self, req): From a0220441e4d6bca905a53acf1e29552a403c0cf3 Mon Sep 17 00:00:00 2001 From: YouYangxiu Date: Fri, 25 Oct 2024 11:20:28 +0800 Subject: [PATCH 4/4] fix bug --- python/sglang/srt/managers/data_parallel_controller.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index ae2b76d827..7a40b17247 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -274,7 +274,6 @@ def allocate_gpu(self, req): ] gpu_idx = random.choice(max_indices) - self.main_num_waiting_req[gpu_idx] += 1 self.main_available_kv_cache[gpu_idx] -= len(req.input_ids) return gpu_idx