Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kavioyu committed Sep 13, 2024
1 parent db37fda commit 39fe235
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 81 deletions.
2 changes: 1 addition & 1 deletion python/sglang/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
multiprocessing.set_start_method("spawn")
multiprocessing.set_start_method("forkserver", force=True)

try:
launch_server(server_args)
Expand Down
160 changes: 92 additions & 68 deletions python/sglang/srt/managers/controller_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import logging
import multiprocessing as mp
import time
from typing import List

import torch.multiprocessing as multiprocessing
import torch
import zmq

from sglang.srt.managers.speculative_utils import SpecInfoPipline
Expand Down Expand Up @@ -47,9 +48,8 @@ def __init__(
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
mp_queue: multiprocessing.Queue,
mp_queue: torch.multiprocessing.Queue,
spec_queue: SpecInfoPipline,
init_flag: multiprocessing.Event = None,
):
# Parse args
self.tp_size = server_args.tp_size
Expand Down Expand Up @@ -94,8 +94,6 @@ def __init__(
spec_queue,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
if init_flag is not None:
init_flag.set()

def loop_for_forward(self):
while True:
Expand Down Expand Up @@ -139,69 +137,89 @@ def start_controller_process(
gpu_ids: List[int] = None,
dp_worker_id: int = None,
queue: mp.connection.Connection = None,
spec_queue: SpecInfoPipline = None,
init_flag: torch.multiprocessing.Event = None,
):
"""Start a controller process."""
if is_data_parallel_worker:
logger_prefix = f" DP{dp_worker_id} TP0"
else:
logger_prefix = " TP0"
configure_logger(server_args, prefix=logger_prefix)

if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
dp_worker_id = 0
queue = None

try:
spec_queue = None
flag = None
if server_args.speculative_algorithm is not None:
spec_queue = SpecInfoPipline()
flag = multiprocessing.Event()

controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
spec_queue,
flag,
ctx = torch.multiprocessing.get_context("forkserver")

if server_args.speculative_algorithm is not None and init_flag is None:
spec_queue_ = SpecInfoPipline()
flag = ctx.Event()
target = ctx.Process(
target=start_controller_process,
args=(
server_args,
port_args,
pipe_writer,
model_overide_args,
is_data_parallel_worker,
gpu_ids,
dp_worker_id,
queue,
spec_queue_,
flag,
),
)

if server_args.speculative_algorithm is not None:
flag.wait()
# draft process should be launch after target process.
proc = multiprocessing.Process(
target=start_spec_controller_process,
args=(
server_args,
port_args,
pipe_writer,
model_overide_args,
True,
gpu_ids,
dp_worker_id,
queue,
spec_queue,
),
target.start()
flag.wait()
# draft process should be launch after target process.
draft = ctx.Process(
target=start_spec_controller_process,
args=(
server_args,
port_args,
pipe_writer,
model_overide_args,
False,
gpu_ids,
dp_worker_id,
queue,
spec_queue_,
),
)
draft.start()
target.join()
# draft.join()
else:
if is_data_parallel_worker:
logger_prefix = f" DP{dp_worker_id} TP0"
else:
logger_prefix = " TP0"
configure_logger(server_args, prefix=logger_prefix)

if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
]
dp_worker_id = 0
queue = None

try:
controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
spec_queue,
)
proc.start()
except Exception:
pipe_writer.send(get_exception_traceback())
raise
if init_flag is not None:
init_flag.set()
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")

pipe_writer.send("init ok")

try:
controller.loop_for_forward()
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
kill_parent_process()
try:
controller.loop_for_forward()
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
kill_parent_process()


class ControllerSingleSpecDraft(ControllerSingle):
Expand All @@ -215,14 +233,15 @@ def __init__(
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
mp_queue: multiprocessing.Queue,
mp_queue: torch.multiprocessing.Queue,
spec_queue: SpecInfoPipline,
):
# Parse args
self.tp_size = server_args.tp_size
self.is_dp_worker = is_data_parallel_worker
self.dp_worker_id = dp_worker_id

