Skip to content

Commit

Permalink
could luanch
Browse files Browse the repository at this point in the history
  • Loading branch information
kavioyu committed Sep 11, 2024
1 parent eb76d6c commit db37fda
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 52 deletions.
3 changes: 3 additions & 0 deletions python/sglang/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import argparse
import os

import torch.multiprocessing as multiprocessing

from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process
Expand All @@ -12,6 +14,7 @@
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
multiprocessing.set_start_method("spawn")

try:
launch_server(server_args)
Expand Down
55 changes: 33 additions & 22 deletions python/sglang/srt/managers/controller_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
"""A controller that manages a group of tensor parallel workers."""

import logging
import torch.multiprocessing as multiprocessing
import multiprocessing as mp
from typing import List

import torch.multiprocessing as multiprocessing
import zmq

from sglang.srt.managers.speculative_utils import SpecInfoPipline
from sglang.srt.managers.speculative_worker import SpecDraftServer
from sglang.srt.managers.tp_worker import (
ModelTpServer,
broadcast_recv_input,
Expand All @@ -30,9 +32,6 @@
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback
from sglang.srt.managers.speculative_utils import SpecInfoPipline
from sglang.srt.managers.speculative_worker import SpecDraftServer
from sglang.srt.managers.speculative_utils import SpecInfoPipline

logger = logging.getLogger(__name__)

Expand All @@ -50,6 +49,7 @@ def __init__(
dp_worker_id: int,
mp_queue: multiprocessing.Queue,
spec_queue: SpecInfoPipline,
init_flag: multiprocessing.Event = None,
):
# Parse args
self.tp_size = server_args.tp_size
Expand Down Expand Up @@ -91,9 +91,11 @@ def __init__(
server_args,
port_args.nccl_ports[dp_worker_id],
model_overide_args,
spec_queue
spec_queue,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
if init_flag is not None:
init_flag.set()

def loop_for_forward(self):
while True:
Expand Down Expand Up @@ -153,8 +155,26 @@ def start_controller_process(

try:
spec_queue = None
flag = None
if server_args.speculative_algorithm is not None:
spec_queue = SpecInfoPipline()
flag = multiprocessing.Event()

controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
spec_queue,
flag,
)

if server_args.speculative_algorithm is not None:
flag.wait()
# draft process should be launch after target process.
proc = multiprocessing.Process(
target=start_spec_controller_process,
args=(
Expand All @@ -170,16 +190,6 @@ def start_controller_process(
),
)
proc.start()
controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
spec_queue,
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
Expand Down Expand Up @@ -218,16 +228,15 @@ def __init__(
gpu_ids[0],
0,
server_args,
port_args.nccl_ports[dp_worker_id*2+1],
port_args.nccl_ports[dp_worker_id * 2 + 1],
model_overide_args,
spec_queue
spec_queue,
)

def loop_for_forward(self):
while True:
recv_reqs = self.recv_requests_from_mp_queue()
self.tp_server.exposed_step(recv_reqs)

self.spec_server.exposed_step(recv_reqs)

def recv_requests_from_mp_queue(self):
recv_reqs = []
Expand All @@ -245,7 +254,7 @@ def start_spec_controller_process(
gpu_ids: List[int] = None,
dp_worker_id: int = None,
queue: mp.connection.Connection = None,
spec_queue: SpecInfoPipline = None
spec_queue: SpecInfoPipline = None,
):
"""Start a controller process."""
if is_data_parallel_worker:
Expand All @@ -268,12 +277,14 @@ def start_spec_controller_process(
except Exception:
pipe_writer.send(get_exception_traceback())
raise
finally:
kill_parent_process()

pipe_writer.send("init ok")
pipe_writer.send("draft init ok")

try:
controller.loop_for_forward()
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
kill_parent_process()
kill_parent_process()
24 changes: 21 additions & 3 deletions python/sglang/srt/managers/speculative_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,34 @@
limitations under the License.
"""

from typing import Type

import torch


class SpecDraftInfo:
pass


class SpecDraftInfoFactory:
def __init__(self):
self.factory = {}

def register(self, name: str) -> SpecDraftInfo:
def wrapper(info: Type[SpecDraftInfo]) -> Type[SpecDraftInfo]:
self.factory[name] = info
return info

return wrapper

def get(self, name):
return self.factory[name]


DraftInfoFactory = SpecDraftInfoFactory()


@DraftInfoFactory.register("EAGLE")
class EAGLEDraftInfo(SpecDraftInfo):
def __init__(
self, hidden_states: torch.Tensor, input_ids: torch.Tensor, output_token
Expand All @@ -36,6 +57,3 @@ class SpecInfoPipline:
def __init__(self):
self.draft_input_queue = torch.multiprocessing.Queue()
self.draft_output_queue = torch.multiprocessing.Queue()



17 changes: 15 additions & 2 deletions python/sglang/srt/managers/speculative_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,23 @@
limitations under the License.
"""

"""A tensor parallel worker."""
"""A speculative draft worker."""

from sglang.srt.managers.speculative_utils import SpecInfoPipline
from sglang.srt.managers.tp_worker import ModelTpServer
from sglang.srt.server_args import ServerArgs


class SpecDraftServer(ModelTpServer):
is_spec_server=True
def __init__(
self,
gpu_id: int,
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_overide_args: dict,
spec_queue: SpecInfoPipline,
):
super().__init__(
gpu_id, tp_rank, server_args, nccl_port, model_overide_args, spec_queue
)
40 changes: 28 additions & 12 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""A tensor parallel worker."""

import logging
import torch.multiprocessing as multiprocessing
import os
import pickle
import time
Expand All @@ -26,6 +25,7 @@
import torch
import torch.distributed
import torch.distributed as dist
import torch.multiprocessing as multiprocessing

from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
Expand All @@ -49,6 +49,7 @@
Req,
ScheduleBatch,
)
from sglang.srt.managers.speculative_utils import SpecInfoPipline
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig
Expand All @@ -62,7 +63,6 @@
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
from sglang.srt.managers.speculative_utils import SpecInfoPipline

logger = logging.getLogger(__name__)

Expand All @@ -78,7 +78,7 @@ def __init__(
server_args: ServerArgs,
nccl_port: int,
model_overide_args: dict,
spec_queue: SpecInfoPipline
spec_queue: SpecInfoPipline,
):
suppress_other_loggers()

Expand All @@ -89,8 +89,9 @@ def __init__(
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
self.is_spec_worker = False

self.is_spec_worker = server_args.speculative_algorithm is not None
self.spec_queue = spec_queue
self.is_draft_worker = self.__class__.__name__ == "SpecDraftServer"
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
Expand All @@ -101,12 +102,17 @@ def __init__(

self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static= server_args.draft_mem_fraction if self.is_spec_worker else server_args.mem_fraction_static ,
mem_fraction_static=(
server_args.draft_mem_fraction
if self.is_draft_worker
else server_args.mem_fraction_static
),
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=nccl_port,
server_args=server_args,
is_draft_runner=self.is_draft_worker,
)
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
Expand Down Expand Up @@ -261,12 +267,14 @@ def forward_step(self):
if new_batch is not None:
# Run a new prefill batch
self.forward_prefill_batch(new_batch)

if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
if self.is_spec_worker:
self.spec_queue.draft_input_queue.put(new_batch)
else:
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
Expand All @@ -285,6 +293,8 @@ def forward_step(self):

if self.out_pyobjs and self.running_batch.has_stream():
break
elif self.is_spec_worker:
self.forward_verify_batch()
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
Expand Down Expand Up @@ -725,6 +735,12 @@ def forward_decode_batch(self, batch: ScheduleBatch):

self.handle_finished_requests(batch)

def forward_verify_batch(self):
recv_batch = []
while not self.spec_queue.draft_output_queue.empty():
recv_batch.append(self.spec_queue.draft_output_queue.get())
pass

def handle_finished_requests(self, batch: ScheduleBatch):
output_rids = []
output_meta_info = []
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import torch

from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.managers.speculative_utils import SpecDraftInfo
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool

if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
Expand All @@ -43,7 +43,9 @@ class ForwardMode(IntEnum):
SPECEXTEND = auto()
# Speculative verify.
SPECVERIFY = auto()


def is_spec_mode(self):
return self in (self.SPECEXTEND, self.SPECVERIFY)


@dataclass
Expand Down Expand Up @@ -94,9 +96,10 @@ class InputMetadata:
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
flashinfer_use_ragged: bool = False

# Information used for speculative decoding
draft_info : SpecDraftInfo = None
spec_draft_info: SpecDraftInfo = None
spec_algorithm: str = None

def init_multimuldal_info(self, batch: ScheduleBatch):
reqs = batch.reqs
Expand Down Expand Up @@ -222,7 +225,7 @@ def from_schedule_batch(
ret.init_flashinfer_handlers(
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
)

ret.spec_algorithm = model_runner.server_args.speculative_algorithm
return ret

def init_triton_args(self, batch: ScheduleBatch):
Expand Down
Loading

0 comments on commit db37fda

Please sign in to comment.