From 53b8f0fd34c1519b2ad7520c6f5f5e9f6f4f1ce0 Mon Sep 17 00:00:00 2001 From: YouYangxiu Date: Mon, 21 Oct 2024 16:05:34 +0800 Subject: [PATCH 01/14] add two new load-balance-method --- .pre-commit-config.yaml | 2 +- .../srt/managers/data_parallel_controller.py | 182 +++++++++++++++++- python/sglang/srt/managers/io_struct.py | 18 ++ python/sglang/srt/managers/scheduler.py | 60 +++++- python/sglang/srt/mem_cache/radix_cache.py | 8 + python/sglang/srt/server_args.py | 3 +- 6 files changed, 264 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7489004bd0..2bd3402981 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.9 + python: python3.10 repos: - repo: https://github.com/PyCQA/isort diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 1b7da747f1..87b2ab926a 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -17,11 +17,13 @@ import logging import multiprocessing as mp +import multiprocessing.connection from enum import Enum, auto import zmq from sglang.srt.managers.io_struct import ( + ControllerInfo, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, TokenizedRewardReqInput, @@ -37,12 +39,42 @@ logger = logging.getLogger(__name__) +import random + + +# for pre radix scheduler +def _key_match(key0, key1): + i = 0 + for k0, k1 in zip(key0, key1): + if k0 != k1: + break + i += 1 + return i + + +def get_match_len(node, key, match_length: int) -> int: + if len(key) == 0: + return match_length + + if key[0] in node.children.keys(): + child = node.children[key[0]] + prefix_len = _key_match(child.key, key) + match_length += prefix_len + if prefix_len < len(child.key): + return match_length + else: + return get_match_len(child, key[prefix_len:], match_length) + else: + return match_length + class LoadBalanceMethod(Enum): """Load balance method.""" ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() + RESOURCES_AWARE = auto() + PRE_RADIX = auto() @classmethod def from_str(cls, method: str): @@ -74,9 +106,29 @@ def __init__(self, server_args, port_args) -> None: dispatch_lookup = { LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + LoadBalanceMethod.RESOURCES_AWARE: self.resources_aware_scheduler, + LoadBalanceMethod.PRE_RADIX: self.pre_radix_scheduler, } self.dispatching = dispatch_lookup[self.load_balance_method] + # For resources aware + self.dp_size = server_args.dp_size + self.controller_info = ControllerInfo(server_args.dp_size) + self.pre_available_kv_cache = [] + self.main_available_kv_cache = [] + + self.pre_num_running_req = [] + self.main_num_running_req = [] + + self.pre_num_waiting_req = [] + self.main_num_waiting_req = [] + + # For pre_radix + self.choosen_gpu_per_req = {} + + # For zmq_radix + self.zmq_raidx = server_args.load_balance_method == "zmq_radix" + # Start data parallel workers base_gpu_id = 0 self.workers = [] @@ -85,21 +137,32 @@ def __init__(self, server_args, port_args) -> None: tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name send_to = self.launch_tensor_parallel_group( - server_args, - tmp_port_args, - base_gpu_id, - dp_rank, + server_args, tmp_port_args, base_gpu_id, dp_rank, self.controller_info ) self.workers.append(send_to) base_gpu_id += server_args.tp_size + if self.zmq_raidx: + import threading + + self.newest_tree_cache = {} + + self.recv_tree_cache_lock = threading.Lock() + self.recv_tree_cache_thread = threading.Thread( + target=self.loop_for_recv_tree_cache + ) + else: + self.newest_tree_cache = None + self.recv_tree_cache_thread = None + def launch_tensor_parallel_group( self, server_args: ServerArgs, port_args: PortArgs, base_gpu_id: int, dp_rank: int, + controller_info: ControllerInfo, ): # Launch tensor parallel scheduler processes scheduler_procs = [] @@ -114,7 +177,15 @@ def launch_tensor_parallel_group( gpu_id = base_gpu_id + tp_rank % tp_size_per_node proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), + args=( + server_args, + port_args, + gpu_id, + tp_rank, + dp_rank, + writer, + controller_info, + ), ) proc.start() scheduler_procs.append(proc) @@ -129,10 +200,108 @@ def launch_tensor_parallel_group( return send_to + def loop_for_recv_tree_cache(self): + while True: + self.recv_tree_cache() + + def recv_tree_cache(self): + while True: + recv_radix_cache = self.controller_info.radix_queue.get() + if recv_radix_cache: + # logger.info('[recv_tree_cache] receive new data') + gpu_id = recv_radix_cache.gpu_id + if ( + gpu_id not in self.newest_tree_cache + or recv_radix_cache.time > self.newest_tree_cache[gpu_id].time + ): + with self.recv_tree_cache_lock: + if gpu_id in self.newest_tree_cache: + del self.newest_tree_cache[gpu_id] + self.newest_tree_cache[gpu_id] = recv_radix_cache + del recv_radix_cache + def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) + def update_memory_and_requests(self): + available_mem = [k.value for k in self.controller_info.available_kv_cache] + num_reqs_running = [k.value for k in self.controller_info.running_reqs] + num_reqs_waiting = [k.value for k in self.controller_info.waiting_reqs] + + if not self.pre_available_kv_cache: + self.pre_available_kv_cache = available_mem.copy() + if not self.main_available_kv_cache: + self.main_available_kv_cache = available_mem.copy() + if self.pre_available_kv_cache != available_mem: + self.pre_available_kv_cache = available_mem.copy() + self.main_available_kv_cache = available_mem.copy() + + if not self.pre_num_running_req: + self.pre_num_running_req = num_reqs_running.copy() + if not self.main_num_running_req: + self.main_num_running_req = num_reqs_running.copy() + if self.pre_num_running_req != num_reqs_running: + self.main_num_running_req = num_reqs_running.copy() + self.pre_num_running_req = num_reqs_running.copy() + + if not self.pre_num_waiting_req: + self.pre_num_waiting_req = num_reqs_waiting.copy() + if not self.main_num_waiting_req: + self.main_num_waiting_req = num_reqs_waiting.copy() + if self.pre_num_waiting_req != num_reqs_waiting: + self.main_num_waiting_req = num_reqs_waiting.copy() + self.pre_num_waiting_req = num_reqs_waiting.copy() + + def allocate_gpu(self, req): + all_waiting = min(self.main_num_waiting_req) > 0 + no_waiting = [1 if waiting == 0 else 0 for waiting in self.main_num_waiting_req] + + if all_waiting: + ratio = [ + run / wait + for run, wait in zip( + self.main_num_running_req, self.main_num_waiting_req + ) + ] + max_ratio = max(ratio) + indices = [i for i, x in enumerate(ratio) if x == max_ratio] + gpu_idx = random.choice(indices) + else: + filter_result = [ + a * b for a, b in zip(no_waiting, self.main_available_kv_cache) + ] + max_value = max(filter_result) + max_indices = [ + index for index, value in enumerate(filter_result) if value == max_value + ] + gpu_idx = random.choice(max_indices) + + self.main_num_waiting_req[gpu_idx] += 1 + self.main_available_kv_cache[gpu_idx] -= len(req.input_ids) + return gpu_idx + + def resources_aware_scheduler(self, req): + self.update_memory_and_requests() + gpu_idx = self.allocate_gpu(req) + self.workers[gpu_idx].send_pyobj(req) + + def pre_radix_scheduler(self, req): + prefix_lens = [0] * self.dp_size + + with self.recv_tree_cache_lock: + for gpu_id, radix_cache in self.newest_tree_cache.items(): + pre_len = get_match_len(radix_cache.root_node, req.input_ids, 0) + prefix_lens[gpu_id] = pre_len + + # NOTE: 100 is used to reduce the influence of random input + # e.g. If the match nums is [1, 2, 0, 0, 0, 0], we think the scheduer method should be resources aware + if max(prefix_lens) <= 100: + self.resources_aware_scheduler(req) + else: + gpu_idx = prefix_lens.index(max(prefix_lens)) + self.workers[gpu_idx].send_pyobj(req) + def shortest_queue_scheduler(self, input_requests): raise NotImplementedError() @@ -144,6 +313,7 @@ def event_loop(self): except zmq.ZMQError: break + # logger.info(f"[event_loop]{type(recv_req)}") if isinstance( recv_req, ( @@ -170,6 +340,8 @@ def run_data_parallel_controller_process( try: controller = DataParallelController(server_args, port_args) pipe_writer.send("ready") + if controller.recv_tree_cache_thread: + controller.recv_tree_cache_thread.start() controller.event_loop() except Exception: msg = get_exception_traceback() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9625ff44eb..3841736497 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -18,9 +18,11 @@ processes (TokenizerManager, DetokenizerManager, Controller). """ +import multiprocessing import uuid from dataclasses import dataclass from enum import Enum +from multiprocessing import Value from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason @@ -353,3 +355,19 @@ class AbortReq: class ProfileReq(Enum): START_PROFILE = 1 STOP_PROFILE = 2 + + +class ControllerInfo: + def __init__(self, dp_size): + self.available_kv_cache = [] + self.running_reqs = [] + self.waiting_reqs = [] + self.lock = multiprocessing.Lock() + + # For pre radix + self.radix_queue = multiprocessing.Queue() + + for i in range(dp_size): + self.available_kv_cache.append(Value("i", 0)) + self.running_reqs.append(Value("i", 0)) + self.waiting_reqs.append(Value("i", 0)) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 990fbeaa85..a47f76580d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -17,10 +17,12 @@ import json import logging +import multiprocessing import os import time import warnings from collections import deque +from copy import deepcopy from types import SimpleNamespace from typing import List, Optional, Union @@ -37,6 +39,7 @@ AbortReq, BatchEmbeddingOut, BatchTokenIDOut, + ControllerInfo, FlushCacheReq, ProfileReq, TokenizedEmbeddingReqInput, @@ -61,7 +64,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.mem_cache.chunk_cache import ChunkCache -from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixCacheSend from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( broadcast_pyobj, @@ -94,6 +97,7 @@ def __init__( gpu_id: int, tp_rank: int, dp_rank: Optional[int], + controller_info: Optional[ControllerInfo] = None, ): # Parse args self.server_args = server_args @@ -265,6 +269,22 @@ def __init__( with_stack=True, ) + # init controller info + if controller_info and self.tp_rank == 0: + self.controller_info = controller_info + self.gpu_id = gpu_id + self.controller_info.available_kv_cache[self.gpu_id].value = ( + self.token_to_kv_pool.available_size() + ) + if self.server_args.load_balance_method == "zmq_radix": + self.pre_radix = True + import threading + + self.change_cnt_lock = threading.Lock() + threading.Thread(target=self.loop_for_send_tree_cache).start() + else: + self.controller_info = None + @torch.inference_mode() def event_loop_normal(self): """A normal blocking scheduler loop.""" @@ -323,6 +343,25 @@ def event_loop_overlap(self): self.last_batch = batch + def loop_for_send_tree_cache(self): + while True: + self.send_tree_cache_to_queue() + time.sleep(1) + + def send_tree_cache_to_queue(self): + if self.pre_radix: + try: + node = deepcopy(self.tree_cache.root_node) + send_data = RadixCacheSend( + gpu_id=self.gpu_id, root_node=node, time=time.time() + ) + del node + self.controller_info.radix_queue.put(send_data) + # logger.info("[send_tree_cache_to_queue] has send new data") + except Exception as e: + # logger.info(f"[send_tree_cache_to_queue]error:{e}") + return + def recv_requests(self): if self.tp_rank == 0: recv_reqs = [] @@ -749,6 +788,20 @@ def process_batch_result(self, batch: ScheduleBatch, result): else: self.process_batch_result_prefill(batch, result) + # update controller info + if self.controller_info: + with self.controller_info.lock: + self.controller_info.available_kv_cache[self.gpu_id].value = ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + self.controller_info.running_reqs[self.gpu_id].value = ( + len(self.running_batch.reqs) if self.running_batch else 0 + ) + self.controller_info.waiting_reqs[self.gpu_id].value = len( + self.waiting_queue + ) + def process_batch_result_prefill(self, batch: ScheduleBatch, result): if self.is_generation: logits_output, next_token_ids, bid = result @@ -1113,6 +1166,7 @@ def run_scheduler_process( tp_rank: int, dp_rank: Optional[int], pipe_writer, + controller_info, ): if dp_rank is None: configure_logger(server_args, prefix=f" TP{tp_rank}") @@ -1122,7 +1176,9 @@ def run_scheduler_process( suppress_other_loggers() try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) + scheduler = Scheduler( + server_args, port_args, gpu_id, tp_rank, dp_rank, controller_info + ) pipe_writer.send("ready") if server_args.enable_overlap_schedule: scheduler.event_loop_overlap() diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 8cd8354b6b..012212155d 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -22,6 +22,7 @@ import heapq import time from collections import defaultdict +from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, List, Optional import torch @@ -33,6 +34,13 @@ from sglang.srt.managers.schedule_batch import Req +@dataclass +class RadixCacheSend: + gpu_id: int + root_node: TreeNode + time: time + + class TreeNode: def __init__(self): self.children = defaultdict(TreeNode) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6ccd891857..cc83b2bb85 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -444,9 +444,10 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[ "round_robin", "shortest_queue", + "resources_aware", + "pre_radix", ], ) - # Multi-node distributed serving args parser.add_argument( "--dist-init-addr", From 13387d30421bdfc63c98b20f3a59f9e77095c96e Mon Sep 17 00:00:00 2001 From: YouYangxiu Date: Mon, 21 Oct 2024 16:11:07 +0800 Subject: [PATCH 02/14] change the method name from zmq_raix to pre_radix --- python/sglang/srt/managers/data_parallel_controller.py | 5 +---- python/sglang/srt/managers/scheduler.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 87b2ab926a..794fe01579 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -124,10 +124,7 @@ def __init__(self, server_args, port_args) -> None: self.main_num_waiting_req = [] # For pre_radix - self.choosen_gpu_per_req = {} - - # For zmq_radix - self.zmq_raidx = server_args.load_balance_method == "zmq_radix" + self.zmq_raidx = server_args.load_balance_method == "pre_radix" # Start data parallel workers base_gpu_id = 0 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a47f76580d..5a17286302 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -276,7 +276,7 @@ def __init__( self.controller_info.available_kv_cache[self.gpu_id].value = ( self.token_to_kv_pool.available_size() ) - if self.server_args.load_balance_method == "zmq_radix": + if self.server_args.load_balance_method == "pre_radix": self.pre_radix = True import threading From 09603c6dc93244cc31de0a1092281bc685187a4f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 21 Oct 2024 01:43:16 -0700 Subject: [PATCH 03/14] Maintain seq_lens_sum to make more FlashInfer operations non-blocking (#1741) --- README.md | 1 - .../sglang/srt/layers/attention/__init__.py | 6 +- .../attention/double_sparsity_backend.py | 6 +- .../layers/attention/flashinfer_backend.py | 57 ++++++++++++++----- .../srt/layers/attention/triton_backend.py | 6 +- python/sglang/srt/managers/schedule_batch.py | 22 +++++-- .../srt/model_executor/cuda_graph_runner.py | 7 ++- .../srt/model_executor/forward_batch_info.py | 36 ++++++------ 8 files changed, 98 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index b413896715..538185cee8 100644 --- a/README.md +++ b/README.md @@ -621,7 +621,6 @@ Please cite our paper, [SGLang: Efficient Execution of Structured Language Model We also learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). -