self.spec_queue = spec_queue
self.mp_queue = spec_queue.draft_input_queue
self.spec_server = SpecDraftServer(
gpu_ids[0],
Expand Down Expand Up @@ -263,6 +282,12 @@ def start_spec_controller_process(
logger_prefix = " Spec "
configure_logger(server_args, prefix=logger_prefix)

if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
dp_worker_id = 0
queue = None

try:
controller = ControllerSingleSpecDraft(
server_args,
Expand All @@ -277,14 +302,13 @@ def start_spec_controller_process(
except Exception:
pipe_writer.send(get_exception_traceback())
raise
finally:
kill_parent_process()

pipe_writer.send("draft init ok")

try:
controller.loop_for_forward()
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
logger.error(
"Exception in ControllerSingleSpecDraft:\n" + get_exception_traceback()
)
finally:
kill_parent_process()
7 changes: 5 additions & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Meta data for requests and batches"""

import logging
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Union

Expand All @@ -26,11 +27,11 @@
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.speculative_utils import SpecDraftInfo
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.managers.speculative_utils import SpecDraftInfo

if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput
Expand Down Expand Up @@ -330,6 +331,7 @@ class ScheduleBatch:
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool
tree_cache: BasePrefixCache
bid: str

# Batched arguments to model runner
input_ids: torch.Tensor = None
Expand All @@ -345,7 +347,7 @@ class ScheduleBatch:
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: List[int] = None

# For speculative decoding
spec_draft_info: SpecDraftInfo = None

Expand All @@ -359,6 +361,7 @@ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
return_logprob=return_logprob,
bid=uuid.uuid4().hex,
)

def batch_size(self):
Expand Down
13 changes: 5 additions & 8 deletions python/sglang/srt/managers/speculative_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,15 @@ def get(self, name):

@DraftInfoFactory.register("EAGLE")
class EAGLEDraftInfo(SpecDraftInfo):
def __init__(
self, hidden_states: torch.Tensor, input_ids: torch.Tensor, output_token
):
hidden_states: torch.Tensor = hidden_states
input_ids: torch.Tensor = input_ids
def __init__(self, hidden_states: torch.Tensor):
self.hidden_states: torch.Tensor = hidden_states

def update_input(self, info: "EAGLEDraftInfo"):
self.hidden_states = info.hidden_states
self.input_ids = info.input_ids


class SpecInfoPipline:
def __init__(self):
self.draft_input_queue = torch.multiprocessing.Queue()
self.draft_output_queue = torch.multiprocessing.Queue()
ctx = torch.multiprocessing.get_context("forkserver")
self.draft_input_queue = ctx.Queue()
self.draft_output_queue = ctx.Queue()
13 changes: 13 additions & 0 deletions python/sglang/srt/managers/speculative_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
"""

"""A speculative draft worker."""
from typing import List

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.speculative_utils import SpecInfoPipline
from sglang.srt.managers.tp_worker import ModelTpServer
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import ServerArgs


Expand All @@ -33,3 +36,13 @@ def __init__(
super().__init__(
gpu_id, tp_rank, server_args, nccl_port, model_overide_args, spec_queue
)

def exposed_step(self, recv_batches: List[ScheduleBatch]):
for batch in recv_batches:
self.forward_speculative_batch(batch)

def forward_speculative_batch(self, batch: ScheduleBatch):
for step in range(self.server_args.num_speculative_tokens):
sample_output, logits_output = self.model_runner.forward(
batch, ForwardMode.DECODE
)
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(
self.is_spec_worker = server_args.speculative_algorithm is not None
self.spec_queue = spec_queue
self.is_draft_worker = self.__class__.__name__ == "SpecDraftServer"
self.server_args = server_args
# hold the batches to avoid torch release the torch.Tensor in ScheduleBatch
self.spec_running_batches = {}
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def forward(
if input_metadata.forward_mode.is_spec_mode():
input_metadata.spec_draft_info = DraftInfoFactory.get(
input_metadata.spec_algorithm
)(hidden_states, sample_output)
)(hidden_states)
return sample_output, logits_output

def get_module_name(self, name):
Expand Down
9 changes: 8 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ServerArgs:
draft_model_path: str = None
speculative_algorithm: str = None
draft_mem_fraction: float = None
num_speculative_tokens: int = None
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
Expand Down Expand Up @@ -156,7 +157,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="The fraction of the memory used for static allocation of draft model in Speculative Decoding.",
required=False,
)

parser.add_argument(
"--num-speculative-tokens",
type=float,
help="The number of token sampled from draft model in Speculative Decoding.",
required=False,
default=4,
)
parser.add_argument(
"--tokenizer-path",
type=str,
Expand Down

0 comments on commit 39fe235

Please sign in to comment.