diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7489004bd0..2bd3402981 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.9 + python: python3.10 repos: - repo: https://github.com/PyCQA/isort diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 1b7da747f1..7a40b17247 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -17,11 +17,13 @@ import logging import multiprocessing as mp +import multiprocessing.connection from enum import Enum, auto import zmq from sglang.srt.managers.io_struct import ( + ControllerInfo, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, TokenizedRewardReqInput, @@ -37,12 +39,42 @@ logger = logging.getLogger(__name__) +import random + + +# for pre radix scheduler +def _key_match(key0, key1): + i = 0 + for k0, k1 in zip(key0, key1): + if k0 != k1: + break + i += 1 + return i + + +def get_match_len(node, key, match_length: int) -> int: + if len(key) == 0: + return match_length + + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = _key_match(child.key, key) + match_length += prefix_len + if prefix_len < len(child.key): + return match_length + else: + return get_match_len(child, key[prefix_len:], match_length) + else: + return match_length + class LoadBalanceMethod(Enum): """Load balance method.""" ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() + RESOURCES_AWARE = auto() + PRE_RADIX = auto() @classmethod def from_str(cls, method: str): @@ -74,9 +106,26 @@ def __init__(self, server_args, port_args) -> None: dispatch_lookup = { LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + LoadBalanceMethod.RESOURCES_AWARE: self.resources_aware_scheduler, + LoadBalanceMethod.PRE_RADIX: self.pre_radix_scheduler, } self.dispatching = dispatch_lookup[self.load_balance_method] + # For resources aware + self.dp_size = server_args.dp_size + self.controller_info = ControllerInfo(server_args.dp_size) + self.pre_available_kv_cache = [] + self.main_available_kv_cache = [] + + self.pre_num_running_req = [] + self.main_num_running_req = [] + + self.pre_num_waiting_req = [] + self.main_num_waiting_req = [] + + # For pre_radix + self.pre_raidx = server_args.load_balance_method == "pre_radix" + # Start data parallel workers base_gpu_id = 0 self.workers = [] @@ -85,21 +134,32 @@ def __init__(self, server_args, port_args) -> None: tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name send_to = self.launch_tensor_parallel_group( - server_args, - tmp_port_args, - base_gpu_id, - dp_rank, + server_args, tmp_port_args, base_gpu_id, dp_rank, self.controller_info ) self.workers.append(send_to) base_gpu_id += server_args.tp_size + if self.pre_raidx: + import threading + + self.newest_tree_cache = {} + + self.recv_tree_cache_lock = threading.Lock() + self.recv_tree_cache_thread = threading.Thread( + target=self.loop_for_recv_tree_cache + ) + else: + self.newest_tree_cache = None + self.recv_tree_cache_thread = None + def launch_tensor_parallel_group( self, server_args: ServerArgs, port_args: PortArgs, base_gpu_id: int, dp_rank: int, + controller_info: ControllerInfo, ): # Launch tensor parallel scheduler processes scheduler_procs = [] @@ -114,7 +174,15 @@ def launch_tensor_parallel_group( gpu_id = base_gpu_id + tp_rank % tp_size_per_node proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), + args=( + server_args, + port_args, + gpu_id, + tp_rank, + dp_rank, + writer, + controller_info, + ), ) proc.start() scheduler_procs.append(proc) @@ -129,10 +197,107 @@ def launch_tensor_parallel_group( return send_to + def loop_for_recv_tree_cache(self): + while True: + self.recv_tree_cache() + + def recv_tree_cache(self): + while True: + recv_radix_cache = self.controller_info.radix_queue.get() + if recv_radix_cache: + # logger.info('[recv_tree_cache] receive new data') + gpu_id = recv_radix_cache.gpu_id + if ( + gpu_id not in self.newest_tree_cache + or recv_radix_cache.time > self.newest_tree_cache[gpu_id].time + ): + with self.recv_tree_cache_lock: + if gpu_id in self.newest_tree_cache: + del self.newest_tree_cache[gpu_id] + self.newest_tree_cache[gpu_id] = recv_radix_cache + del recv_radix_cache + def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) + def update_memory_and_requests(self): + available_mem = [k.value for k in self.controller_info.available_kv_cache] + num_reqs_running = [k.value for k in self.controller_info.running_reqs] + num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + + if not self.pre_available_kv_cache: + self.pre_available_kv_cache = available_mem.copy() + if not self.main_available_kv_cache: + self.main_available_kv_cache = available_mem.copy() + if self.pre_available_kv_cache != available_mem: + self.pre_available_kv_cache = available_mem + self.main_available_kv_cache = available_mem.copy() + + if not self.pre_num_running_req: + self.pre_num_running_req = num_reqs_running.copy() + if not self.main_num_running_req: + self.main_num_running_req = num_reqs_running.copy() + if self.pre_num_running_req != num_reqs_running: + self.main_num_running_req = num_reqs_running + self.pre_num_running_req = num_reqs_running.copy() + + if not self.pre_num_waiting_req: + self.pre_num_waiting_req = num_reqs_waiting.copy() + if not self.main_num_waiting_req: + self.main_num_waiting_req = num_reqs_waiting.copy() + if self.pre_num_waiting_req != num_reqs_waiting: + self.main_num_waiting_req = num_reqs_waiting + self.pre_num_waiting_req = num_reqs_waiting.copy() + + def allocate_gpu(self, req): + all_waiting = min(self.main_num_waiting_req) > 0 + no_waiting = [1 if waiting == 0 else 0 for waiting in self.main_num_waiting_req] + + if all_waiting: + ratio = [ + run / wait + for run, wait in zip( + self.main_num_running_req, self.main_num_waiting_req + ) + ] + max_ratio = max(ratio) + indices = [i for i, x in enumerate(ratio) if x == max_ratio] + gpu_idx = random.choice(indices) + else: + filter_result = [ + a * b for a, b in zip(no_waiting, self.main_available_kv_cache) + ] + max_value = max(filter_result) + max_indices = [ + index for index, value in enumerate(filter_result) if value == max_value + ] + gpu_idx = random.choice(max_indices) + + self.main_available_kv_cache[gpu_idx] -= len(req.input_ids) + return gpu_idx + + def resources_aware_scheduler(self, req): + self.update_memory_and_requests() + gpu_idx = self.allocate_gpu(req) + self.workers[gpu_idx].send_pyobj(req) + + def pre_radix_scheduler(self, req): + prefix_lens = [0] * self.dp_size + + with self.recv_tree_cache_lock: + for gpu_id, radix_cache in self.newest_tree_cache.items(): + pre_len = get_match_len(radix_cache.root_node, req.input_ids, 0) + prefix_lens[gpu_id] = pre_len + + # NOTE: 100 is used to reduce the influence of random input + # e.g. If the match nums is [1, 2, 0, 0, 0, 0], we think the scheduer method should be resources aware + if max(prefix_lens) <= 100: + self.resources_aware_scheduler(req) + else: + gpu_idx = prefix_lens.index(max(prefix_lens)) + self.workers[gpu_idx].send_pyobj(req) + def shortest_queue_scheduler(self, input_requests): raise NotImplementedError() @@ -144,6 +309,7 @@ def event_loop(self): except zmq.ZMQError: break + # logger.info(f"[event_loop]{type(recv_req)}") if isinstance( recv_req, ( @@ -170,6 +336,8 @@ def run_data_parallel_controller_process( try: controller = DataParallelController(server_args, port_args) pipe_writer.send("ready") + if controller.recv_tree_cache_thread: + controller.recv_tree_cache_thread.start() controller.event_loop() except Exception: msg = get_exception_traceback() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2cdc3f4785..c3f037c63e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -18,9 +18,11 @@ processes (TokenizerManager, DetokenizerManager, Controller). """ +import multiprocessing import uuid from dataclasses import dataclass from enum import Enum +from multiprocessing import Value from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason @@ -363,3 +365,19 @@ class GetMemPoolSizeReq: @dataclass class GetMemPoolSizeReqOutput: size: int + + +class ControllerInfo: + def __init__(self, dp_size): + self.available_kv_cache = [] + self.running_reqs = [] + self.waiting_reqs = [] + self.lock = multiprocessing.Lock() + + # For pre radix + self.radix_queue = multiprocessing.Queue() + + for i in range(dp_size): + self.available_kv_cache.append(Value("i", 0)) + self.running_reqs.append(Value("i", 0)) + self.waiting_reqs.append(Value("i", 0)) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 60531ce251..47d1e14d8b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -17,10 +17,12 @@ import json import logging +import multiprocessing import os import time import warnings from collections import deque +from copy import deepcopy from types import SimpleNamespace from typing import List, Optional, Union @@ -37,6 +39,7 @@ AbortReq, BatchEmbeddingOut, BatchTokenIDOut, + ControllerInfo, FlushCacheReq, GetMemPoolSizeReq, GetMemPoolSizeReqOutput, @@ -63,7 +66,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.mem_cache.chunk_cache import ChunkCache -from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixCacheSend from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( broadcast_pyobj, @@ -96,6 +99,7 @@ def __init__( gpu_id: int, tp_rank: int, dp_rank: Optional[int], + controller_info: Optional[ControllerInfo] = None, ): # Parse args self.server_args = server_args @@ -268,6 +272,22 @@ def __init__( with_stack=True, ) + # init controller info + if controller_info and self.tp_rank == 0: + self.controller_info = controller_info + self.gpu_id = gpu_id + self.controller_info.available_kv_cache[self.gpu_id].value = ( + self.token_to_kv_pool.available_size() + ) + if self.server_args.load_balance_method == "pre_radix": + self.pre_radix = True + import threading + + self.change_cnt_lock = threading.Lock() + threading.Thread(target=self.loop_for_send_tree_cache).start() + else: + self.controller_info = None + @torch.inference_mode() def event_loop_normal(self): """A normal blocking scheduler loop.""" @@ -326,6 +346,25 @@ def event_loop_overlap(self): self.last_batch = batch + def loop_for_send_tree_cache(self): + while True: + self.send_tree_cache_to_queue() + time.sleep(1) + + def send_tree_cache_to_queue(self): + if self.pre_radix: + try: + node = deepcopy(self.tree_cache.root_node) + send_data = RadixCacheSend( + gpu_id=self.gpu_id, root_node=node, time=time.time() + ) + del node + self.controller_info.radix_queue.put(send_data) + # logger.info("[send_tree_cache_to_queue] has send new data") + except Exception as e: + # logger.info(f"[send_tree_cache_to_queue]error:{e}") + return + def recv_requests(self): if self.tp_rank == 0: recv_reqs = [] @@ -758,6 +797,20 @@ def process_batch_result(self, batch: ScheduleBatch, result): else: self.process_batch_result_prefill(batch, result) + # update controller info + if self.controller_info: + with self.controller_info.lock: + self.controller_info.available_kv_cache[self.gpu_id].value = ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + self.controller_info.running_reqs[self.gpu_id].value = ( + len(self.running_batch.reqs) if self.running_batch else 0 + ) + self.controller_info.waiting_reqs[self.gpu_id].value = len( + self.waiting_queue + ) + def process_batch_result_prefill(self, batch: ScheduleBatch, result): if self.is_generation: logits_output, next_token_ids, bid = result @@ -1122,6 +1175,7 @@ def run_scheduler_process( tp_rank: int, dp_rank: Optional[int], pipe_writer, + controller_info, ): if dp_rank is None: configure_logger(server_args, prefix=f" TP{tp_rank}") @@ -1131,7 +1185,9 @@ def run_scheduler_process( suppress_other_loggers() try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) + scheduler = Scheduler( + server_args, port_args, gpu_id, tp_rank, dp_rank, controller_info + ) pipe_writer.send("ready") if server_args.enable_overlap_schedule: scheduler.event_loop_overlap() diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 8cd8354b6b..012212155d 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -22,6 +22,7 @@ import heapq import time from collections import defaultdict +from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, List, Optional import torch @@ -33,6 +34,13 @@ from sglang.srt.managers.schedule_batch import Req +@dataclass +class RadixCacheSend: + gpu_id: int + root_node: TreeNode + time: time + + class TreeNode: def __init__(self): self.children = defaultdict(TreeNode) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6ccd891857..cc83b2bb85 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -444,9 +444,10 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[ "round_robin", "shortest_queue", + "resources_aware", + "pre_radix", ], ) - # Multi-node distributed serving args parser.add_argument( "--dist-init-addr",