From 39fe235bc92fa4ca05f01b92223674eb0b603ab2 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Fri, 13 Sep 2024 14:52:08 +0800 Subject: [PATCH] fix bug --- python/sglang/launch_server.py | 2 +- .../sglang/srt/managers/controller_single.py | 160 ++++++++++-------- python/sglang/srt/managers/schedule_batch.py | 7 +- .../sglang/srt/managers/speculative_utils.py | 13 +- .../sglang/srt/managers/speculative_worker.py | 13 ++ python/sglang/srt/managers/tp_worker.py | 3 + python/sglang/srt/models/llama2.py | 2 +- python/sglang/srt/server_args.py | 9 +- 8 files changed, 128 insertions(+), 81 deletions(-) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 5f26999f48..b0467c95de 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -14,7 +14,7 @@ ServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - multiprocessing.set_start_method("spawn") + multiprocessing.set_start_method("forkserver", force=True) try: launch_server(server_args) diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 8341e15314..80cfcdcd88 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -17,9 +17,10 @@ import logging import multiprocessing as mp +import time from typing import List -import torch.multiprocessing as multiprocessing +import torch import zmq from sglang.srt.managers.speculative_utils import SpecInfoPipline @@ -47,9 +48,8 @@ def __init__( gpu_ids: List[int], is_data_parallel_worker: bool, dp_worker_id: int, - mp_queue: multiprocessing.Queue, + mp_queue: torch.multiprocessing.Queue, spec_queue: SpecInfoPipline, - init_flag: multiprocessing.Event = None, ): # Parse args self.tp_size = server_args.tp_size @@ -94,8 +94,6 @@ def __init__( spec_queue, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group - if init_flag is not None: - init_flag.set() def loop_for_forward(self): while True: @@ -139,69 +137,89 @@ def start_controller_process( gpu_ids: List[int] = None, dp_worker_id: int = None, queue: mp.connection.Connection = None, + spec_queue: SpecInfoPipline = None, + init_flag: torch.multiprocessing.Event = None, ): """Start a controller process.""" - if is_data_parallel_worker: - logger_prefix = f" DP{dp_worker_id} TP0" - else: - logger_prefix = " TP0" - configure_logger(server_args, prefix=logger_prefix) - - if not is_data_parallel_worker: - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] - dp_worker_id = 0 - queue = None - - try: - spec_queue = None - flag = None - if server_args.speculative_algorithm is not None: - spec_queue = SpecInfoPipline() - flag = multiprocessing.Event() - - controller = ControllerSingle( - server_args, - port_args, - model_overide_args, - gpu_ids, - is_data_parallel_worker, - dp_worker_id, - queue, - spec_queue, - flag, + ctx = torch.multiprocessing.get_context("forkserver") + + if server_args.speculative_algorithm is not None and init_flag is None: + spec_queue_ = SpecInfoPipline() + flag = ctx.Event() + target = ctx.Process( + target=start_controller_process, + args=( + server_args, + port_args, + pipe_writer, + model_overide_args, + is_data_parallel_worker, + gpu_ids, + dp_worker_id, + queue, + spec_queue_, + flag, + ), ) - - if server_args.speculative_algorithm is not None: - flag.wait() - # draft process should be launch after target process. - proc = multiprocessing.Process( - target=start_spec_controller_process, - args=( - server_args, - port_args, - pipe_writer, - model_overide_args, - True, - gpu_ids, - dp_worker_id, - queue, - spec_queue, - ), + target.start() + flag.wait() + # draft process should be launch after target process. + draft = ctx.Process( + target=start_spec_controller_process, + args=( + server_args, + port_args, + pipe_writer, + model_overide_args, + False, + gpu_ids, + dp_worker_id, + queue, + spec_queue_, + ), + ) + draft.start() + target.join() + # draft.join() + else: + if is_data_parallel_worker: + logger_prefix = f" DP{dp_worker_id} TP0" + else: + logger_prefix = " TP0" + configure_logger(server_args, prefix=logger_prefix) + + if not is_data_parallel_worker: + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [ + i for _ in range(server_args.nnodes) for i in range(tp_size_local) + ] + dp_worker_id = 0 + queue = None + + try: + controller = ControllerSingle( + server_args, + port_args, + model_overide_args, + gpu_ids, + is_data_parallel_worker, + dp_worker_id, + queue, + spec_queue, ) - proc.start() - except Exception: - pipe_writer.send(get_exception_traceback()) - raise + if init_flag is not None: + init_flag.set() + except Exception: + pipe_writer.send(get_exception_traceback()) + raise + pipe_writer.send("init ok") - pipe_writer.send("init ok") - - try: - controller.loop_for_forward() - except Exception: - logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) - finally: - kill_parent_process() + try: + controller.loop_for_forward() + except Exception: + logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) + finally: + kill_parent_process() class ControllerSingleSpecDraft(ControllerSingle): @@ -215,7 +233,7 @@ def __init__( gpu_ids: List[int], is_data_parallel_worker: bool, dp_worker_id: int, - mp_queue: multiprocessing.Queue, + mp_queue: torch.multiprocessing.Queue, spec_queue: SpecInfoPipline, ): # Parse args @@ -223,6 +241,7 @@ def __init__( self.is_dp_worker = is_data_parallel_worker self.dp_worker_id = dp_worker_id + self.spec_queue = spec_queue self.mp_queue = spec_queue.draft_input_queue self.spec_server = SpecDraftServer( gpu_ids[0], @@ -263,6 +282,12 @@ def start_spec_controller_process( logger_prefix = " Spec " configure_logger(server_args, prefix=logger_prefix) + if not is_data_parallel_worker: + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] + dp_worker_id = 0 + queue = None + try: controller = ControllerSingleSpecDraft( server_args, @@ -277,14 +302,13 @@ def start_spec_controller_process( except Exception: pipe_writer.send(get_exception_traceback()) raise - finally: - kill_parent_process() - pipe_writer.send("draft init ok") try: controller.loop_for_forward() except Exception: - logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) + logger.error( + "Exception in ControllerSingleSpecDraft:\n" + get_exception_traceback() + ) finally: kill_parent_process() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 68ac8f5b12..6835ee9e2a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -18,6 +18,7 @@ """Meta data for requests and batches""" import logging +import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Union @@ -26,11 +27,11 @@ from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap +from sglang.srt.managers.speculative_utils import SpecDraftInfo from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.managers.speculative_utils import SpecDraftInfo if TYPE_CHECKING: from sglang.srt.layers.sampler import SampleOutput @@ -330,6 +331,7 @@ class ScheduleBatch: req_to_token_pool: ReqToTokenPool token_to_kv_pool: BaseTokenToKVPool tree_cache: BasePrefixCache + bid: str # Batched arguments to model runner input_ids: torch.Tensor = None @@ -345,7 +347,7 @@ class ScheduleBatch: # For processing logprobs return_logprob: bool = False top_logprobs_nums: List[int] = None - + # For speculative decoding spec_draft_info: SpecDraftInfo = None @@ -359,6 +361,7 @@ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, return_logprob=return_logprob, + bid=uuid.uuid4().hex, ) def batch_size(self): diff --git a/python/sglang/srt/managers/speculative_utils.py b/python/sglang/srt/managers/speculative_utils.py index 875e832ba6..ab081503f8 100644 --- a/python/sglang/srt/managers/speculative_utils.py +++ b/python/sglang/srt/managers/speculative_utils.py @@ -42,18 +42,15 @@ def get(self, name): @DraftInfoFactory.register("EAGLE") class EAGLEDraftInfo(SpecDraftInfo): - def __init__( - self, hidden_states: torch.Tensor, input_ids: torch.Tensor, output_token - ): - hidden_states: torch.Tensor = hidden_states - input_ids: torch.Tensor = input_ids + def __init__(self, hidden_states: torch.Tensor): + self.hidden_states: torch.Tensor = hidden_states def update_input(self, info: "EAGLEDraftInfo"): self.hidden_states = info.hidden_states - self.input_ids = info.input_ids class SpecInfoPipline: def __init__(self): - self.draft_input_queue = torch.multiprocessing.Queue() - self.draft_output_queue = torch.multiprocessing.Queue() + ctx = torch.multiprocessing.get_context("forkserver") + self.draft_input_queue = ctx.Queue() + self.draft_output_queue = ctx.Queue() diff --git a/python/sglang/srt/managers/speculative_worker.py b/python/sglang/srt/managers/speculative_worker.py index 1b3fd8ac16..e74e7d38d0 100644 --- a/python/sglang/srt/managers/speculative_worker.py +++ b/python/sglang/srt/managers/speculative_worker.py @@ -14,9 +14,12 @@ """ """A speculative draft worker.""" +from typing import List +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.speculative_utils import SpecInfoPipline from sglang.srt.managers.tp_worker import ModelTpServer +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import ServerArgs @@ -33,3 +36,13 @@ def __init__( super().__init__( gpu_id, tp_rank, server_args, nccl_port, model_overide_args, spec_queue ) + + def exposed_step(self, recv_batches: List[ScheduleBatch]): + for batch in recv_batches: + self.forward_speculative_batch(batch) + + def forward_speculative_batch(self, batch: ScheduleBatch): + for step in range(self.server_args.num_speculative_tokens): + sample_output, logits_output = self.model_runner.forward( + batch, ForwardMode.DECODE + ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 27ae4b2c36..2aaed49164 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -92,6 +92,9 @@ def __init__( self.is_spec_worker = server_args.speculative_algorithm is not None self.spec_queue = spec_queue self.is_draft_worker = self.__class__.__name__ == "SpecDraftServer" + self.server_args = server_args + # hold the batches to avoid torch release the torch.Tensor in ScheduleBatch + self.spec_running_batches = {} # Init model and tokenizer self.model_config = ModelConfig( server_args.model_path, diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 0948fd4ace..34ddac7276 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -322,7 +322,7 @@ def forward( if input_metadata.forward_mode.is_spec_mode(): input_metadata.spec_draft_info = DraftInfoFactory.get( input_metadata.spec_algorithm - )(hidden_states, sample_output) + )(hidden_states) return sample_output, logits_output def get_module_name(self, name): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5d5fc317d1..2df426e336 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -31,6 +31,7 @@ class ServerArgs: draft_model_path: str = None speculative_algorithm: str = None draft_mem_fraction: float = None + num_speculative_tokens: int = None tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False @@ -156,7 +157,13 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The fraction of the memory used for static allocation of draft model in Speculative Decoding.", required=False, ) - + parser.add_argument( + "--num-speculative-tokens", + type=float, + help="The number of token sampled from draft model in Speculative Decoding.", + required=False, + default=4, + ) parser.add_argument( "--tokenizer-path", type=str,