diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 1df64e848c..5f26999f48 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -3,6 +3,8 @@ import argparse import os +import torch.multiprocessing as multiprocessing + from sglang.srt.server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_child_process @@ -12,6 +14,7 @@ ServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) + multiprocessing.set_start_method("spawn") try: launch_server(server_args) diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index abb11dbcec..8341e15314 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -16,12 +16,14 @@ """A controller that manages a group of tensor parallel workers.""" import logging -import torch.multiprocessing as multiprocessing import multiprocessing as mp from typing import List +import torch.multiprocessing as multiprocessing import zmq +from sglang.srt.managers.speculative_utils import SpecInfoPipline +from sglang.srt.managers.speculative_worker import SpecDraftServer from sglang.srt.managers.tp_worker import ( ModelTpServer, broadcast_recv_input, @@ -30,9 +32,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import configure_logger, kill_parent_process from sglang.utils import get_exception_traceback -from sglang.srt.managers.speculative_utils import SpecInfoPipline -from sglang.srt.managers.speculative_worker import SpecDraftServer -from sglang.srt.managers.speculative_utils import SpecInfoPipline logger = logging.getLogger(__name__) @@ -50,6 +49,7 @@ def __init__( dp_worker_id: int, mp_queue: multiprocessing.Queue, spec_queue: SpecInfoPipline, + init_flag: multiprocessing.Event = None, ): # Parse args self.tp_size = server_args.tp_size @@ -91,9 +91,11 @@ def __init__( server_args, port_args.nccl_ports[dp_worker_id], model_overide_args, - spec_queue + spec_queue, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group + if init_flag is not None: + init_flag.set() def loop_for_forward(self): while True: @@ -153,8 +155,26 @@ def start_controller_process( try: spec_queue = None + flag = None if server_args.speculative_algorithm is not None: spec_queue = SpecInfoPipline() + flag = multiprocessing.Event() + + controller = ControllerSingle( + server_args, + port_args, + model_overide_args, + gpu_ids, + is_data_parallel_worker, + dp_worker_id, + queue, + spec_queue, + flag, + ) + + if server_args.speculative_algorithm is not None: + flag.wait() + # draft process should be launch after target process. proc = multiprocessing.Process( target=start_spec_controller_process, args=( @@ -170,16 +190,6 @@ def start_controller_process( ), ) proc.start() - controller = ControllerSingle( - server_args, - port_args, - model_overide_args, - gpu_ids, - is_data_parallel_worker, - dp_worker_id, - queue, - spec_queue, - ) except Exception: pipe_writer.send(get_exception_traceback()) raise @@ -218,16 +228,15 @@ def __init__( gpu_ids[0], 0, server_args, - port_args.nccl_ports[dp_worker_id*2+1], + port_args.nccl_ports[dp_worker_id * 2 + 1], model_overide_args, - spec_queue + spec_queue, ) def loop_for_forward(self): while True: recv_reqs = self.recv_requests_from_mp_queue() - self.tp_server.exposed_step(recv_reqs) - + self.spec_server.exposed_step(recv_reqs) def recv_requests_from_mp_queue(self): recv_reqs = [] @@ -245,7 +254,7 @@ def start_spec_controller_process( gpu_ids: List[int] = None, dp_worker_id: int = None, queue: mp.connection.Connection = None, - spec_queue: SpecInfoPipline = None + spec_queue: SpecInfoPipline = None, ): """Start a controller process.""" if is_data_parallel_worker: @@ -268,12 +277,14 @@ def start_spec_controller_process( except Exception: pipe_writer.send(get_exception_traceback()) raise + finally: + kill_parent_process() - pipe_writer.send("init ok") + pipe_writer.send("draft init ok") try: controller.loop_for_forward() except Exception: logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) finally: - kill_parent_process() \ No newline at end of file + kill_parent_process() diff --git a/python/sglang/srt/managers/speculative_utils.py b/python/sglang/srt/managers/speculative_utils.py index 55cdf03456..875e832ba6 100644 --- a/python/sglang/srt/managers/speculative_utils.py +++ b/python/sglang/srt/managers/speculative_utils.py @@ -13,6 +13,8 @@ limitations under the License. """ +from typing import Type + import torch @@ -20,6 +22,25 @@ class SpecDraftInfo: pass +class SpecDraftInfoFactory: + def __init__(self): + self.factory = {} + + def register(self, name: str) -> SpecDraftInfo: + def wrapper(info: Type[SpecDraftInfo]) -> Type[SpecDraftInfo]: + self.factory[name] = info + return info + + return wrapper + + def get(self, name): + return self.factory[name] + + +DraftInfoFactory = SpecDraftInfoFactory() + + +@DraftInfoFactory.register("EAGLE") class EAGLEDraftInfo(SpecDraftInfo): def __init__( self, hidden_states: torch.Tensor, input_ids: torch.Tensor, output_token @@ -36,6 +57,3 @@ class SpecInfoPipline: def __init__(self): self.draft_input_queue = torch.multiprocessing.Queue() self.draft_output_queue = torch.multiprocessing.Queue() - - - diff --git a/python/sglang/srt/managers/speculative_worker.py b/python/sglang/srt/managers/speculative_worker.py index 0cb45f5c01..1b3fd8ac16 100644 --- a/python/sglang/srt/managers/speculative_worker.py +++ b/python/sglang/srt/managers/speculative_worker.py @@ -13,10 +13,23 @@ limitations under the License. """ -"""A tensor parallel worker.""" +"""A speculative draft worker.""" +from sglang.srt.managers.speculative_utils import SpecInfoPipline from sglang.srt.managers.tp_worker import ModelTpServer +from sglang.srt.server_args import ServerArgs class SpecDraftServer(ModelTpServer): - is_spec_server=True + def __init__( + self, + gpu_id: int, + tp_rank: int, + server_args: ServerArgs, + nccl_port: int, + model_overide_args: dict, + spec_queue: SpecInfoPipline, + ): + super().__init__( + gpu_id, tp_rank, server_args, nccl_port, model_overide_args, spec_queue + ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 12c26aa7df..27ae4b2c36 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -16,7 +16,6 @@ """A tensor parallel worker.""" import logging -import torch.multiprocessing as multiprocessing import os import pickle import time @@ -26,6 +25,7 @@ import torch import torch.distributed import torch.distributed as dist +import torch.multiprocessing as multiprocessing from sglang.global_config import global_config from sglang.srt.constrained.fsm_cache import FSMCache @@ -49,6 +49,7 @@ Req, ScheduleBatch, ) +from sglang.srt.managers.speculative_utils import SpecInfoPipline from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_config import ModelConfig @@ -62,7 +63,6 @@ suppress_other_loggers, ) from sglang.utils import get_exception_traceback -from sglang.srt.managers.speculative_utils import SpecInfoPipline logger = logging.getLogger(__name__) @@ -78,7 +78,7 @@ def __init__( server_args: ServerArgs, nccl_port: int, model_overide_args: dict, - spec_queue: SpecInfoPipline + spec_queue: SpecInfoPipline, ): suppress_other_loggers() @@ -89,8 +89,9 @@ def __init__( self.dp_size = server_args.dp_size self.schedule_policy = server_args.schedule_policy self.disable_regex_jump_forward = server_args.disable_regex_jump_forward - self.is_spec_worker = False - + self.is_spec_worker = server_args.speculative_algorithm is not None + self.spec_queue = spec_queue + self.is_draft_worker = self.__class__.__name__ == "SpecDraftServer" # Init model and tokenizer self.model_config = ModelConfig( server_args.model_path, @@ -101,12 +102,17 @@ def __init__( self.model_runner = ModelRunner( model_config=self.model_config, - mem_fraction_static= server_args.draft_mem_fraction if self.is_spec_worker else server_args.mem_fraction_static , + mem_fraction_static=( + server_args.draft_mem_fraction + if self.is_draft_worker + else server_args.mem_fraction_static + ), gpu_id=gpu_id, tp_rank=tp_rank, tp_size=server_args.tp_size, nccl_port=nccl_port, server_args=server_args, + is_draft_runner=self.is_draft_worker, ) if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None @@ -261,12 +267,14 @@ def forward_step(self): if new_batch is not None: # Run a new prefill batch self.forward_prefill_batch(new_batch) - - if not new_batch.is_empty(): - if self.running_batch is None: - self.running_batch = new_batch - else: - self.running_batch.merge(new_batch) + if self.is_spec_worker: + self.spec_queue.draft_input_queue.put(new_batch) + else: + if not new_batch.is_empty(): + if self.running_batch is None: + self.running_batch = new_batch + else: + self.running_batch.merge(new_batch) else: # Run a decode batch if self.running_batch is not None: @@ -285,6 +293,8 @@ def forward_step(self): if self.out_pyobjs and self.running_batch.has_stream(): break + elif self.is_spec_worker: + self.forward_verify_batch() else: self.check_memory() self.new_token_ratio = global_config.init_new_token_ratio @@ -725,6 +735,12 @@ def forward_decode_batch(self, batch: ScheduleBatch): self.handle_finished_requests(batch) + def forward_verify_batch(self): + recv_batch = [] + while not self.spec_queue.draft_output_queue.empty(): + recv_batch.append(self.spec_queue.draft_output_queue.get()) + pass + def handle_finished_requests(self, batch: ScheduleBatch): output_rids = [] output_meta_info = [] diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 627a24ed83..9eccd8da6f 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -24,8 +24,8 @@ import torch from sglang.srt.managers.schedule_batch import ScheduleBatch -from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.managers.speculative_utils import SpecDraftInfo +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -43,7 +43,9 @@ class ForwardMode(IntEnum): SPECEXTEND = auto() # Speculative verify. SPECVERIFY = auto() - + + def is_spec_mode(self): + return self in (self.SPECEXTEND, self.SPECVERIFY) @dataclass @@ -94,9 +96,10 @@ class InputMetadata: flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_use_ragged: bool = False - + # Information used for speculative decoding - draft_info : SpecDraftInfo = None + spec_draft_info: SpecDraftInfo = None + spec_algorithm: str = None def init_multimuldal_info(self, batch: ScheduleBatch): reqs = batch.reqs @@ -222,7 +225,7 @@ def from_schedule_batch( ret.init_flashinfer_handlers( model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged ) - + ret.spec_algorithm = model_runner.server_args.speculative_algorithm return ret def init_triton_args(self, batch: ScheduleBatch): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e6f5e74311..5d9912207e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -78,6 +78,7 @@ def __init__( tp_size: int, nccl_port: int, server_args: ServerArgs, + is_draft_runner: bool, ): # Parse args self.model_config = model_config @@ -87,6 +88,7 @@ def __init__( self.tp_size = tp_size self.nccl_port = nccl_port self.server_args = server_args + self.is_draft_runner = is_draft_runner self.is_multimodal_model = is_multimodal_model( self.model_config.hf_config.architectures ) @@ -177,7 +179,11 @@ def load_model(self): self.device_config = DeviceConfig() self.load_config = LoadConfig(load_format=self.server_args.load_format) self.vllm_model_config = VllmModelConfig( - model=self.server_args.model_path, + model=( + self.server_args.draft_model_path + if self.is_draft_runner + else self.server_args.model_path + ), quantization=self.server_args.quantization, tokenizer=None, tokenizer_mode=None, @@ -333,9 +339,22 @@ def profile_max_num_token(self, total_gpu_memory: int): * 2 * torch._utils._element_size(self.kv_cache_dtype) ) - rest_memory = available_gpu_memory - total_gpu_memory * ( - 1 - self.mem_fraction_static - ) + if self.server_args.speculative_algorithm is not None: + if self.is_draft_runner: + rest_memory = available_gpu_memory - total_gpu_memory * ( + 1 + - self.server_args.draft_mem_fraction + - self.server_args.mem_fraction_static + ) + else: + rest_memory = available_gpu_memory - total_gpu_memory * ( + 1 - self.server_args.mem_fraction_static + ) + else: + rest_memory = available_gpu_memory - total_gpu_memory * ( + 1 - self.mem_fraction_static + ) + max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token @@ -539,9 +558,11 @@ def forward_decode(self, batch: ScheduleBatch): ForwardMode.DECODE, ) - return self.model.forward( + ret = self.model.forward( batch.input_ids, input_metadata.positions, input_metadata ) + batch.spec_draft_info = input_metadata.spec_draft_info + return ret @torch.inference_mode() def forward_extend(self, batch: ScheduleBatch): @@ -551,9 +572,11 @@ def forward_extend(self, batch: ScheduleBatch): forward_mode=ForwardMode.EXTEND, ) if self.is_generation: - return self.model.forward( + ret = self.model.forward( batch.input_ids, input_metadata.positions, input_metadata ) + batch.spec_draft_info = input_metadata.spec_draft_info + return ret else: # Only embedding models have get_embedding parameter return self.model.forward( diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 22751d9b67..0948fd4ace 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -42,6 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.sampler import Sampler +from sglang.srt.managers.speculative_utils import DraftInfoFactory from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -318,6 +319,10 @@ def forward( input_ids, hidden_states, self.lm_head.weight, input_metadata ) sample_output = self.sampler(logits_output, input_metadata.sampling_info) + if input_metadata.forward_mode.is_spec_mode(): + input_metadata.spec_draft_info = DraftInfoFactory.get( + input_metadata.spec_algorithm + )(hidden_states, sample_output) return sample_output, logits_output def get_module_name(self, name): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9a13129457..5d5fc317d1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -30,7 +30,7 @@ class ServerArgs: model_path: str draft_model_path: str = None speculative_algorithm: str = None - draft_mem_fraction: float = 0.1 + draft_mem_fraction: float = None tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False @@ -156,7 +156,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The fraction of the memory used for static allocation of draft model in Speculative Decoding.", required=False, ) - + parser.add_argument( "--tokenizer-path", type=str,