Back To Top diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index f6d10170cb..ae0ef6b7d2 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -25,7 +25,11 @@ def init_forward_metadata_capture_cuda_graph( raise NotImplementedError() def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index e2cd98ec2c..c83fba8145 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -144,7 +144,11 @@ def init_forward_metadata_capture_cuda_graph( ) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, ): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c2cfa5fb6a..cd4aec8598 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -127,6 +127,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.indices_updater_decode.update( forward_batch.req_pool_indices, forward_batch.seq_lens, + forward_batch.seq_lens_sum, ) self.forward_metadata = (self.decode_wrappers,) else: @@ -134,10 +135,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # Some heuristics to check whether to use ragged forward use_ragged = False - if ( - torch.sum(forward_batch.seq_lens).item() >= 4096 - and self.num_wrappers == 1 - ): + if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1: use_ragged = True extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item() @@ -181,15 +179,25 @@ def init_forward_metadata_capture_cuda_graph( ) ) - self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers + ) self.cuda_graph_metadata[bs] = decode_wrappers self.forward_metadata = (decode_wrappers,) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, ): self.indices_updater_decode.update( - req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs] + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + self.cuda_graph_metadata[bs], ) def get_cuda_graph_seq_len_fill_value(self): @@ -305,13 +313,30 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): assert attn_backend.num_wrappers == 1 self.update = self.update_single_wrapper - def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None): + def update_single_wrapper( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + decode_wrappers=None, + ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( - decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None + decode_wrappers[0], + req_pool_indices, + seq_lens, + seq_lens_sum, + self.kv_indptr[0], + None, ) - def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None): + def update_sliding_window( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + decode_wrappers=None, + ): decode_wrappers = decode_wrappers or self.decode_wrappers for wrapper_id in range(2): @@ -331,6 +356,7 @@ def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None decode_wrappers[wrapper_id], req_pool_indices, paged_kernel_lens, + seq_lens_sum, self.kv_indptr[wrapper_id], kv_start_idx, ) @@ -339,13 +365,18 @@ def update_cross_attention(self): raise NotImplementedError() def call_begin_forward( - self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx + self, + wrapper, + req_pool_indices, + paged_kernel_lens, + seq_lens_sum, + kv_indptr, + kv_start_idx, ): bs = len(req_pool_indices) kv_indptr = kv_indptr[: bs + 1] - # TODO: optimize the blocking call on kv_indptr[-1] kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + kv_indices = torch.empty(seq_lens_sum, dtype=torch.int32, device="cuda") create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index e1f5bf3710..fb3805cfe5 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -91,7 +91,11 @@ def init_forward_metadata_capture_cuda_graph( ) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, ): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7fd153e80e..b0ab2dfe5d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -416,7 +416,6 @@ class ScheduleBatch: req_to_token_pool: ReqToTokenPool = None token_to_kv_pool: BaseTokenToKVPool = None tree_cache: BasePrefixCache = None - forward_mode: ForwardMode = None sampling_info: SamplingBatchInfo = None @@ -424,9 +423,13 @@ class ScheduleBatch: input_ids: torch.Tensor = None req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None + # The output locations of the KV cache out_cache_loc: torch.Tensor = None output_ids: torch.Tensor = None + # The sum of all sequence lengths + seq_lens_sum: int = None + # For processing logprobs return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -435,7 +438,6 @@ class ScheduleBatch: prefix_lens: List[int] = None extend_lens: List[int] = None extend_num_tokens: int = None - running_bs: int = None decoding_reqs: List[Req] = None # Stream @@ -549,10 +551,12 @@ def prepare_for_extend(self, vocab_size: int): self.device, non_blocking=True ) - self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc + + self.seq_lens_sum = sum(seq_lens) if self.return_logprob: self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + self.extend_num_tokens = extend_num_tokens self.prefix_lens = [len(r.prefix_indices) for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] @@ -571,12 +575,11 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): input_ids = torch.cat([self.input_ids, running_batch.input_ids]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) - extend_num_tokens = self.extend_num_tokens + running_bs self.merge_batch(running_batch) self.input_ids = input_ids self.out_cache_loc = out_cache_loc - self.extend_num_tokens = extend_num_tokens + self.extend_num_tokens += running_bs # NOTE: prefix_indices is what has been cached, but we don't cache each decode step self.prefix_lens.extend( @@ -775,6 +778,7 @@ def prepare_for_decode(self, enable_overlap: bool = False): (self.req_pool_indices, self.seq_lens), self.out_cache_loc ) self.seq_lens.add_(1) + self.seq_lens_sum += bs def filter_batch( self, @@ -805,6 +809,7 @@ def filter_batch( self.req_pool_indices = self.req_pool_indices[new_indices] self.seq_lens = self.seq_lens[new_indices] self.out_cache_loc = None + self.seq_lens_sum = self.seq_lens.sum().item() self.output_ids = self.output_ids[new_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: @@ -828,6 +833,7 @@ def merge_batch(self, other: "ScheduleBatch"): ) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.out_cache_loc = None + self.seq_lens_sum += other.seq_lens_sum if self.output_ids is not None: self.output_ids = torch.concat([self.output_ids, other.output_ids]) if self.return_logprob and other.return_logprob: @@ -873,9 +879,11 @@ def get_model_worker_batch(self): req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens, out_cache_loc=self.out_cache_loc, + seq_lens_sum=self.seq_lens_sum, req_to_token_pool_records=self.req_to_token_pool.get_write_records(), return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, + extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, extend_logprob_start_lens=extend_logprob_start_lens, @@ -917,6 +925,9 @@ class ModelWorkerBatch: # The indices of output tokens in the token_to_kv_pool out_cache_loc: torch.Tensor + # The sum of all sequence lengths + seq_lens_sum: int + # The memory pool operation records req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]] @@ -925,6 +936,7 @@ class ModelWorkerBatch: top_logprobs_nums: Optional[List[int]] # For extend + extend_num_tokens: Optional[int] extend_seq_lens: Optional[List[int]] extend_prefix_lens: Optional[List[int]] extend_logprob_start_lens: Optional[List[int]] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d3ff3cd1d5..37e3c84292 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -188,6 +188,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] out_cache_loc = self.out_cache_loc[:bs] + seq_lens_sum = seq_lens.sum().item() # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( @@ -206,6 +207,7 @@ def run_once(): token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, return_logprob=False, top_logprobs_nums=[0] * bs, positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), @@ -252,7 +254,10 @@ def replay(self, forward_batch: ForwardBatch): # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( - bs, self.req_pool_indices, self.seq_lens + bs, + self.req_pool_indices, + self.seq_lens, + forward_batch.seq_lens_sum, ) # Replay diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 49ef754a21..f4e117b760 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -87,6 +87,9 @@ class ForwardBatch: # The indices of output tokens in the token_to_kv_pool out_cache_loc: torch.Tensor + # The sum of all sequence lengths + seq_lens_sum: int + # For logprob return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -95,6 +98,7 @@ class ForwardBatch: positions: torch.Tensor = None # For extend + extend_num_tokens: Optional[int] = None extend_seq_lens: Optional[torch.Tensor] = None extend_prefix_lens: Optional[torch.Tensor] = None extend_start_loc: Optional[torch.Tensor] = None @@ -175,21 +179,6 @@ def compute_mrope_positions( ) self.mrope_positions = self.mrope_positions.to(torch.int64) - def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch): - device = model_runner.device - if self.forward_mode.is_decode(): - self.positions = (self.seq_lens - 1).to(torch.int64) - else: - self.positions = torch.concat( - [ - torch.arange(prefix_len, prefix_len + extend_len, device=device) - for prefix_len, extend_len in zip( - batch.extend_prefix_lens, batch.extend_seq_lens - ) - ], - axis=0, - ) - @classmethod def init_new( cls, @@ -205,6 +194,7 @@ def init_new( req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, out_cache_loc=batch.out_cache_loc, + seq_lens_sum=batch.seq_lens_sum, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, lora_paths=batch.lora_paths, @@ -213,7 +203,17 @@ def init_new( # Init position information if not ret.forward_mode.is_decode(): + ret.positions = torch.concat( + [ + torch.arange(prefix_len, prefix_len + extend_len, device=device) + for prefix_len, extend_len in zip( + batch.extend_prefix_lens, batch.extend_seq_lens + ) + ], + axis=0, + ) ret.image_inputs = batch.image_inputs + ret.extend_num_tokens = batch.extend_num_tokens ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 ).to(device, non_blocking=True) @@ -225,12 +225,8 @@ def init_new( ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens - # Init position information - is_mrope = model_runner.model_is_mrope - if is_mrope: + if model_runner.model_is_mrope: ret.compute_mrope_positions(model_runner, batch) - else: - ret.compute_positions(model_runner, batch) # Init attention information ret.req_to_token_pool = model_runner.req_to_token_pool From efb099cdee90b9ad332fcda96d89dd91ddebe072 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 21 Oct 2024 03:54:35 -0700 Subject: [PATCH 04/14] Fix prefill oom (#1743) --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 990fbeaa85..1b68bacd9b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -427,7 +427,7 @@ def handle_generate_request( if req.sampling_params.max_new_tokens is not None else 1 << 30 ), - self.max_req_input_len - 1 - len(req.origin_input_ids), + self.max_req_input_len - len(req.origin_input_ids), ) self.waiting_queue.append(req) From 7ce36068914503c3a53ad7be23ab29831fb8aa63 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 21 Oct 2024 04:30:52 -0700 Subject: [PATCH 05/14] Faster overlap mode scheduler (#1738) --- .../srt/managers/tp_worker_overlap_thread.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 5d78b97ce4..8b27d2a69a 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -55,7 +55,7 @@ def __init__( (self.max_running_requests * 5,), dtype=torch.int32, device=self.device ) - # Launch a thread + # Launch threads self.input_queue = Queue() self.output_queue = Queue() self.forward_stream = torch.cuda.Stream() @@ -64,6 +64,12 @@ def __init__( ) self.forward_thread.start() + self.copy_queue = Queue() + self.copy_thread = threading.Thread( + target=self.copy_thread_func, + ) + self.copy_thread.start() + def get_worker_info(self): return self.worker.get_worker_info() @@ -86,7 +92,10 @@ def forward_thread_func(self): @torch.inference_mode() def forward_thread_func_(self): while True: + self.has_inflight_batch = False model_worker_batch, future_token_ids_ct = self.input_queue.get() + self.has_inflight_batch = True + self.launch_event = threading.Event() # Resolve future tokens in the input input_ids = model_worker_batch.input_ids @@ -100,6 +109,7 @@ def forward_thread_func_(self): logits_output, next_token_ids = self.worker.forward_batch_generation( model_worker_batch ) + self.launch_event.set() # Update the future token ids map bs = len(model_worker_batch.seq_lens) @@ -113,13 +123,23 @@ def forward_thread_func_(self): torch.int32 ) - # Set the result - next_token_ids = next_token_ids.tolist() - assert logits_output.next_token_logprobs is None, "Not supported" - self.output_queue.put((None, next_token_ids)) + next_token_ids = next_token_ids.to("cpu", non_blocking=True) + copy_event = torch.cuda.Event(blocking=True) + copy_event.record() + self.copy_queue.put((copy_event, next_token_ids)) + + def copy_thread_func(self): + while True: + copy_event, next_token_ids = self.copy_queue.get() + while not copy_event.query(): + time.sleep(1e-5) + self.output_queue.put((None, next_token_ids.tolist())) def resulve_batch_result(self, bid: int): logits_output, next_token_ids = self.output_queue.get() + if self.has_inflight_batch: + # Wait until the batch is launched + self.launch_event.wait() return logits_output, next_token_ids def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): From e68b9e7667db64e240c25c1b872f7b4d69f54698 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 21 Oct 2024 06:28:32 -0700 Subject: [PATCH 06/14] misc: add CODEOWNERS (#1737) --- .github/CODEOWNERS | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..72c04a8ca3 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,13 @@ +/python/sglang/lang @merrymercy @Ying1123 @hnyls2002 @ByronHsu +/python/sglang/srt @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu +/python/sglang/srt/constrained @hnyls2002 +/python/sglang/srt/layers @merrymercy @Ying1123 @zhyncs @ispobock +/python/sglang/srt/lora @Ying1123 +/python/sglang/srt/managers @merrymercy @Ying1123 @hnyls2002 +/python/sglang/srt/mem_cache @merrymercy @Ying1123 @hnyls2002 +/python/sglang/srt/model_executor @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock +/python/sglang/srt/models @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu +/python/sglang/srt/openai_api @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu +/python/sglang/srt/sampling @merrymercy @hnyls2002 +/test/lang @merrymercy @Ying1123 @hnyls2002 @ByronHsu +/test/srt @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu From 00611286a1a57da6d305a634bf959beb8f5549f6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 21 Oct 2024 13:47:12 -0700 Subject: [PATCH 07/14] Fix sliding window attention and gemma-2 unit tests in CI (#1746) --- .../layers/attention/flashinfer_backend.py | 22 +++++++++++-------- python/sglang/test/runners.py | 21 +++++++++++++++++- test/srt/models/test_generation_models.py | 4 +--- test/srt/run_suite.py | 2 +- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index cd4aec8598..231300ce0c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -342,23 +342,25 @@ def update_sliding_window( for wrapper_id in range(2): if wrapper_id == 0: # Sliding window attention - paged_kernel_lens = torch.minimum( # TODO: replace this with clamp + paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp seq_lens, torch.tensor(self.sliding_window_size + 1), ) + paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item() + kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp else: # Full attention - paged_kernel_lens = seq_lens - - kv_start_idx = seq_lens - paged_kernel_lens + paged_kernel_lens_tmp = seq_lens + paged_kernel_lens_sum_tmp = seq_lens_sum + kv_start_idx_tmp = None self.call_begin_forward( decode_wrappers[wrapper_id], req_pool_indices, - paged_kernel_lens, - seq_lens_sum, + paged_kernel_lens_tmp, + paged_kernel_lens_sum_tmp, self.kv_indptr[wrapper_id], - kv_start_idx, + kv_start_idx_tmp, ) def update_cross_attention(self): @@ -369,14 +371,16 @@ def call_begin_forward( wrapper, req_pool_indices, paged_kernel_lens, - seq_lens_sum, + paged_kernel_lens_sum, kv_indptr, kv_start_idx, ): bs = len(req_pool_indices) kv_indptr = kv_indptr[: bs + 1] kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indices = torch.empty(seq_lens_sum, dtype=torch.int32, device="cuda") + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 8439aa8bbc..217065bd20 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -102,8 +102,10 @@ def needs_trust_remote_code(self, model_path): return False def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): - self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype) + # Apply model-specific patches + monkey_patch_gemma2_sdpa() + # Load the model and tokenizer if self.model_type == "generation": self.base_model = AutoModelForCausalLM.from_pretrained( model_path, @@ -128,7 +130,9 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): ).cuda() else: raise Exception(f"Unrecognized model type {self.model_type}") + self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype) + # Run forward while True: prompts, max_new_tokens, lora_paths = in_queue.get() if lora_paths is not None: @@ -370,3 +374,18 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.runtime.shutdown() del self.runtime + + +def monkey_patch_gemma2_sdpa(): + """ + Use sdpa by default to fix the OOM issue. + Revert this commit: + https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660 + """ + from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel + + def _check_and_enable_sdpa(config, hard_check_only: bool = False): + config._attn_implementation = "sdpa" + return config + + setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa) diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index ba4c05ee48..9cd1f4207c 100755 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -46,9 +46,7 @@ class ModelCase: # Popular models that run on the CI CI_MODELS = [ ModelCase("meta-llama/Llama-3.1-8B-Instruct"), - ModelCase( - "google/gemma-2-2b", skip_long_prompt=True - ), # There is a bug with new transformers library. This can only run with transformers==4.44 + ModelCase("google/gemma-2-2b"), ] # All other models that do not run on the CI diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3f8a1fecb1..e8fadcef72 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -15,7 +15,7 @@ "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", "test_json_constrained.py", - "test_large_max_new_tokens.py", + # "test_large_max_new_tokens.py", # This test hangs on CI due to unknown reasons "test_openai_server.py", "test_overlap_schedule.py", "test_pytorch_sampling_backend.py", From 94cde10920035648b0554abec5323176eea8486d Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 21 Oct 2024 15:01:21 -0700 Subject: [PATCH 08/14] Llama3.2 vision model support (#1551) --- python/pyproject.toml | 37 +- python/sglang/bench_latency.py | 3 +- python/sglang/lang/chat_template.py | 1 + python/sglang/srt/configs/model_config.py | 2 + python/sglang/srt/conversation.py | 13 + .../sglang/srt/layers/attention/__init__.py | 15 +- .../attention/double_sparsity_backend.py | 22 +- .../layers/attention/flashinfer_backend.py | 157 ++- .../srt/layers/attention/triton_backend.py | 22 +- python/sglang/srt/managers/image_processor.py | 88 +- python/sglang/srt/managers/schedule_batch.py | 165 ++- python/sglang/srt/managers/scheduler.py | 3 +- .../sglang/srt/managers/tokenizer_manager.py | 14 +- python/sglang/srt/mem_cache/memory_pool.py | 21 +- .../srt/model_executor/cuda_graph_runner.py | 60 +- .../srt/model_executor/forward_batch_info.py | 13 +- .../sglang/srt/model_executor/model_runner.py | 7 +- python/sglang/srt/models/mllama.py | 1004 +++++++++++++++++ python/sglang/srt/models/qwen2_vl.py | 6 +- python/sglang/srt/utils.py | 1 + test/srt/test_vision_openai_server.py | 24 +- 21 files changed, 1559 insertions(+), 119 deletions(-) create mode 100644 python/sglang/srt/models/mllama.py diff --git a/python/pyproject.toml b/python/pyproject.toml index df62361623..d51fc2331e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -8,16 +8,12 @@ version = "0.3.4" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" -license = {file = "LICENSE"} +license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] -dependencies = [ - "requests", - "tqdm", - "numpy", -] +dependencies = ["requests", "tqdm", "numpy"] [project.optional-dependencies] runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", @@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] -test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"] +test = [ + "jsonlines", + "matplotlib", + "pandas", + "sentence_transformers", + "accelerate", + "peft", +] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] @@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"] "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" [tool.setuptools.packages.find] -exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] +exclude = [ + "assets*", + "benchmark*", + "docs*", + "dist*", + "playground*", + "scripts*", + "tests*", +] [tool.wheel] -exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] +exclude = [ + "assets*", + "benchmark*", + "docs*", + "dist*", + "playground*", + "scripts*", + "tests*", +] diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index a05398812b..43cb7bc3fb 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -227,8 +227,9 @@ def extend(reqs, model_runner): req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, tree_cache=None, + model_config=model_runner.model_config, ) - batch.prepare_for_extend(model_runner.model_config.vocab_size) + batch.prepare_for_extend() model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) logits_output = model_runner.forward(forward_batch) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index d7602964d4..b8f9a533de 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -229,6 +229,7 @@ def get_chat_template_by_model_path(model_path): ), }, stop_str=("<|eot_id|>",), + image_token="<|image|>", ) ) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index c1493faadb..a3c59e8d82 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -89,6 +89,8 @@ def __init__( self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.vocab_size = self.hf_text_config.vocab_size + self.is_encoder_decoder = self.hf_config.model_type in ["mllama"] + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 73bbc1e2ee..42b2d70d58 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -509,6 +509,19 @@ def generate_chat_conv( ) ) +register_conv_template( + Conversation( + name="llama_3_vision", + system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + image_token="<|image|>", + ) +) + register_conv_template( Conversation( name="llava_llama_3", diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index ae0ef6b7d2..f5d573f5f7 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod +from typing import Optional import torch from torch import nn +from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -19,7 +21,11 @@ def init_cuda_graph_state(self, max_bs: int): raise NotImplementedError() def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor] = None, ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -30,6 +36,7 @@ def init_forward_metadata_replay_cuda_graph( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor] = None, ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() @@ -43,7 +50,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer: nn.Module, + layer: RadixAttention, forward_batch: ForwardBatch, ): """Run forward on an attention layer.""" @@ -57,7 +64,7 @@ def forward_decode( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer: nn.Module, + layer: RadixAttention, forward_batch: ForwardBatch, ): """Run a forward for decode.""" @@ -68,7 +75,7 @@ def forward_extend( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer: nn.Module, + layer: RadixAttention, forward_batch: ForwardBatch, ): """Run a forward for extend.""" diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index c83fba8145..73c32df8f6 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -10,6 +10,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner @@ -134,8 +135,13 @@ def init_cuda_graph_state(self, max_bs: int): ) def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens=None, ): + # NOTE: encoder_lens expected to be zeros or None self.forward_metadata = ( self.cuda_graph_start_loc, self.cuda_graph_attn_logits, @@ -149,14 +155,18 @@ def init_forward_metadata_replay_cuda_graph( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens=None, ): + # NOTE: encoder_lens expected to be zeros or None self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) def get_cuda_graph_seq_len_fill_value(self): return 1 - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_extend( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) @@ -172,7 +182,7 @@ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) ) forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v, k_label + layer, forward_batch.out_cache_loc, k, v, k_label ) ( @@ -201,7 +211,9 @@ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) ) return o - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_decode( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) @@ -231,7 +243,7 @@ def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) ) forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v, k_label + layer, forward_batch.out_cache_loc, k, v, k_label ) # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 231300ce0c..e5e7ca29c9 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING import torch -import torch.nn as nn import triton import triton.language as tl @@ -21,6 +20,7 @@ from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner if is_flashinfer_available(): @@ -56,13 +56,13 @@ def __init__(self, model_runner: ModelRunner): assert not ( model_runner.sliding_window_size is not None - and model_runner.has_cross_attention + and model_runner.model_config.is_encoder_decoder ), "Sliding window and cross attention are not supported together" if model_runner.sliding_window_size is not None: self.num_wrappers = 2 self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW - elif model_runner.has_cross_attention: + elif model_runner.model_config.is_encoder_decoder: self.num_wrappers = 2 self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION else: @@ -128,6 +128,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, + decode_wrappers=None, + encoder_lens=forward_batch.encoder_lens, ) self.forward_metadata = (self.decode_wrappers,) else: @@ -144,13 +146,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.req_pool_indices, forward_batch.seq_lens, prefix_lens, - use_ragged, + use_ragged=use_ragged, + encoder_lens=forward_batch.encoder_lens, ) - self.forward_metadata = ( - use_ragged, - extend_no_prefix, - ) + self.forward_metadata = (use_ragged, extend_no_prefix) def init_cuda_graph_state(self, max_bs: int): cuda_graph_kv_indices = torch.zeros( @@ -163,7 +163,11 @@ def init_cuda_graph_state(self, max_bs: int): ] def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: torch.Tensor = None, ): decode_wrappers = [] for i in range(self.num_wrappers): @@ -181,7 +185,11 @@ def init_forward_metadata_capture_cuda_graph( seq_lens_sum = seq_lens.sum().item() self.indices_updater_decode.update( - req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=decode_wrappers, + encoder_lens=encoder_lens, ) self.cuda_graph_metadata[bs] = decode_wrappers self.forward_metadata = (decode_wrappers,) @@ -192,34 +200,42 @@ def init_forward_metadata_replay_cuda_graph( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens: torch.Tensor = None, ): self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], seq_lens_sum, - self.cuda_graph_metadata[bs], + decode_wrappers=self.cuda_graph_metadata[bs], + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, ) def get_cuda_graph_seq_len_fill_value(self): return 0 - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_extend( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): prefill_wrapper_paged = self.prefill_wrappers_paged[ self._get_wrapper_idx(layer) ] use_ragged, extend_no_prefix = self.forward_metadata + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) if not use_ragged: if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=True, + causal=not layer.is_cross_attention, sm_scale=layer.scaling, window_left=layer.sliding_window_size, logits_soft_cap=layer.logit_cap, @@ -247,20 +263,23 @@ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) o, _ = merge_state(o1, s1, o2, s2) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) return o.view(-1, layer.tp_q_head_num * layer.head_dim) - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_decode( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -271,7 +290,7 @@ def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) return o.view(-1, layer.tp_q_head_num * layer.head_dim) - def _get_wrapper_idx(self, layer: nn.Module): + def _get_wrapper_idx(self, layer: RadixAttention): if self.num_wrappers == 1: return 0 @@ -298,6 +317,8 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.sliding_window_size = model_runner.sliding_window_size + self.attn_backend = attn_backend + # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len @@ -305,20 +326,27 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.decode_wrappers = attn_backend.decode_wrappers # Dispatch - if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: self.update = self.update_sliding_window - elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: self.update = self.update_cross_attention else: - assert attn_backend.num_wrappers == 1 + assert self.attn_backend.num_wrappers == 1 self.update = self.update_single_wrapper + def update( + self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens + ): + # Keep the signature for type checking, will be initialized during runtime + raise NotImplementedError() + def update_single_wrapper( self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrappers=None, + encoder_lens=None, ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( @@ -336,6 +364,7 @@ def update_sliding_window( seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrappers=None, + encoder_lens=None, ): decode_wrappers = decode_wrappers or self.decode_wrappers @@ -363,8 +392,35 @@ def update_sliding_window( kv_start_idx_tmp, ) - def update_cross_attention(self): - raise NotImplementedError() + def update_cross_attention( + self, + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=None, + encoder_lens=None, + ): + decode_wrappers = decode_wrappers or self.decode_wrappers + + for wrapper_id in range(2): + if wrapper_id == 0: + # Normal attention + paged_kernel_lens = seq_lens + kv_start_idx = encoder_lens + else: + # Cross attention + paged_kernel_lens = encoder_lens + kv_start_idx = torch.zeros_like(encoder_lens) + seq_lens_sum = encoder_lens.sum().item() + + self.call_begin_forward( + decode_wrappers[wrapper_id], + req_pool_indices, + paged_kernel_lens, + seq_lens_sum, + self.kv_indptr[wrapper_id], + kv_start_idx, + ) def call_begin_forward( self, @@ -421,6 +477,8 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.sliding_window_size = model_runner.sliding_window_size + self.attn_backend = attn_backend + # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len @@ -430,16 +488,20 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.wrappers_paged = attn_backend.prefill_wrappers_paged # Dispatch - if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: self.update = self.update_sliding_window - elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: self.update = self.update_cross_attention else: - assert attn_backend.num_wrappers == 1 + assert self.attn_backend.num_wrappers == 1 self.update = self.update_single_wrapper + def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens): + # Keep the signature for type checking, will be initialized during runtime + raise NotImplementedError() + def update_single_wrapper( - self, req_pool_indices, seq_lens, prefix_lens, use_ragged + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens ): if use_ragged: paged_kernel_lens = prefix_lens @@ -460,7 +522,7 @@ def update_single_wrapper( ) def update_sliding_window( - self, req_pool_indices, seq_lens, prefix_lens, use_ragged + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens ): for wrapper_id in range(2): if wrapper_id == 0: @@ -487,8 +549,31 @@ def update_sliding_window( use_ragged, ) - def update_cross_attention(self): - raise NotImplementedError() + def update_cross_attention( + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens + ): + for wrapper_id in range(2): + if wrapper_id == 0: + # normal attention + paged_kernel_lens = seq_lens + kv_start_idx = encoder_lens + else: + # cross attention + paged_kernel_lens = encoder_lens + kv_start_idx = torch.zeros_like(encoder_lens) + + self.call_begin_forward( + self.wrapper_ragged, + self.wrappers_paged[wrapper_id], + req_pool_indices, + paged_kernel_lens, + seq_lens, + prefix_lens, + kv_start_idx, + self.kv_indptr[wrapper_id], + self.qo_indptr[wrapper_id], + use_ragged, + ) def call_begin_forward( self, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index fb3805cfe5..47b8d3cd56 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -10,6 +10,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner @@ -81,8 +82,13 @@ def init_cuda_graph_state(self, max_bs: int): ) def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens=None, ): + # NOTE: encoder_lens expected to be zeros or None self.forward_metadata = ( self.cuda_graph_start_loc, self.cuda_graph_attn_logits, @@ -96,14 +102,18 @@ def init_forward_metadata_replay_cuda_graph( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens=None, ): + # NOTE: encoder_lens expected to be zeros or None self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) def get_cuda_graph_seq_len_fill_value(self): return 1 - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_extend( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) @@ -111,7 +121,7 @@ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) o = torch.empty_like(q) forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v + layer, forward_batch.out_cache_loc, k, v ) start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata @@ -133,7 +143,9 @@ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) ) return o - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_decode( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) @@ -147,7 +159,7 @@ def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v + layer, forward_batch.out_cache_loc, k, v ) self.decode_attention_fwd( diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index b958ab89bc..08ad150232 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -33,26 +33,32 @@ def init_global_processor(server_args: ServerArgs): class BaseImageProcessor(ABC): + def __init__(self, hf_config, server_args, _processor): + self.hf_config = hf_config + self._processor = _processor + self.executor = concurrent.futures.ProcessPoolExecutor( + initializer=init_global_processor, + mp_context=mp.get_context("fork"), + initargs=(server_args,), + max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()), + ) + @abstractmethod - async def process_images_async(self, image_data, **kwargs): + async def process_images_async(self, image_data, input_text, **kwargs): pass class DummyImageProcessor(BaseImageProcessor): + def __init__(self): + pass + async def process_images_async(self, *args, **kwargs): return None class LlavaImageProcessor(BaseImageProcessor): - def __init__(self, hf_config, server_args, _image_processor): - self.hf_config = hf_config - self._image_processor = _image_processor - self.executor = concurrent.futures.ProcessPoolExecutor( - initializer=init_global_processor, - mp_context=mp.get_context("fork"), - initargs=(server_args,), - max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()), - ) + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) @staticmethod def _process_single_image_task( @@ -119,7 +125,7 @@ async def _process_single_image( ) async def process_images_async( - self, image_data: List[Union[str, bytes]], request_obj + self, image_data: List[Union[str, bytes]], input_text, request_obj ): if not image_data: return None @@ -177,6 +183,54 @@ async def process_images_async( } +class MllamaImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + + @staticmethod + def _process_single_image_task(images, input_text): + # input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask' + return global_processor(images, input_text, return_tensors="pt") + + async def _process_single_image(self, images, input_text): + if self.executor is not None: + loop = asyncio.get_event_loop() + image_inputs = await loop.run_in_executor( + self.executor, + MllamaImageProcessor._process_single_image_task, + images, + input_text, + ) + else: + image_inputs = self._processor(images, input_text, return_tensors="pt") + + return image_inputs + + async def process_images_async( + self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs + ): + if not image_data: + return None + + if isinstance(input_text, list): + assert len(input_text) and isinstance(input_text[0], int) + input_text = self._processor.tokenizer.decode(input_text) + + if not isinstance(image_data, list): + image_data = [image_data] + + if len(image_data) > 0: + images = [load_image(image)[0] for image in image_data] + else: + images = load_image(image_data[0])[0] + + image_inputs = await self._process_single_image(images, input_text) + image_inputs["image_hashes"] = [hash(str(image_data))] + image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] + + return image_inputs + + class Qwen2VLImageProcessor(BaseImageProcessor): def __init__(self, hf_config, server_args, _image_processor): self.hf_config = hf_config @@ -237,7 +291,7 @@ async def _process_single_image(self, image_data: Union[bytes, str]): return self._process_single_image_task(image_data) async def process_images_async( - self, image_data: List[Union[str, bytes]], request_obj + self, image_data: List[Union[str, bytes]], input_text, request_obj ): if not image_data: return None @@ -292,12 +346,14 @@ async def process_images_async( def get_image_processor( - hf_config, server_args: ServerArgs, _image_processor + hf_config, server_args: ServerArgs, processor ) -> BaseImageProcessor: - if "Qwen2VLForConditionalGeneration" in hf_config.architectures: - return Qwen2VLImageProcessor(hf_config, server_args, _image_processor) + if "MllamaForConditionalGeneration" in hf_config.architectures: + return MllamaImageProcessor(hf_config, server_args, processor) + elif "Qwen2VLForConditionalGeneration" in hf_config.architectures: + return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor) else: - return LlavaImageProcessor(hf_config, server_args, _image_processor) + return LlavaImageProcessor(hf_config, server_args, processor.image_processor) def get_dummy_image_processor(): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b0ab2dfe5d..bcf3103ad2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -36,6 +36,7 @@ import torch from sglang.global_config import global_config +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -121,11 +122,12 @@ class ImageInputs: """The image related inputs.""" pixel_values: torch.Tensor - image_hash: int + image_hashes: Optional[list] = None image_sizes: Optional[list] = None image_offsets: Optional[list] = None pad_values: Optional[list] = None modalities: Optional[list] = None + num_image_tokens: Optional[int] = None image_embeds: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None @@ -138,19 +140,27 @@ def from_dict(obj, vocab_size): # Use image hash as fake token_ids, which is then used for prefix matching ret = ImageInputs( pixel_values=obj["pixel_values"], - image_hash=hash(tuple(obj["image_hashes"])), - image_grid_thws=obj.get("image_grid_thws"), + image_hashes=hash(tuple(obj["image_hashes"])), ) - image_hash = ret.image_hash + image_hash = ret.image_hashes ret.pad_values = [ (image_hash) % vocab_size, (image_hash >> 16) % vocab_size, (image_hash >> 32) % vocab_size, (image_hash >> 64) % vocab_size, ] - ret.image_sizes = obj["image_sizes"] - # Only when pixel values is not None we have modalities - ret.modalities = obj["modalities"] or ["image"] + + optional_args = [ + "image_sizes", + "modalities", + "aspect_ratio_ids", + "aspect_ratio_mask", + "image_grid_thws", + ] + for arg in optional_args: + if arg in obj: + setattr(ret, arg, obj[arg]) + return ret @@ -416,6 +426,10 @@ class ScheduleBatch: req_to_token_pool: ReqToTokenPool = None token_to_kv_pool: BaseTokenToKVPool = None tree_cache: BasePrefixCache = None + + # For utility + model_config: ModelConfig = None + forward_mode: ForwardMode = None sampling_info: SamplingBatchInfo = None @@ -440,6 +454,12 @@ class ScheduleBatch: extend_num_tokens: int = None decoding_reqs: List[Req] = None + # For encoder-decoder + encoder_cached: Optional[List[bool]] = None + encoder_lens: Optional[torch.Tensor] = None + encoder_lens_cpu: Optional[List[int]] = None + encoder_out_cache_loc: Optional[torch.Tensor] = None + # Stream has_stream: bool = False @@ -450,12 +470,20 @@ class ScheduleBatch: device: str = "cuda" @classmethod - def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): + def init_new( + cls, + reqs, + req_to_token_pool, + token_to_kv_pool, + tree_cache, + model_config, + ): return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, + model_config=model_config, return_logprob=any(req.return_logprob for req in reqs), has_stream=any(req.stream for req in reqs), has_regex=any(req.regex_fsm for req in reqs), @@ -493,7 +521,78 @@ def alloc_token_slots(self, num_tokens: int): return out_cache_loc - def prepare_for_extend(self, vocab_size: int): + def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): + self.encoder_lens_cpu = [] + self.encoder_cached = [] + + for req in self.reqs: + im = req.image_inputs + if im is None or im.num_image_tokens is None: + # No image input + self.encoder_lens_cpu.append(0) + self.encoder_cached.append(True) + else: + self.encoder_lens_cpu.append(im.num_image_tokens) + self.encoder_cached.append( + self.forward_mode.is_decode() + or len(req.prefix_indices) >= im.num_image_tokens + ) + + self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to( + self.device, non_blocking=True + ) + + # Strip encoder infos + pt = 0 + decoder_out_cache_loc = [] + encoder_out_cache_loc = [] + for i, req in enumerate(self.reqs): + encoder_len = self.encoder_lens_cpu[i] + seq_lens[i] -= encoder_len + + if len(req.prefix_indices) < encoder_len: + # NOTE: the encoder part should considered as a whole + assert len(req.prefix_indices) == 0 + input_ids[i] = input_ids[i][encoder_len:] + encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) + decoder_out_cache_loc.append( + self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] + ) + self.extend_lens[i] -= encoder_len + self.extend_num_tokens -= encoder_len + else: + decoder_out_cache_loc.append( + self.out_cache_loc[pt : pt + req.extend_input_len] + ) + self.prefix_lens[i] -= encoder_len + + pt += req.extend_input_len + + # Reassign + self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( + self.device, non_blocking=True + ) + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.device, non_blocking=True + ) + + if not decoder_out_cache_loc: + self.out_cache_loc = torch.empty(0, dtype=torch.int32).to( + self.device, non_blocking=True + ) + else: + self.out_cache_loc = torch.cat(decoder_out_cache_loc) + + if not encoder_out_cache_loc: + self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to( + self.device, non_blocking=True + ) + else: + self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc) + + assert len(self.out_cache_loc) == self.extend_num_tokens + + def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND bs = len(self.reqs) @@ -561,8 +660,13 @@ def prepare_for_extend(self, vocab_size: int): self.extend_lens = [r.extend_input_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + if self.model_config.is_encoder_decoder: + self.prepare_encoder_info_extend(input_ids, seq_lens) + self.sampling_info = SamplingBatchInfo.from_schedule_batch( - self, vocab_size, global_server_args_dict["disable_penalizer"] + self, + self.model_config.vocab_size, + global_server_args_dict["disable_penalizer"], ) def mix_with_running(self, running_batch: "ScheduleBatch"): @@ -752,6 +856,10 @@ def check_for_jump_forward(self, pad_input_ids_func): return jump_forward_reqs + def prepare_encoder_info_decode(self): + # Reset the encoder cached status + self.encoder_cached = [True] * len(self.reqs) + def prepare_for_decode(self, enable_overlap: bool = False): self.forward_mode = ForwardMode.DECODE @@ -766,16 +874,22 @@ def prepare_for_decode(self, enable_overlap: bool = False): bs = len(self.reqs) self.out_cache_loc = self.alloc_token_slots(bs) + if self.model_config.is_encoder_decoder: + locs = self.encoder_lens + self.seq_lens + self.prepare_encoder_info_decode() + else: + locs = self.seq_lens + if enable_overlap: # Do not use in-place operations in the overlap mode self.req_to_token_pool.write( - (self.req_pool_indices, self.seq_lens), self.out_cache_loc + (self.req_pool_indices, locs), self.out_cache_loc ) self.seq_lens = self.seq_lens + 1 else: # A faster in-place version self.req_to_token_pool.write( - (self.req_pool_indices, self.seq_lens), self.out_cache_loc + (self.req_pool_indices, locs), self.out_cache_loc ) self.seq_lens.add_(1) self.seq_lens_sum += bs @@ -802,6 +916,10 @@ def filter_batch( # No need to filter return + if self.model_config.is_encoder_decoder: + self.encoder_lens = self.encoder_lens[keep_indices] + self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] + self.reqs = [self.reqs[i] for i in keep_indices] new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( self.device, non_blocking=True @@ -828,6 +946,11 @@ def merge_batch(self, other: "ScheduleBatch"): # needs to be called with pre-merged Batch.reqs. self.sampling_info.merge_batch(other.sampling_info) + # Encoder-decoder infos + if self.model_config.is_encoder_decoder: + self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens]) + self.encoder_lens_cpu.extend(other.encoder_lens_cpu) + self.req_pool_indices = torch.concat( [self.req_pool_indices, other.req_pool_indices] ) @@ -850,14 +973,11 @@ def merge_batch(self, other: "ScheduleBatch"): def get_model_worker_batch(self): if self.forward_mode.is_decode(): - extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = ( - image_inputs - ) = None + extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: extend_seq_lens = self.extend_lens extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens - image_inputs = [r.image_inputs for r in self.reqs] if self.has_regex: self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] @@ -887,7 +1007,11 @@ def get_model_worker_batch(self): extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, extend_logprob_start_lens=extend_logprob_start_lens, - image_inputs=image_inputs, + image_inputs=[r.image_inputs for r in self.reqs], + encoder_cached=self.encoder_cached, + encoder_lens=self.encoder_lens, + encoder_lens_cpu=self.encoder_lens_cpu, + encoder_out_cache_loc=self.encoder_out_cache_loc, lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, mrope_positions_delta=mrope_positions_delta, @@ -897,6 +1021,7 @@ def copy(self): # Only contain fields that will be used by process_batch_result return ScheduleBatch( reqs=self.reqs, + model_config=self.model_config, forward_mode=self.forward_mode, out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, @@ -944,6 +1069,12 @@ class ModelWorkerBatch: # For multimodal image_inputs: Optional[List[ImageInputs]] + # For encoder-decoder + encoder_cached: Optional[List[bool]] + encoder_lens: Optional[torch.Tensor] + encoder_lens_cpu: Optional[List[int]] + encoder_out_cache_loc: Optional[torch.Tensor] + # For LoRA lora_paths: Optional[List[str]] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1b68bacd9b..b2f217c852 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -662,8 +662,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.req_to_token_pool, self.token_to_kv_pool, self.tree_cache, + self.model_config, ) - new_batch.prepare_for_extend(self.model_config.vocab_size) + new_batch.prepare_for_extend() # Mixed-style chunked prefill if self.is_mixed_chunk and self.running_batch is not None: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2bc7ff04b4..fc9e235198 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -122,7 +122,7 @@ def __init__( # We want to parallelize the image pre-processing so we create an executor for it self.image_processor = get_image_processor( - self.hf_config, server_args, self.processor.image_processor + self.hf_config, server_args, self.processor ) else: self.tokenizer = get_tokenizer( @@ -191,8 +191,10 @@ async def _send_single_request( sampling_params = self._get_sampling_params(obj.sampling_params) if self.is_generation: image_inputs = await self.image_processor.process_images_async( - obj.image_data, obj + obj.image_data, input_text or input_ids, obj ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len top_logprobs_num = obj.top_logprobs_num @@ -217,8 +219,10 @@ async def _send_single_request( sampling_params = self._get_sampling_params(obj.sampling_params[index]) if self.is_generation: image_inputs = await self.image_processor.process_images_async( - obj.image_data[index], obj + obj.image_data[index], input_text or input_ids, obj ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] return_logprob = obj.return_logprob[index] logprob_start_len = obj.logprob_start_len[index] top_logprobs_num = obj.top_logprobs_num[index] @@ -263,8 +267,10 @@ async def _send_single_request( sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params.max_new_tokens = 0 image_inputs = await self.image_processor.process_images_async( - obj.image_data[0], obj + obj.image_data[0], input_text or input_ids, obj ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index bd42dfc72d..4277862a7e 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -26,6 +26,8 @@ import torch +from sglang.srt.layers.radix_attention import RadixAttention + logger = logging.getLogger(__name__) @@ -41,13 +43,17 @@ def __init__(self, size: int, max_context_len: int, device: str, use_records: bo ) self.free_slots = list(range(size)) self.write_records = [] + self.use_records = use_records - if use_records: - # records all write operations + if self.use_records: self.write = self.write_with_records else: self.write = self.write_without_records + def write(self, indices, values): + # Keep the signature for type checking, will be initialized during runtime + raise NotImplementedError() + def available_size(self): return len(self.free_slots) @@ -154,7 +160,7 @@ def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: def set_kv_buffer( self, - layer_id: int, + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, @@ -209,11 +215,12 @@ def get_kv_buffer(self, layer_id: int): def set_kv_buffer( self, - layer_id: int, + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ): + layer_id = layer.layer_id if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) if cache_v.dtype != self.dtype: @@ -265,11 +272,12 @@ def get_kv_buffer(self, layer_id: int): def set_kv_buffer( self, - layer_id: int, + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ): + layer_id = layer.layer_id if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) if self.store_dtype != self.dtype: @@ -324,13 +332,14 @@ def get_kv_buffer(self, layer_id: int): def set_kv_buffer( self, - layer_id: int, + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, cache_label: torch.Tensor, ): # NOTE(Andy): ignore the dtype check + layer_id = layer.layer_id self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v self.label_buffer[layer_id][loc] = cache_label diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 37e3c84292..ffa77ec4c9 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -105,6 +105,7 @@ def __init__(self, model_runner: "ModelRunner"): self.graph_memory_pool = None self.use_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder # Batch sizes to capture if self.model_runner.server_args.disable_cuda_graph_padding: @@ -132,6 +133,9 @@ def __init__(self, model_runner: "ModelRunner"): self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) + # FIXME(lsyin): leave it here for now, I don't know whether it is necessary + self.encoder_len_fill_value = 0 + if self.use_torch_compile: set_torch_compile_config() @@ -144,9 +148,18 @@ def __init__(self, model_runner: "ModelRunner"): ) self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) + if self.is_encoder_decoder: + # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch + self.encoder_lens = torch.full( + (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 + ) + else: + self.encoder_lens = None + # Capture try: - self.capture() + with self.model_capture_mode(): + self.capture() except RuntimeError as e: raise Exception( f"Capture cuda graph failed: {e}\n" @@ -157,11 +170,32 @@ def __init__(self, model_runner: "ModelRunner"): "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" ) - def can_run(self, batch_size: int): - if self.disable_padding: - return batch_size in self.graphs - else: - return batch_size <= self.max_bs + @contextmanager + def model_capture_mode(self): + if hasattr(self.model_runner.model, "capture_mode"): + self.model_runner.model.capture_mode = True + + yield + + if hasattr(self.model_runner.model, "capture_mode"): + self.model_runner.model.capture_mode = False + + def can_run(self, forward_batch: ForwardBatch): + is_bs_supported = ( + forward_batch.batch_size in self.graphs + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) + + # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) + # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph + # because the full_text_row_masked_out_mask tensor will always be ones + is_encoder_lens_supported = ( + torch.all(forward_batch.encoder_lens > 0) + if self.is_encoder_decoder + else True + ) + return is_bs_supported and is_encoder_lens_supported def capture(self): with graph_capture() as graph_capture_context: @@ -188,11 +222,19 @@ def capture_one_batch_size(self, bs: int, forward: Callable): req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] out_cache_loc = self.out_cache_loc[:bs] + if self.is_encoder_decoder: + encoder_lens = self.encoder_lens[:bs] + else: + encoder_lens = None + seq_lens_sum = seq_lens.sum().item() # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( - bs, req_pool_indices, seq_lens + bs, + req_pool_indices, + seq_lens, + encoder_lens, ) # Run and capture @@ -208,6 +250,7 @@ def run_once(): attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, seq_lens_sum=seq_lens_sum, + encoder_lens=encoder_lens, return_logprob=False, top_logprobs_nums=[0] * bs, positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), @@ -251,6 +294,8 @@ def replay(self, forward_batch: ForwardBatch): self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) + if self.is_encoder_decoder: + self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( @@ -258,6 +303,7 @@ def replay(self, forward_batch: ForwardBatch): self.req_pool_indices, self.seq_lens, forward_batch.seq_lens_sum, + self.encoder_lens, ) # Replay diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index f4e117b760..f3065d7a2b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -108,6 +108,12 @@ class ForwardBatch: # For multimodal image_inputs: Optional[List[ImageInputs]] = None + # Encoder-decoder + encoder_cached: Optional[List[bool]] = None + encoder_lens: Optional[torch.Tensor] = None + encoder_lens_cpu: Optional[List[int]] = None + encoder_out_cache_loc: Optional[torch.Tensor] = None + # For LoRA lora_paths: Optional[List[str]] = None @@ -194,6 +200,11 @@ def init_new( req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, out_cache_loc=batch.out_cache_loc, + image_inputs=batch.image_inputs, + encoder_cached=batch.encoder_cached, + encoder_lens=batch.encoder_lens, + encoder_lens_cpu=batch.encoder_lens_cpu, + encoder_out_cache_loc=batch.encoder_out_cache_loc, seq_lens_sum=batch.seq_lens_sum, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, @@ -212,11 +223,11 @@ def init_new( ], axis=0, ) - ret.image_inputs = batch.image_inputs ret.extend_num_tokens = batch.extend_num_tokens ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 ).to(device, non_blocking=True) + ret.extend_prefix_lens = torch.tensor( batch.extend_prefix_lens, dtype=torch.int32 ).to(device, non_blocking=True) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 291528e079..e2a2504cbd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -270,7 +270,6 @@ def load_model(self): if hasattr(self.model, "get_attention_sliding_window_size") else None ) - self.has_cross_attention = getattr(self.model, "has_cross_attention", False) self.is_generation = is_generation_model( self.model_config.hf_config.architectures, self.server_args.is_embedding ) @@ -510,7 +509,7 @@ def init_attention_backend(self): "Window attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." ) - assert not self.has_cross_attention, ( + assert not self.model_config.is_encoder_decoder, ( "Cross attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." ) @@ -558,9 +557,7 @@ def init_cuda_graphs(self): self.cuda_graph_runner = CudaGraphRunner(self) def forward_decode(self, forward_batch: ForwardBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run( - forward_batch.batch_size - ): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): return self.cuda_graph_runner.replay(forward_batch) forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py new file mode 100644 index 0000000000..7db6f0e1f1 --- /dev/null +++ b/python/sglang/srt/models/mllama.py @@ -0,0 +1,1004 @@ +# Adapted from: +# https://github.com/vllm-project/vllm/blob/7193774b1ff8603ad5bf4598e5efba0d9a39b436/vllm/model_executor/models/mllama.py +"""PyTorch Mllama model.""" +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers.models.mllama.configuration_mllama as config_mllama +import vllm.distributed.parallel_state as ps +from torch import nn +from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast +from transformers.models.mllama.modeling_mllama import ( + _prepare_aspect_ratio_attention_mask, +) +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP + + +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: bool = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1) + x, _ = self._linear(x) + return x + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size + ) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size, + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size + ) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaVisionSdpaAttention(nn.Module): + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + + model_parallel_size = get_tensor_model_parallel_world_size() + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + self.num_local_heads = self.num_heads // model_parallel_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + input_is_parallel=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_state) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view( + q.shape[0], q.shape[1], self.num_local_heads, self.head_dim + ).transpose(1, 2) + k = k.view( + k.shape[0], k.shape[1], self.num_local_heads, self.head_dim + ).transpose(1, 2) + v = v.view( + v.shape[0], v.shape[1], self.num_local_heads, self.head_dim + ).transpose(1, 2) + + # TODO: remove padding in image encoder + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=0.0 + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape( + attn_output.shape[0], attn_output.shape[1], -1 + ) + output, _ = self.o_proj(attn_output) + return output + + +class MllamaVisionMLP(nn.Module): + def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__( + self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False + ): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention(config) + self.mlp = MllamaVisionMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + self.hidden_size, eps=config.norm_eps + ) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + def __init__( + self, + config: config_mllama.MllamaVisionConfig, + num_layers=32, + is_gated=False, + output_hidden_states=None, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)] + ) + self.output_hidden_states = output_hidden_states or [] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + encoder_states = () + + for i, encoder_layer in enumerate(self.layers): + if i in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + + if len(self.layers) - 1 in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.in_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = ColumnParallelConv2dPatch( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + config, is_gated=True + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + config, is_gated=True + ) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + output_hidden_states=config.intermediate_layers_indices, + ) + self.global_transformer = MllamaVisionEncoder( + config, config.num_global_layers, is_gated=True + ) + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + ) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( + pixel_values.shape + ) + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + # patch embedding + patch_embeds = self.patch_embedding( + pixel_values.to(self.layernorm_pre.weight.dtype) + ) + hidden_state = patch_embeds + hidden_state = ps.get_tp_group().all_gather(hidden_state) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, -1, dim + ) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1 + ) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.layernorm_pre.weight.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state, intermediate_hidden_states = output[0], output[1] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + dim, + ) + hidden_state = self.global_transformer( + hidden_state, attention_mask=attention_mask + )[0] + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + return hidden_state + + +class MllamaTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + def __init__( + self, + config: Optional[config_mllama.MllamaTextConfig] = None, + layer_id: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.model_parallel_size = get_tensor_model_parallel_world_size() + self.num_heads = self.config.num_attention_heads + self.num_local_heads = self.num_heads // self.model_parallel_size + self.num_key_value_heads = self.config.num_key_value_heads + self.num_local_key_value_heads = ( + self.num_key_value_heads // self.model_parallel_size + ) + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_id = layer_id + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_local_size = self.num_local_heads * self.head_dim + self.kv_local_size = self.num_local_key_value_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + ) + # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, + # use huggingface's instead + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.scaling = self.head_dim**-0.5 + + self.attn = RadixAttention( + self.num_local_heads, + self.head_dim, + self.scaling, + self.num_local_key_value_heads, + layer_id=layer_id, + is_cross_attention=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor], + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv_dec, _ = self.qkv_proj(hidden_states) + q, _, _ = qkv_dec.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1 + ) + if cross_attention_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(cross_attention_states) + _, k, v = qkv_enc.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1 + ) + k = k.view(-1, self.num_local_key_value_heads, self.head_dim) + v = v.view(-1, self.num_local_key_value_heads, self.head_dim) + k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) + q = self.q_norm(q) + + output = self.attn(q, k, v, forward_batch) + out, _ = self.o_proj(output) + return out + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention + and feedforward.""" + + def __init__( + self, + config: config_mllama.MllamaTextConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig], + ) -> None: + super().__init__() + self.layer_id = layer_id + self.cross_attn = MllamaTextCrossAttention( + config=config, + layer_id=layer_id, + quant_config=quant_config, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + forward_batch=forward_batch, + ) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + return hidden_states + + +class MllamaTextModel(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "model" + + def __init__( + self, + config: config_mllama.MllamaTextConfig, + quant_config: Optional[QuantizationConfig], + cache_config=None, + ): + super().__init__() + self.padding_id = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size + 8, config.hidden_size + ) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_id in range(config.num_hidden_layers): + if layer_id in self.cross_attention_layers: + layers.append( + MllamaCrossAttentionDecoderLayer( + config, layer_id, quant_config=quant_config + ) + ) + else: + # TODO: force LlamaDecoderLayer to config.attention_bias=False + layers.append( + LlamaDecoderLayer( + config, quant_config=quant_config, layer_id=layer_id + ) + ) + + self.layers = nn.ModuleList(layers) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], + forward_batch: ForwardBatch, + skip_cross_attention: bool, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + for _, decoder_layer in enumerate(self.layers): + if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): + if not skip_cross_attention: + hidden_states = decoder_layer( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + forward_batch=forward_batch, + ) + elif isinstance(decoder_layer, LlamaDecoderLayer): + hidden_states, residual = decoder_layer( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + residual=None, + ) + hidden_states = hidden_states + residual + else: + raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}") + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MllamaForCausalLM(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "language_model" + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", + "MllamaSelfAttentionDecoderLayer", + ] + + def __init__( + self, + config: config_mllama.MllamaTextConfig, + quant_config: Optional[QuantizationConfig], + cache_config=None, + ): + super().__init__() + self.vocab_size = config.vocab_size + self.model = MllamaTextModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], + forward_batch: ForwardBatch, + skip_cross_attention: bool, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + forward_batch=forward_batch, + skip_cross_attention=skip_cross_attention, + ) + return hidden_states + + +class MllamaForConditionalGeneration(nn.Module): + def __init__( + self, + config: config_mllama.MllamaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config=None, + ): + super().__init__() + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + self.image_size = config.vision_config.image_size + + self.vision_model = MllamaVisionModel(config.vision_config) + self.language_model = MllamaForCausalLM( + config.text_config, + cache_config=cache_config, + quant_config=quant_config, + ) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.logits_processor = LogitsProcessor(config.text_config) + self.capture_mode = False + + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + pixel_values = image_inputs.pixel_values + pad_values = image_inputs.pad_values + + num_concurrent_media, num_tiles = pixel_values.shape[1:3] + num_patches = self.vision_model.num_patches + image_len = num_concurrent_media * num_tiles * num_patches + image_inputs.num_image_tokens = image_len + + pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values)) + + return pad_ids[:image_len] + input_ids + + def _batch_image_inputs(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode() or all(forward_batch.encoder_cached): + return None, None, None, None + + # pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res) + max_num_images = max_num_tiles = bs = 0 + for i, im in enumerate(forward_batch.image_inputs): + if not forward_batch.encoder_cached[i] and im is not None: + max_num_images = max(max_num_images, im.pixel_values.shape[1]) + max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2]) + bs += 1 + + if max_num_images * max_num_tiles * bs == 0: + return None, None, None, None + + with forward_batch.out_cache_loc.device: + batched_images = torch.zeros( + bs, + max_num_images, + max_num_tiles, + 3, + self.image_size, + self.image_size, + dtype=torch.float32, + ) + batched_ar_ids = torch.ones( + bs, max_num_images, dtype=torch.int64, device="cuda" + ) + batched_ar_mask = torch.zeros( + bs, max_num_images, max_num_tiles, dtype=torch.int64 + ) + i = 0 + encoder_lens_need = [] + for k, im in enumerate(forward_batch.image_inputs): + if forward_batch.encoder_cached[k] or im is None: + continue + + encoder_lens_need.append(forward_batch.encoder_lens[k]) + for j in range(im.pixel_values.shape[1]): + img = im.pixel_values[0, j] + num_tiles = img.shape[0] + batched_images[i, j, :num_tiles] = img + batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j] + batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j] + i += 1 + + return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need + + def flat_encoder_result( + self, cross_attention_states: torch.Tensor, encoder_lens_need: List[int] + ): + # NOTE: not all encoders need computation, some are cached + head_dim = cross_attention_states.shape[-1] + total_encoder_len = sum(encoder_lens_need) + cross_attention_states_flat = torch.zeros( + total_encoder_len, + head_dim, + device=cross_attention_states.device, + dtype=cross_attention_states.dtype, + ) + + i = start_pos = 0 + for encoder_len in encoder_lens_need: + if encoder_len == 0: + continue + end_pos = start_pos + encoder_len + cross_attention_states_flat[start_pos:end_pos] = cross_attention_states[i][ + :encoder_len + ] + i += 1 + start_pos += encoder_len + + return cross_attention_states_flat + + def get_full_text_row_masked_out_mask(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode(): + full_text_row_masked_out_mask = forward_batch.encoder_lens != 0 + else: + full_text_row_masked_out_mask = torch.ones( + forward_batch.extend_seq_lens.sum(), dtype=torch.bool + ) + start_pos = 0 + + for seq_len, encoder_len in zip( + forward_batch.seq_lens.tolist(), forward_batch.encoder_lens_cpu + ): + if encoder_len == 0: + full_text_row_masked_out_mask[start_pos : start_pos + seq_len] = ( + False + ) + start_pos += encoder_len + + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + forward_batch.seq_lens.device + ) + + return full_text_row_masked_out_mask.reshape(-1, 1) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> Union[Tuple, CausalLMOutputWithPast]: + batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = ( + self._batch_image_inputs(forward_batch) + ) + + # TODO: support multi-image by this mask + cross_attention_mask = None + cross_attention_states = None + + if self.capture_mode: + # NOTE: when doing cuda graph capture, we do not want to skip cross attention + # Make is a constant value to avoid cuda graph capture issue + skip_cross_attention = False + else: + # NOTE: we do not need image_inputs when prefill + assert len(forward_batch.encoder_lens) == len(forward_batch.seq_lens) + assert len(forward_batch.encoder_lens_cpu) == len(forward_batch.seq_lens) + skip_cross_attention = forward_batch.encoder_lens.max() == 0 + + if not skip_cross_attention: + full_text_row_masked_out_mask = self.get_full_text_row_masked_out_mask( + forward_batch + ) + else: + full_text_row_masked_out_mask = None + + if batched_images is not None: + # NOTE: llama's reference implementation runs vision model on CPU + cross_attention_states = self.vision_model( + batched_images, batched_ar_ids, batched_ar_mask + ) + cross_attention_states = self.multi_modal_projector(cross_attention_states) + + bs, _, _, _, image_token_dim = cross_attention_states.shape + cross_attention_states = cross_attention_states.view( + bs, -1, image_token_dim + ) + + cross_attention_states = self.flat_encoder_result( + cross_attention_states, encoder_lens_need + ) + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + forward_batch=forward_batch, + skip_cross_attention=skip_cross_attention, + ) + return self.logits_processor( + input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params = set() + for name, loaded_weight in weights: + if "patch_embedding.weight" in name: + name = name.replace( + "patch_embedding.weight", "patch_embedding._linear.weight" + ) + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = MllamaForConditionalGeneration diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index ae2d4f58c6..b1cc787710 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -605,7 +605,11 @@ def forward( ] positions = forward_batch.mrope_positions - if image_inputs is None or len(image_inputs) == 0: + if ( + forward_batch.forward_mode.is_decode() + or image_inputs is None + or len(image_inputs) == 0 + ): inputs_embeds = self.model.embed_tokens(input_ids) else: if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 762f3933d6..69aea52acf 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures): or "LlavaQwenForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures + or "MllamaForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures ): return True diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 296572ea90..bf8f9d2775 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -171,7 +171,7 @@ def test_mult_images_chat_completion(self): assert isinstance(text, str) print(text) assert "man" in text or "cab" in text, text - assert "logo" in text, text + assert "logo" in text or '"S"' in text or "SG" in text, text assert response.id assert response.created assert response.usage.prompt_tokens > 0 @@ -363,5 +363,27 @@ def setUpClass(cls): cls.base_url += "/v1" +class TestMllamaServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "meta-llama/Llama-3.2-11B-Vision-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--chat-template", + "llama_3_vision", + ], + ) + cls.base_url += "/v1" + + def test_video_chat_completion(self): + pass + + if __name__ == "__main__": unittest.main() From 5e1558f1f26f0fc060ea261c9e81b767dc8e3fb9 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 21 Oct 2024 16:12:04 -0700 Subject: [PATCH 09/14] Update `max_req_len` and `max_req_input_len` (#1748) --- python/sglang/srt/managers/scheduler.py | 4 +++- python/sglang/srt/managers/tp_worker.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b2f217c852..210a243a4b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -165,6 +165,7 @@ def __init__( self.max_total_num_tokens, self.max_prefill_tokens, self.max_running_requests, + self.max_req_len, self.max_req_input_len, self.random_seed, self.device, @@ -421,13 +422,14 @@ def handle_generate_request( "the max context length. Truncated!!!" ) req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + req.sampling_params.max_new_tokens = min( ( req.sampling_params.max_new_tokens if req.sampling_params.max_new_tokens is not None else 1 << 30 ), - self.max_req_input_len - len(req.origin_input_ids), + self.max_req_len - len(req.origin_input_ids) - 1, ) self.waiting_queue.append(req) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 302c5d740b..561bfd77c5 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -90,10 +90,14 @@ def __init__( ), self.model_runner.req_to_token_pool.size, ) - self.max_req_input_len = min( + self.max_req_len = min( self.model_config.context_len - 1, self.max_total_num_tokens - 1, ) + self.max_req_input_len = self.max_req_len - 5 + assert ( + self.max_req_len > 0 and self.max_req_input_len > 0 + ), "Memory pool size is too small" # Sync random seed across TP workers self.random_seed = broadcast_pyobj( @@ -108,6 +112,7 @@ def get_worker_info(self): self.max_total_num_tokens, self.max_prefill_tokens, self.max_running_requests, + self.max_req_len, self.max_req_input_len, self.random_seed, self.device, From 1f26e8b8e4c8b884e59036dccd87929b2af592f9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 21 Oct 2024 21:16:43 -0700 Subject: [PATCH 10/14] Release v0.3.4.post1 (#1749) --- README.md | 2 +- python/pyproject.toml | 2 +- python/sglang/version.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 538185cee8..81a08e0851 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ### Method 2: From source ``` # Use the last release branch -git clone -b v0.3.4 https://github.com/sgl-project/sglang.git +git clone -b v0.3.4.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip diff --git a/python/pyproject.toml b/python/pyproject.toml index d51fc2331e..27ad6756fc 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.3.4" +version = "0.3.4.post1" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" diff --git a/python/sglang/version.py b/python/sglang/version.py index 334b899568..03a502ce0d 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.3.4" +__version__ = "0.3.4.post1" From 17536e7e3dde0518097dd4c22cea35f7db8e5d5a Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Tue, 22 Oct 2024 21:00:25 -0700 Subject: [PATCH 11/14] Fix edge case for truncated (#1747) --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 210a243a4b..16f4196bd3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -416,7 +416,7 @@ def handle_generate_request( ) # Truncate prompts that are too long - if len(req.origin_input_ids) >= self.max_req_input_len: + if len(req.origin_input_ids) > self.max_req_input_len: logger.warning( "Request length is longer than the KV cache pool size or " "the max context length. Truncated!!!" From ad4125d1a9c4796cdbc6c6a5cdb69b09e60e5509 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 22 Oct 2024 23:20:43 -0700 Subject: [PATCH 12/14] Fuse more ops & Simplify token mapping (#1758) --- docs/en/benchmark_and_profiling.md | 6 +- .../layers/attention/flashinfer_backend.py | 10 +-- python/sglang/srt/layers/sampler.py | 81 ++++++++++--------- .../srt/managers/tp_worker_overlap_thread.py | 36 +++++---- python/sglang/srt/mem_cache/memory_pool.py | 27 ++++--- .../srt/model_executor/cuda_graph_runner.py | 8 +- python/sglang/test/run_eval.py | 2 + test/srt/test_eval_accuracy_mini.py | 1 + test/srt/test_pytorch_sampling_backend.py | 3 +- 9 files changed, 99 insertions(+), 75 deletions(-) diff --git a/docs/en/benchmark_and_profiling.md b/docs/en/benchmark_and_profiling.md index 77fbbfc1b6..c0f54957d1 100644 --- a/docs/en/benchmark_and_profiling.md +++ b/docs/en/benchmark_and_profiling.md @@ -46,4 +46,8 @@ pip install nvtx import nvtx with nvtx.annotate("description", color="color"): # some critical code -``` \ No newline at end of file +``` + +## Other tips + +1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder. diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index e5e7ca29c9..c6b5393ee9 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -337,7 +337,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def update( self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens ): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( @@ -432,8 +432,8 @@ def call_begin_forward( kv_start_idx, ): bs = len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device="cuda" ) @@ -497,7 +497,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.update = self.update_single_wrapper def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( @@ -589,8 +589,8 @@ def call_begin_forward( use_ragged, ): bs = len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, @@ -602,8 +602,8 @@ def call_begin_forward( self.max_context_len, ) + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) # extend part if use_ragged: diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 454078d59c..54fc47b736 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -33,56 +33,61 @@ def forward( if isinstance(logits, LogitsProcessorOutput): logits = logits.next_token_logits - # Post process logits logits = logits.contiguous() - logits.div_(sampling_info.temperatures) - probs = torch.softmax(logits, dim=-1) - logits = None - del logits - if self.use_nan_detectioin and torch.any(torch.isnan(probs)): - logger.warning("Detected errors during sampling! NaN in the probability.") - probs = torch.where( - torch.isnan(probs), torch.full_like(probs, 1e-10), probs + if self.use_nan_detectioin and torch.any(torch.isnan(logits)): + logger.warning("Detected errors during sampling! NaN in the logits.") + logits = torch.where( + torch.isnan(logits), torch.full_like(logits, -1e5), logits ) if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling - batch_next_token_ids = torch.argmax(probs, -1) - elif global_server_args_dict["sampling_backend"] == "flashinfer": - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device - ) - if sampling_info.need_min_p_sampling: - probs = top_k_renorm_prob(probs, sampling_info.top_ks) - probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids, success = min_p_sampling_from_probs( - probs, uniform_samples, sampling_info.min_ps + batch_next_token_ids = torch.argmax(logits, -1) + else: + # Post process logits + logits.div_(sampling_info.temperatures) + probs = torch.softmax(logits, dim=-1) + logits = None + del logits + + if global_server_args_dict["sampling_backend"] == "flashinfer": + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device ) - else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + if sampling_info.need_min_p_sampling: + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids, success = min_p_sampling_from_probs( + probs, uniform_samples, sampling_info.min_ps + ) + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + sampling_info.top_ks, + sampling_info.top_ps, + filter_apply_order="joint", + ) + + if not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + elif global_server_args_dict["sampling_backend"] == "pytorch": + # A slower fallback implementation with torch native operations. + batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( probs, - uniform_samples, sampling_info.top_ks, sampling_info.top_ps, - filter_apply_order="joint", + sampling_info.min_ps, + ) + else: + raise ValueError( + f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) - if not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like(batch_next_token_ids) - elif global_server_args_dict["sampling_backend"] == "pytorch": - # Here we provide a slower fallback implementation. - batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( - probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps - ) - else: - raise ValueError( - f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" - ) - - return batch_next_token_ids + return batch_next_token_ids.to(torch.int32) def top_k_top_p_min_p_sampling_from_probs_torch( diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 8b27d2a69a..8032915e7b 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -32,6 +32,15 @@ logger = logging.getLogger(__name__) +@torch.compile(dynamic=True) +def resolve_future_token_ids(input_ids, future_token_ids_map): + input_ids[:] = torch.where( + input_ids < 0, + future_token_ids_map[torch.clamp(-input_ids, min=0)], + input_ids, + ) + + class TpModelWorkerClient: """A tensor parallel model worker.""" @@ -99,33 +108,25 @@ def forward_thread_func_(self): # Resolve future tokens in the input input_ids = model_worker_batch.input_ids - input_ids[:] = torch.where( - input_ids < 0, - self.future_token_ids_map[torch.clamp(-input_ids, min=0)], - input_ids, - ) + resolve_future_token_ids(input_ids, self.future_token_ids_map) # Run forward logits_output, next_token_ids = self.worker.forward_batch_generation( model_worker_batch ) - self.launch_event.set() # Update the future token ids map bs = len(model_worker_batch.seq_lens) - future_next_token_ids = torch.arange( - -(future_token_ids_ct + bs), - -(future_token_ids_ct), - dtype=torch.int32, - device=self.device, - ) - self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to( - torch.int32 - ) + self.future_token_ids_map[ + future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 + ] = next_token_ids + # Copy results to the CPU next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_event = torch.cuda.Event(blocking=True) copy_event.record() + + self.launch_event.set() self.copy_queue.put((copy_event, next_token_ids)) def copy_thread_func(self): @@ -149,8 +150,9 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): # Allocate output future objects bs = len(model_worker_batch.seq_lens) future_next_token_ids = torch.arange( - -(self.future_token_ids_ct + bs), - -(self.future_token_ids_ct), + -(self.future_token_ids_ct + 1), + -(self.future_token_ids_ct + 1 + bs), + -1, dtype=torch.int32, device=self.device, ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 4277862a7e..181ac7eefe 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -51,7 +51,7 @@ def __init__(self, size: int, max_context_len: int, device: str, use_records: bo self.write = self.write_without_records def write(self, indices, values): - # Keep the signature for type checking, will be initialized during runtime + # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def available_size(self): @@ -221,16 +221,21 @@ def set_kv_buffer( cache_v: torch.Tensor, ): layer_id = layer.layer_id - if cache_k.dtype != self.dtype: - cache_k = cache_k.to(self.dtype) - if cache_v.dtype != self.dtype: - cache_v = cache_v.to(self.dtype) - if self.store_dtype != self.dtype: - self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) - self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) - else: - self.k_buffer[layer_id][loc] = cache_k - self.v_buffer[layer_id][loc] = cache_v + copy_two_array( + loc, + self.k_buffer[layer_id], + cache_k, + self.v_buffer[layer_id], + cache_v, + self.dtype, + self.store_dtype, + ) + + +@torch.compile(dynamic=True) +def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): + dst_1[loc] = src_1.to(dtype).view(store_dtype) + dst_2[loc] = src_2.to(dtype).view(store_dtype) class MLATokenToKVPool(BaseTokenToKVPool): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ffa77ec4c9..b859df3588 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -92,6 +92,11 @@ def set_torch_compile_config(): torch._dynamo.config.accumulated_cache_size_limit = 1024 +@torch.compile(dynamic=True) +def clamp_position(seq_lens): + return torch.clamp((seq_lens - 1), min=0).to(torch.int64) + + class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" @@ -112,7 +117,6 @@ def __init__(self, model_runner: "ModelRunner"): self.capture_bs = list(range(1, 32)) + [64, 128] else: self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - self.capture_bs = [ bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size ] @@ -253,7 +257,7 @@ def run_once(): encoder_lens=encoder_lens, return_logprob=False, top_logprobs_nums=[0] * bs, - positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), + positions=clamp_position(seq_lens), ) return forward(input_ids, forward_batch.positions, forward_batch) diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 51b32ca01b..fe88171ce2 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -67,6 +67,7 @@ def run_eval(args): model=args.model, max_tokens=2048, base_url=base_url, + temperature=getattr(args, "temperature", 0.0), ) # Run eval @@ -119,6 +120,7 @@ def run_eval(args): parser.add_argument("--eval-name", type=str, default="mmlu") parser.add_argument("--num-examples", type=int) parser.add_argument("--num-threads", type=int, default=512) + parser.add_argument("--temperature", type=float, default=0.0) args = parser.parse_args() run_eval(args) diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index 6ddd97d940..ee977a6368 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -31,6 +31,7 @@ def test_mmlu(self): eval_name="mmlu", num_examples=64, num_threads=32, + temperature=0.1, ) metrics = run_eval(args) diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py index 5507182a73..ee06de8fae 100644 --- a/test/srt/test_pytorch_sampling_backend.py +++ b/test/srt/test_pytorch_sampling_backend.py @@ -23,7 +23,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--sampling-backend", "pytorch"], + other_args=["--sampling-backend", "pytorch", "--disable-radix-cache"], ) @classmethod @@ -37,6 +37,7 @@ def test_mmlu(self): eval_name="mmlu", num_examples=64, num_threads=32, + temperature=0.1, ) metrics = run_eval(args) From 2fce449b1c0a6cadde4946984426336621baed22 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Wed, 23 Oct 2024 00:02:29 -0700 Subject: [PATCH 13/14] [API] add get memory pool size (#1760) Co-authored-by: Byron Hsu --- python/sglang/srt/managers/detokenizer_manager.py | 4 ++++ python/sglang/srt/managers/io_struct.py | 10 ++++++++++ python/sglang/srt/managers/scheduler.py | 6 ++++++ python/sglang/srt/managers/tokenizer_manager.py | 14 ++++++++++++++ python/sglang/srt/server.py | 12 ++++++++++++ test/srt/test_srt_endpoint.py | 4 ++++ 6 files changed, 50 insertions(+) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 4ae31ecc8b..d0d399363f 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -27,6 +27,7 @@ BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + GetMemPoolSizeReqOutput, UpdateWeightReqOutput, ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN @@ -111,6 +112,9 @@ def event_loop(self): # If it is a weight update request, no detokenization is needed. self.send_to_tokenizer.send_pyobj(recv_obj) continue + elif isinstance(recv_obj, GetMemPoolSizeReqOutput): + self.send_to_tokenizer.send_pyobj(recv_obj) + continue elif self.tokenizer is None: # If the tokenizer is skipped, no detokenization is needed self.send_to_tokenizer.send_pyobj(recv_obj) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9625ff44eb..2cdc3f4785 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -353,3 +353,13 @@ class AbortReq: class ProfileReq(Enum): START_PROFILE = 1 STOP_PROFILE = 2 + + +@dataclass +class GetMemPoolSizeReq: + pass + + +@dataclass +class GetMemPoolSizeReqOutput: + size: int diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 16f4196bd3..60531ce251 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -38,6 +38,8 @@ BatchEmbeddingOut, BatchTokenIDOut, FlushCacheReq, + GetMemPoolSizeReq, + GetMemPoolSizeReqOutput, ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -363,6 +365,10 @@ def process_input_requests(self, recv_reqs: List): self.start_profile() else: self.stop_profile() + elif isinstance(recv_req, GetMemPoolSizeReq): + self.send_to_detokenizer.send_pyobj( + GetMemPoolSizeReqOutput(self.max_total_num_tokens) + ) else: raise ValueError(f"Invalid request: {recv_req}") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index fc9e235198..875239a941 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,8 @@ EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + GetMemPoolSizeReq, + GetMemPoolSizeReqOutput, ProfileReq, RewardReqInput, TokenizedEmbeddingReqInput, @@ -531,6 +533,15 @@ def stop_profile(self): req = ProfileReq.STOP_PROFILE self.send_to_scheduler.send_pyobj(req) + async def get_memory_pool_size(self): + if self.to_create_loop: + self.create_handle_loop() + + req = GetMemPoolSizeReq() + self.send_to_scheduler.send_pyobj(req) + self.mem_pool_size = asyncio.Future() + return await self.mem_pool_size + async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None ): @@ -590,6 +601,9 @@ async def handle_loop(self): if isinstance(recv_obj, UpdateWeightReqOutput): self.model_update_result.set_result(recv_obj) continue + elif isinstance(recv_obj, GetMemPoolSizeReqOutput): + self.mem_pool_size.set_result(recv_obj) + continue assert isinstance( recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ceb2d55c28..8912c5583a 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -172,6 +172,18 @@ async def stop_profile(): ) +@app.api_route("/get_memory_pool_size", methods=["GET", "POST"]) +async def get_memory_pool_size(): + """Get the memory pool size in number of tokens""" + try: + ret = await tokenizer_manager.get_memory_pool_size() + return ret.size + except Exception as e: + return JSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + @app.post("/update_weights") async def update_weights(obj: UpdateWeightReqInput, request: Request): """Update the weights inplace without re-launching the server.""" diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 9a0a37c607..c4c8e844d6 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -119,6 +119,10 @@ def test_logprob_start_len(self): [x[-1] for x in res["meta_info"]["output_token_logprobs"]] ) + def test_get_memory_pool_size(self): + response = requests.post(self.base_url + "/get_memory_pool_size") + assert isinstance(response.json(), int) + if __name__ == "__main__": unittest.main() From f2119176e98b2c5f7493617cf255c5b9c2b40787 Mon Sep 17 00:00:00 2001 From: YouYangxiu Date: Wed, 23 Oct 2024 18:27:34 +0800 Subject: [PATCH 14/14] fix bug in comments --- python/sglang/srt/managers/data_parallel_controller.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 794fe01579..ae2b76d827 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -124,7 +124,7 @@ def __init__(self, server_args, port_args) -> None: self.main_num_waiting_req = [] # For pre_radix - self.zmq_raidx = server_args.load_balance_method == "pre_radix" + self.pre_raidx = server_args.load_balance_method == "pre_radix" # Start data parallel workers base_gpu_id = 0 @@ -140,7 +140,7 @@ def __init__(self, server_args, port_args) -> None: self.workers.append(send_to) base_gpu_id += server_args.tp_size - if self.zmq_raidx: + if self.pre_raidx: import threading self.newest_tree_cache = {} @@ -231,7 +231,7 @@ def update_memory_and_requests(self): if not self.main_available_kv_cache: self.main_available_kv_cache = available_mem.copy() if self.pre_available_kv_cache != available_mem: - self.pre_available_kv_cache = available_mem.copy() + self.pre_available_kv_cache = available_mem self.main_available_kv_cache = available_mem.copy() if not self.pre_num_running_req: @@ -239,7 +239,7 @@ def update_memory_and_requests(self): if not self.main_num_running_req: self.main_num_running_req = num_reqs_running.copy() if self.pre_num_running_req != num_reqs_running: - self.main_num_running_req = num_reqs_running.copy() + self.main_num_running_req = num_reqs_running self.pre_num_running_req = num_reqs_running.copy() if not self.pre_num_waiting_req: @@ -247,7 +247,7 @@ def update_memory_and_requests(self): if not self.main_num_waiting_req: self.main_num_waiting_req = num_reqs_waiting.copy() if self.pre_num_waiting_req != num_reqs_waiting: - self.main_num_waiting_req = num_reqs_waiting.copy() + self.main_num_waiting_req = num_reqs_waiting self.pre_num_waiting_req = num_reqs_waiting.copy() def allocate_gpu(self, req):