From 843e63d809f59e1446d6b0a61306c9a001b404d6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 4 Sep 2024 04:15:11 -0700 Subject: [PATCH 01/33] Fix the flaky test test_moe_eval_accuracy_large.py (#1326) --- test/srt/test_moe_eval_accuracy_large.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index 1183cc4e7a..d4b1354b79 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -66,7 +66,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.64, f"{metrics}" + assert metrics["score"] >= 0.63, f"{metrics}" if __name__ == "__main__": From 5ab9418f5b4c9ad1a90d72a803331d9a0b26d13e Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 4 Sep 2024 21:21:21 +1000 Subject: [PATCH 02/33] [Doc] update news (#1327) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d56a243df4..eb3099cf7a 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,8 @@ The core features include: - **Flexible Frontend Language**: Enables easy programming of LLM applications with chained generation calls, advanced prompting, control flow, multiple modalities, parallelism, and external interactions. ## News +- [2024/09] 🔥 SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). - [2024/07] 🔥 Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). -- [2024/08] 🔥 LLaVA-OneVision with single-image, multi-image and video are supported ([blog](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)). - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
From eda7c09048b39bd03b0e34aa16ffef9398072663 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 4 Sep 2024 05:37:32 -0700 Subject: [PATCH 03/33] Remove useless fields in global_config.py (#1328) --- python/sglang/global_config.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index d5f16e2ae5..7bd5aa0901 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -11,10 +11,6 @@ def __init__(self): # Default backend of the language self.default_backend = None - # Runtime constants: Request dependency time due to network delay - self.request_dependency_delay = 0.02 - self.wait_for_new_request_delay = 0.0006 - # Runtime constants: New generation token ratio estimation self.init_new_token_ratio = 0.7 self.base_min_new_token_ratio = 0.1 From 3494b32c3a77e32d1a492b8c2a408b3662c08229 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 5 Sep 2024 23:39:44 +1000 Subject: [PATCH 04/33] docs: update README (#1336) --- benchmark/benchmark_vllm_060/README.md | 83 ++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 benchmark/benchmark_vllm_060/README.md diff --git a/benchmark/benchmark_vllm_060/README.md b/benchmark/benchmark_vllm_060/README.md new file mode 100644 index 0000000000..acb55f8971 --- /dev/null +++ b/benchmark/benchmark_vllm_060/README.md @@ -0,0 +1,83 @@ +## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0 + +## Installation + +```bash +# install sglang v0.3.0 +pip install --upgrade pip +pip install "sglang[all]"==0.3.0 +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ + +# install vllm v0.6.0 +pip install vllm==0.6.0 +``` + +## Online benchmarks + +```bash +# Llama 3.1 8B Instruct on 1 x A100 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096 + +# Llama 3.1 70B Instruct on 4 x H100 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096 + +# bench serving +python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 1200 --request-rate 4 +python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 2400 --request-rate 8 +python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 1200 --request-rate 4 +python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 2400 --request-rate 8 +``` + +## Offline benchmarks + +```bash +# Llama 3.1 8B Instruct on 1 x A100 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096 + +# Llama 3.1 70B Instruct on 4 x H100 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 --mem-frac 0.88 +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096 + +# bench serving +python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000 +python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000 +``` + +## Online benchmark results + +### Llama 3.1 8B Instruct 1 x A100 80G + +| RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | +|------|-------------|--------|--------------------|-------------|-------------|------------| +| 4 | 1200 | SGLang | 1564.17 | 31.98 | 13.17 | 11.93 | +| 4 | 1200 | vLLM | 1691.97 | 100.48 | 14.14 | 129.32 | +| 8 | 2400 | SGLang | 2175.02 | 35.68 | 17.85 | 14.41 | +| 8 | 2400 | vLLM | 2137.16 | 120.39 | 17.09 | 158.63 | + +### Llama 3.1 70B Insruct 4 x H100 80G + +| RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | +|------|-------------|--------|--------------------|-------------|-------------|------------| +| 4 | 1200 | SGLang | 3005.24 | 53.94 | 25.03 | 21.67 | +| 4 | 1200 | vLLM | 2915.60 | 179.15 | 23.58 | 231.23 | +| 8 | 2400 | SGLang | 4064.98 | 58.11 | 33.07 | 24.45 | +| 8 | 2400 | vLLM | 3752.38 | 207.12 | 29.15 | 275.32 | + +## Offline benchmark results + +### Llama 3.1 8B Instruct 1 x A100 80G + +| RPS | Num Prompts | Engine | Request throughput | Output token throughput | +|------|-------------|--------|--------------------|-------------------------| +| inf | 5000 | SGLang | 22.03 | 4281.51 | +| inf | 5000 | vLLM | 21.27 | 4132.37 | + +### Llama 3.1 70B Insruct 4 x H100 80G + +| RPS | Num Prompts | Engine | Request throughput | Output token throughput | +|------|-------------|--------|--------------------|-------------------------| +| inf | 5000 | SGLang | 19.84 | 3856.01 | +| inf | 5000 | vLLM | 19.04 | 3700.64 | From 79794af52d90abfb00e73871109f0cdc4e0b7f34 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 6 Sep 2024 00:00:06 +1000 Subject: [PATCH 05/33] docs: highlight ttft itl and throughput (#1337) --- benchmark/benchmark_vllm_060/README.md | 28 +++++++++++++++----------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/benchmark/benchmark_vllm_060/README.md b/benchmark/benchmark_vllm_060/README.md index acb55f8971..157bd9df7a 100644 --- a/benchmark/benchmark_vllm_060/README.md +++ b/benchmark/benchmark_vllm_060/README.md @@ -12,6 +12,10 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ pip install vllm==0.6.0 ``` +## Notes + +We referred to the reproduction method in https://github.com/vllm-project/vllm/issues/8176, and added the `--num-scheduler-steps 10` parameter when starting the vLLM server. The `gpu_memory_utilization` of vLLM is by default 0.9 at both TP 1 and TP 4, while SGLang's `mem_frac` is 0.88 at TP 1 and 0.85 at TP 4, so we manually set it to 0.88 at TP 4. + ## Online benchmarks ```bash @@ -52,19 +56,19 @@ python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-pro | RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | |------|-------------|--------|--------------------|-------------|-------------|------------| -| 4 | 1200 | SGLang | 1564.17 | 31.98 | 13.17 | 11.93 | -| 4 | 1200 | vLLM | 1691.97 | 100.48 | 14.14 | 129.32 | -| 8 | 2400 | SGLang | 2175.02 | 35.68 | 17.85 | 14.41 | -| 8 | 2400 | vLLM | 2137.16 | 120.39 | 17.09 | 158.63 | +| 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** | +| 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** | +| 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** | +| 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** | ### Llama 3.1 70B Insruct 4 x H100 80G | RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | |------|-------------|--------|--------------------|-------------|-------------|------------| -| 4 | 1200 | SGLang | 3005.24 | 53.94 | 25.03 | 21.67 | -| 4 | 1200 | vLLM | 2915.60 | 179.15 | 23.58 | 231.23 | -| 8 | 2400 | SGLang | 4064.98 | 58.11 | 33.07 | 24.45 | -| 8 | 2400 | vLLM | 3752.38 | 207.12 | 29.15 | 275.32 | +| 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** | +| 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** | +| 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** | +| 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** | ## Offline benchmark results @@ -72,12 +76,12 @@ python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-pro | RPS | Num Prompts | Engine | Request throughput | Output token throughput | |------|-------------|--------|--------------------|-------------------------| -| inf | 5000 | SGLang | 22.03 | 4281.51 | -| inf | 5000 | vLLM | 21.27 | 4132.37 | +| inf | 5000 | SGLang | 22.03 | **4281.51** | +| inf | 5000 | vLLM | 21.27 | **4132.37** | ### Llama 3.1 70B Insruct 4 x H100 80G | RPS | Num Prompts | Engine | Request throughput | Output token throughput | |------|-------------|--------|--------------------|-------------------------| -| inf | 5000 | SGLang | 19.84 | 3856.01 | -| inf | 5000 | vLLM | 19.04 | 3700.64 | +| inf | 5000 | SGLang | 19.84 | **3856.01** | +| inf | 5000 | vLLM | 19.04 | **3700.64** | From 62f15eea5a0b4266cdae965d0337fd33f6673736 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 6 Sep 2024 04:25:14 +1000 Subject: [PATCH 06/33] docs: add conclusion (#1340) --- benchmark/benchmark_vllm_060/README.md | 74 +++++++++++++------------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/benchmark/benchmark_vllm_060/README.md b/benchmark/benchmark_vllm_060/README.md index 157bd9df7a..5a1247c5f4 100644 --- a/benchmark/benchmark_vllm_060/README.md +++ b/benchmark/benchmark_vllm_060/README.md @@ -1,5 +1,43 @@ ## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0 +In short, with multi step enabled, in online scenarios that we benchmarked, the Median TTFT of vLLM is **3 times** that of SGLang, and the Median ITL is **10 times** that of SGLang. Lower Median TTFT and ITL are better. vLLM's multi-step optimization did not improve throughput while ensuring lower Median TTFT and ITL. Also, under maximum throughput benchmark, if vLLM does not set gpu util to 0.95 separately and uses the default configuration instead, its maximum throughput is **lower** than that of SGLang. + +## Online benchmark results + +### Llama 3.1 8B Instruct 1 x A100 80G + +| RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | +|------|-------------|--------|--------------------|-------------|-------------|------------| +| 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** | +| 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** | +| 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** | +| 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** | + +### Llama 3.1 70B Insruct 4 x H100 80G + +| RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | +|------|-------------|--------|--------------------|-------------|-------------|------------| +| 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** | +| 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** | +| 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** | +| 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** | + +## Offline benchmark results + +### Llama 3.1 8B Instruct 1 x A100 80G + +| RPS | Num Prompts | Engine | Request throughput | Output token throughput | +|------|-------------|--------|--------------------|-------------------------| +| inf | 5000 | SGLang | 22.03 | **4281.51** | +| inf | 5000 | vLLM | 21.27 | **4132.37** | + +### Llama 3.1 70B Insruct 4 x H100 80G + +| RPS | Num Prompts | Engine | Request throughput | Output token throughput | +|------|-------------|--------|--------------------|-------------------------| +| inf | 5000 | SGLang | 19.84 | **3856.01** | +| inf | 5000 | vLLM | 19.04 | **3700.64** | + ## Installation ```bash @@ -49,39 +87,3 @@ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-7 python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000 python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000 ``` - -## Online benchmark results - -### Llama 3.1 8B Instruct 1 x A100 80G - -| RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | -|------|-------------|--------|--------------------|-------------|-------------|------------| -| 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** | -| 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** | -| 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** | -| 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** | - -### Llama 3.1 70B Insruct 4 x H100 80G - -| RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | -|------|-------------|--------|--------------------|-------------|-------------|------------| -| 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** | -| 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** | -| 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** | -| 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** | - -## Offline benchmark results - -### Llama 3.1 8B Instruct 1 x A100 80G - -| RPS | Num Prompts | Engine | Request throughput | Output token throughput | -|------|-------------|--------|--------------------|-------------------------| -| inf | 5000 | SGLang | 22.03 | **4281.51** | -| inf | 5000 | vLLM | 21.27 | **4132.37** | - -### Llama 3.1 70B Insruct 4 x H100 80G - -| RPS | Num Prompts | Engine | Request throughput | Output token throughput | -|------|-------------|--------|--------------------|-------------------------| -| inf | 5000 | SGLang | 19.84 | **3856.01** | -| inf | 5000 | vLLM | 19.04 | **3700.64** | From ab4a83b25909aa98330b838a224e4fe5c943e483 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 5 Sep 2024 14:30:26 -0700 Subject: [PATCH 07/33] Optimize schedule (#1339) --- .../sglang/srt/managers/policy_scheduler.py | 110 +++++++++++++++++- python/sglang/srt/managers/tp_worker.py | 21 +++- 2 files changed, 123 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 04169e8086..3a70bfe548 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -108,18 +108,24 @@ class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, + running_batch: ScheduleBatch, + new_token_ratio: float, rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache + self.running_batch = running_batch + self.new_token_ratio = new_token_ratio self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens + self.total_tokens = rem_total_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens + self.req_states = None self.can_run_list = [] self.new_inflight_req = None self.log_hit_tokens = 0 @@ -136,16 +142,14 @@ def no_remaining_tokens(self): ) ) - def remove_running_tokens( - self, running_batch: ScheduleBatch, new_token_ratio: float - ): + def remove_running_tokens(self, running_batch: ScheduleBatch): self.rem_total_tokens -= sum( [ min( (r.sampling_params.max_new_tokens - len(r.output_ids)), CLIP_MAX_NEW_TOKENS, ) - * new_token_ratio + * self.new_token_ratio for r in running_batch.reqs ] ) @@ -161,7 +165,29 @@ def _prefill_one_req( self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len + def add_inflight_req_ignore_eos(self, req: Req): + truncated = req.extend_input_len > self.rem_chunk_tokens + req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) + req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] + self.can_run_list.append(req) + + self._prefill_one_req( + 0, + req.extend_input_len, + ( + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) + if not truncated + else 0 + ), + ) + + # Return if chunked prefill not finished + return req if truncated else None + def add_inflight_req(self, req: Req): + if req.sampling_params.ignore_eos: + return self.add_inflight_req_ignore_eos(req) + truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -190,7 +216,81 @@ def _lock_node(self, last_node: TreeNode): delta = self.tree_cache.dec_lock_ref(last_node) self.rem_total_tokens += delta + def add_one_req_ignore_eos(self, req: Req): + def get_req_state(r): + new_token_ratio = ( + 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio + ) + tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len( + r.output_ids + ) + tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) + + if tokens_left > 0: + return (tokens_left, tokens_occupied) + + return None + + if self.req_states is None: + self.req_states = [] + if self.running_batch is not None: + for r in self.running_batch.reqs: + state = get_req_state(r) + if state is not None: + self.req_states.append(state) + for r in self.can_run_list: + state = get_req_state(r) + if state is not None: + self.req_states.append(state) + state = get_req_state(req) + if state is not None: + self.req_states.append(state) + + self.req_states.sort(key=lambda x: x[0]) + else: + state = get_req_state(req) + if state is not None: + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + if tokens_left >= state[0]: + self.req_states.insert(i, state) + break + else: + self.req_states.append(state) + + tokens_freed = 0 + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + decode_steps = ( + self.req_states[i + 1][0] + if i + 1 < len(self.req_states) + else tokens_left + ) + bs = len(self.req_states) - i + if self.total_tokens + tokens_freed - decode_steps * bs <= 0: + return False + tokens_freed += tokens_occupied + + if req.extend_input_len <= self.rem_chunk_tokens: + self.can_run_list.append(req) + self._prefill_one_req( + 0, + req.extend_input_len, + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), + ) + else: + # Chunked prefill + trunc_len = self.rem_chunk_tokens + req.extend_input_len = trunc_len + req.fill_ids = req.fill_ids[:trunc_len] + self.can_run_list.append(req) + self.new_inflight_req = req + self._prefill_one_req(0, trunc_len, 0) + + return True + def add_one_req(self, req: Req): + if req.sampling_params.ignore_eos and self.tree_cache.disable: + return self.add_one_req_ignore_eos(req) + total_tokens = req.extend_input_len + min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS ) @@ -233,4 +333,4 @@ def add_one_req(self, req: Req): self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) - return True + return True and not self.no_remaining_tokens() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8fc03b8599..d914a71c27 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -221,6 +221,7 @@ def __init__( ) self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio_decay = global_config.new_token_ratio_decay + self.do_not_get_new_batch = False def exposed_step(self, recv_reqs: List): try: @@ -253,7 +254,13 @@ def exposed_step(self, recv_reqs: List): @torch.inference_mode() def forward_step(self): - new_batch = self.get_new_prefill_batch() + if self.current_inflight_req is not None: + self.do_not_get_new_batch = False + + new_batch = ( + self.get_new_prefill_batch() if not self.do_not_get_new_batch else None + ) + self.do_not_get_new_batch = False if new_batch is not None: # Run a new prefill batch @@ -409,6 +416,8 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: adder = PrefillAdder( self.tree_cache, + self.running_batch, + self.new_token_ratio, self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, @@ -416,7 +425,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) if self.running_batch is not None: - adder.remove_running_tokens(self.running_batch, self.new_token_ratio) + adder.remove_running_tokens(self.running_batch) has_inflight = self.current_inflight_req is not None if self.current_inflight_req is not None: @@ -428,11 +437,12 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) for req in self.waiting_queue: + if adder.no_remaining_tokens(): + break req.init_next_round_input(None if prefix_computed else self.tree_cache) res = adder.add_one_req(req) if ( not res - or adder.no_remaining_tokens() or running_bs + len(adder.can_run_list) >= self.max_running_requests ): break @@ -700,6 +710,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): next_token_ids = next_token_ids.tolist() # Check finish condition + has_finished = False for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) @@ -712,6 +723,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): if req.finished(): self.tree_cache.cache_finished_req(req) + has_finished = True if req.return_logprob: req.output_token_logprobs.append( @@ -720,6 +732,9 @@ def forward_decode_batch(self, batch: ScheduleBatch): if req.top_logprobs_num > 0: req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + if not has_finished: + self.do_not_get_new_batch = True + self.handle_finished_requests(batch) def handle_finished_requests(self, batch: ScheduleBatch): From 05bea6883c4b3f2fb7f01287cd8dccefeacd545f Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 7 Sep 2024 20:46:27 -0700 Subject: [PATCH 08/33] Fix some online scheduling delay (#1345) --- .../sglang/srt/managers/policy_scheduler.py | 77 +++++++++++-------- python/sglang/srt/managers/tp_worker.py | 11 ++- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 3a70bfe548..b58c0e7b3b 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -119,6 +119,7 @@ def __init__( self.running_batch = running_batch self.new_token_ratio = new_token_ratio self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens + self.rem_total_tokens_ = self.rem_total_tokens self.total_tokens = rem_total_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens @@ -153,11 +154,18 @@ def remove_running_tokens(self, running_batch: ScheduleBatch): for r in running_batch.reqs ] ) + self.rem_total_tokens_ -= sum( + [ + r.sampling_params.max_new_tokens - len(r.output_ids) + for r in running_batch.reqs + ] + ) def _prefill_one_req( self, prefix_len: int, extend_input_len: int, max_new_tokens: int ): self.rem_total_tokens -= extend_input_len + max_new_tokens + self.rem_total_tokens_ -= extend_input_len + max_new_tokens self.rem_input_tokens -= extend_input_len if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= extend_input_len @@ -231,43 +239,52 @@ def get_req_state(r): return None - if self.req_states is None: - self.req_states = [] - if self.running_batch is not None: - for r in self.running_batch.reqs: + # Quick Check + can_run = False + if ( + req.extend_input_len + req.sampling_params.max_new_tokens + <= self.rem_total_tokens + ): + can_run = True + + if not can_run: + if self.req_states is None: + self.req_states = [] + if self.running_batch is not None: + for r in self.running_batch.reqs: + state = get_req_state(r) + if state is not None: + self.req_states.append(state) + for r in self.can_run_list: state = get_req_state(r) if state is not None: self.req_states.append(state) - for r in self.can_run_list: - state = get_req_state(r) + state = get_req_state(req) if state is not None: self.req_states.append(state) - state = get_req_state(req) - if state is not None: - self.req_states.append(state) - self.req_states.sort(key=lambda x: x[0]) - else: - state = get_req_state(req) - if state is not None: - for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): - if tokens_left >= state[0]: - self.req_states.insert(i, state) - break - else: - self.req_states.append(state) + self.req_states.sort(key=lambda x: x[0]) + else: + state = get_req_state(req) + if state is not None: + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + if tokens_left >= state[0]: + self.req_states.insert(i, state) + break + else: + self.req_states.append(state) - tokens_freed = 0 - for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): - decode_steps = ( - self.req_states[i + 1][0] - if i + 1 < len(self.req_states) - else tokens_left - ) - bs = len(self.req_states) - i - if self.total_tokens + tokens_freed - decode_steps * bs <= 0: - return False - tokens_freed += tokens_occupied + tokens_freed = 0 + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + decode_steps = ( + self.req_states[i + 1][0] + if i + 1 < len(self.req_states) + else tokens_left + ) + bs = len(self.req_states) - i + if self.total_tokens + tokens_freed - decode_steps * bs <= 0: + return False + tokens_freed += tokens_occupied if req.extend_input_len <= self.rem_chunk_tokens: self.can_run_list.append(req) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index d914a71c27..c2c0e6c2d1 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -231,6 +231,7 @@ def exposed_step(self, recv_reqs: List): recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) ): self.handle_generate_request(recv_req) + self.do_not_get_new_batch = False elif isinstance(recv_req, FlushCacheReq): self.flush_cache() elif isinstance(recv_req, AbortReq): @@ -254,12 +255,10 @@ def exposed_step(self, recv_reqs: List): @torch.inference_mode() def forward_step(self): - if self.current_inflight_req is not None: - self.do_not_get_new_batch = False - - new_batch = ( - self.get_new_prefill_batch() if not self.do_not_get_new_batch else None - ) + if self.do_not_get_new_batch and self.current_inflight_req is None: + new_batch = None + else: + new_batch = self.get_new_prefill_batch() self.do_not_get_new_batch = False if new_batch is not None: From 8e6bdf851c4aa6619baa584fc450af748720319d Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 9 Sep 2024 01:30:24 -0700 Subject: [PATCH 09/33] [triton] Support head_dim not 2^n in triton extend and decode attention (#1281) --- python/sglang/srt/layers/decode_attention.py | 50 ++++++++++++------ python/sglang/srt/layers/extend_attention.py | 51 +++++++++++++------ python/sglang/srt/layers/prefill_attention.py | 20 +++++--- 3 files changed, 84 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/decode_attention.py index dc92a65480..9c9822b852 100644 --- a/python/sglang/srt/layers/decode_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -60,6 +60,7 @@ def _fwd_kernel_stage1( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -97,7 +98,7 @@ def _fwd_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=offs_n_new[:, None] < cur_batch_end_index, + mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk), other=0.0, ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) @@ -128,6 +129,7 @@ def _fwd_kernel_stage2( kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -170,14 +172,16 @@ def _fwd_kernel_stage2( old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max) e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + v = tl.load( + v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + ) acc = acc * old_scale + tl.sum(p[:, None] * v, 0) e_max = n_e_max acc = acc / e_sum off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d out_ptrs = Out + off_o - tl.store(out_ptrs, acc) + tl.store(out_ptrs, acc, mask=(offs_d < Lv)) def _decode_att_m_fwd( @@ -196,7 +200,7 @@ def _decode_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 96, 128, 256} batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -208,6 +212,8 @@ def _decode_att_m_fwd( else: num_warps = 2 + BLOCK_DMODEL = triton.next_power_of_2(Lk) + _fwd_kernel_stage1[grid]( q, k_buffer, @@ -224,11 +230,12 @@ def _decode_att_m_fwd( k_buffer.stride(1), att_out.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, logit_cap=logit_cap, num_warps=num_warps, num_stages=1, + Lk=Lk, ) @@ -248,6 +255,9 @@ def _decode_softmax_reducev_fwd( num_warps = 1 + Lv = v_buffer.shape[-1] + BLOCK_DMODEL = triton.next_power_of_2(Lv) + _fwd_kernel_stage2[grid]( logics, v_buffer, @@ -263,10 +273,11 @@ def _decode_softmax_reducev_fwd( o.stride(1), req_to_tokens.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=3, + Lv=Lv, ) @@ -293,6 +304,7 @@ def _fwd_grouped_kernel_stage1( BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, logit_cap: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -324,9 +336,9 @@ def _fwd_grouped_kernel_stage1( block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) for start_mark in range(0, block_mask, 1): - q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to( - REDUCE_TRITON_TYPE - ) + q = tl.load( + Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk) + ).to(REDUCE_TRITON_TYPE) offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, @@ -340,7 +352,7 @@ def _fwd_grouped_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=offs_n_new[None, :] < cur_batch_end_index, + mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk), other=0.0, ).to(REDUCE_TRITON_TYPE) qk = tl.dot(q, k) @@ -395,6 +407,7 @@ def _fwd_grouped_kernel_stage2( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -441,7 +454,9 @@ def _fwd_grouped_kernel_stage2( old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) e_sum = e_sum * old_scale + tl.sum(p, 1) - v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + v = tl.load( + v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + ) p = p.to(v.dtype) acc = acc * old_scale[:, None] + tl.dot(p, v) e_max = n_e_max @@ -449,7 +464,7 @@ def _fwd_grouped_kernel_stage2( acc = acc / e_sum[:, None] off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=mask_h[:, None]) + tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) def _decode_grouped_att_m_fwd( @@ -468,13 +483,13 @@ def _decode_grouped_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256, 576} + assert Lk in {16, 32, 64, 96, 128, 256, 576} if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 else: - BLOCK_DMODEL = Lk + BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -513,6 +528,7 @@ def _decode_grouped_att_m_fwd( logit_cap=logit_cap, num_warps=num_warps, num_stages=1, + Lk=Lk, ) @@ -533,6 +549,9 @@ def _decode_grouped_softmax_reducev_fwd( num_warps = 8 + Lv = v_buffer.shape[-1] + BLOCK_DMODEL = triton.next_power_of_2(Lv) + _fwd_grouped_kernel_stage2[grid]( logics, v_buffer, @@ -549,11 +568,12 @@ def _decode_grouped_softmax_reducev_fwd( req_to_tokens.stride(0), kv_group_num=kv_group_num, q_head_num=head_num, - BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=1, + Lv=Lv, ) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 6c7686971e..8880622854 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -15,7 +15,7 @@ """ Memory-efficient attention for prefill. -It supporst page size = 1 and prefill with KV cache (i.e. extend). +It supports page size = 1 and prefill with KV cache (i.e. extend). """ import torch @@ -67,6 +67,8 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -86,13 +88,18 @@ def _fwd_kernel( offs_m = tl.arange(0, BLOCK_M) mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + offs_q = ( (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] ) - q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) + q = tl.load( + Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 + ) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) @@ -125,7 +132,9 @@ def _fwd_kernel( + cur_kv_head * stride_buf_kh + offs_d[:, None] ) - k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) + k = tl.load( + K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) qk = tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: @@ -157,7 +166,9 @@ def _fwd_kernel( + cur_kv_head * stride_buf_vh + offs_dv[None, :] ) - v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) + v = tl.load( + V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) p = p.to(v.dtype) acc = acc * re_scale[:, None] + tl.dot(p, v) @@ -176,7 +187,9 @@ def _fwd_kernel( + cur_kv_head * stride_kh + offs_d[:, None] ) - k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) + k = tl.load( + K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: @@ -214,7 +227,9 @@ def _fwd_kernel( + cur_kv_head * stride_vh + offs_dv[None, :] ) - v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) + v = tl.load( + V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) p = p.to(v.dtype) acc = acc * re_scale[:, None] + tl.dot(p, v) @@ -226,7 +241,9 @@ def _fwd_kernel( + cur_head * stride_oh + offs_dv[None, :] ) - tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) + tl.store( + O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :] + ) def extend_attention_fwd( @@ -261,16 +278,18 @@ def extend_attention_fwd( ) assert Lq == Lk and Lv == Lo - assert Lq in {16, 32, 64, 128, 256, 576} - assert Lv in {16, 32, 64, 128, 256, 512} + + # TODO: is the assertion necessary? + assert Lq in {16, 32, 64, 96, 128, 256, 576} + assert Lv in {16, 32, 64, 96, 128, 256, 512} if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 else: - BLOCK_DMODEL = Lq + BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 - BLOCK_DV = Lv + BLOCK_DV = triton.next_power_of_2(Lv) if CUDA_CAPABILITY[0] >= 9: if Lq <= 256: @@ -330,6 +349,8 @@ def extend_attention_fwd( num_warps=num_warps, num_stages=num_stages, logit_cap=logit_cap, + Lq=Lq, + Lv=Lv, ) @@ -373,10 +394,7 @@ def redundant_attention( pt += cur_seq_len_extend -def test(): - torch.manual_seed(0) - - B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128 +def test_once(B, N_CTX, H_Q, H_KV, D): dtype = torch.float16 b_seq_len_prefix = torch.randint( @@ -473,4 +491,5 @@ def test(): if __name__ == "__main__": - test() + test_once(19, 12331, 12, 4, 128) + test_once(19, 12331, 12, 4, 96) diff --git a/python/sglang/srt/layers/prefill_attention.py b/python/sglang/srt/layers/prefill_attention.py index 99343a4df7..fbf9976fbc 100644 --- a/python/sglang/srt/layers/prefill_attention.py +++ b/python/sglang/srt/layers/prefill_attention.py @@ -48,6 +48,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -72,7 +73,11 @@ def _fwd_kernel( off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + mask_d = offs_d < Lk + + q = tl.load( + Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0 + ) k_ptrs = K + off_k v_ptrs = V + off_v @@ -89,7 +94,7 @@ def _fwd_kernel( # -- compute qk ---- k = tl.load( k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]), other=0.0, ) # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) @@ -118,7 +123,7 @@ def _fwd_kernel( # update acc v = tl.load( v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), other=0.0, ) @@ -134,7 +139,9 @@ def _fwd_kernel( + offs_d[None, :] ) out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + tl.store( + out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]) + ) def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): @@ -145,7 +152,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 96, 128, 256} sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -172,8 +179,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): o.stride(1), kv_group_num=kv_group_num, BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, + BLOCK_DMODEL=triton.next_power_of_2(Lk), BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, + Lk=Lk, ) From 662ecd93680c8195eda799cb9a497f93efdc521a Mon Sep 17 00:00:00 2001 From: Kaichen Zhang - NTU Date: Mon, 9 Sep 2024 17:07:34 +0800 Subject: [PATCH 10/33] [Feat] Add modalities for vision server when handling pixel values for llava (#1346) --- .../llava_onevision/http_llava_onevision_test.py | 3 +++ python/sglang/srt/conversation.py | 3 +++ python/sglang/srt/managers/io_struct.py | 4 ++++ python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/managers/tokenizer_manager.py | 5 +++++ python/sglang/srt/managers/tp_worker.py | 2 ++ .../sglang/srt/model_executor/forward_batch_info.py | 2 ++ python/sglang/srt/models/llava.py | 11 +++++++++-- python/sglang/srt/openai_api/adapter.py | 7 +++++++ python/sglang/srt/openai_api/protocol.py | 1 + test/srt/test_vision_openai_server.py | 3 +++ 11 files changed, 40 insertions(+), 2 deletions(-) diff --git a/examples/runtime/llava_onevision/http_llava_onevision_test.py b/examples/runtime/llava_onevision/http_llava_onevision_test.py index 0c93d2ce2b..2c7c2bd38b 100644 --- a/examples/runtime/llava_onevision/http_llava_onevision_test.py +++ b/examples/runtime/llava_onevision/http_llava_onevision_test.py @@ -93,12 +93,14 @@ def multi_image_stream_request_test(client): "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" }, + "modalities": "multi-images", }, { "type": "image_url", "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" }, + "modalities": "multi-images", }, { "type": "text", @@ -218,6 +220,7 @@ def prepare_video_messages(video_path): frame_format = { "type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", } for base64_frame in base64_frames: diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index dbc376d959..9a1227218b 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -71,6 +71,7 @@ class Conversation: # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None image_data: Optional[List[str]] = None + modalities: Optional[List[str]] = None def get_prompt(self) -> str: """Get the prompt for generation.""" @@ -379,6 +380,7 @@ def generate_chat_conv( sep2=conv.sep2, stop_str=conv.stop_str, image_data=[], + modalities=[], ) if isinstance(request.messages, str): @@ -408,6 +410,7 @@ def generate_chat_conv( for content in message.content: if content.type == "image_url": num_image_url += 1 + conv.modalities.append(content.modalities) if num_image_url > 1: image_token = "" else: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5b91ff62e9..8e53df3355 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -50,6 +50,8 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None def post_init(self): if (self.text is None and self.input_ids is None) or ( @@ -177,6 +179,8 @@ class TokenizedGenerateReqInput: top_logprobs_num: int # Whether to stream output stream: bool + # Modalities of the input images + modalites: Optional[List[str]] = None @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c80cf2e272..f126cc9f3a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -130,6 +130,7 @@ def __init__(self, rid, origin_input_text, origin_input_ids): self.image_sizes = None self.image_offsets = None self.pad_value = None + self.modalities = None # Prefix info self.extend_input_len = 0 diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6af8206415..d0cfed08cd 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -188,6 +188,7 @@ async def _handle_single_request( pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data if not_use_index else obj.image_data[index] ) + modalities = obj.modalities return_logprob = ( obj.return_logprob if not_use_index else obj.return_logprob[index] ) @@ -243,6 +244,7 @@ async def _handle_single_request( pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data[0] ) + modalities = obj.modalities return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] @@ -263,6 +265,7 @@ async def _handle_single_request( logprob_start_len, top_logprobs_num, obj.stream, + modalities, ) else: # is embedding tokenized_obj = TokenizedEmbeddingReqInput( @@ -346,6 +349,7 @@ async def _handle_batch_request( pixel_values, image_hashes, image_sizes = ( await self._get_pixel_values(obj.image_data[index]) ) + modalities = obj.modalities tokenized_obj = TokenizedGenerateReqInput( rid, @@ -359,6 +363,7 @@ async def _handle_batch_request( obj.logprob_start_len[index], obj.top_logprobs_num[index], obj.stream, + modalities, ) else: tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index c2c0e6c2d1..7bb9c43356 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -358,6 +358,8 @@ def handle_generate_request( req.pixel_values, req.image_sizes, ) + # Only when pixel values is not None we have modalities + req.modalities = recv_req.modalites req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len req.top_logprobs_num = recv_req.top_logprobs_num diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a443b113d4..75f9136d39 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -78,6 +78,7 @@ class InputMetadata: pixel_values: List[torch.Tensor] = None image_sizes: List[List[List[int]]] = None image_offsets: List[List[int]] = None + modalities: List[List[str]] = None # Trition attention backend triton_max_seq_len: int = 0 @@ -96,6 +97,7 @@ def init_multimuldal_info(self, batch: ScheduleBatch): self.pixel_values = [r.pixel_values for r in reqs] self.image_sizes = [r.image_sizes for r in reqs] self.image_offsets = [r.image_offsets for r in reqs] + self.modalities = [r.modalities for r in reqs] def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 2e3c9ceba1..62041a8955 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -138,6 +138,12 @@ def forward( ) -> torch.Tensor: if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size + # Got List[List[str]] extend it to List[str] + # The length of the List should be equal to batch size + modalities_list = [] + for modalities in input_metadata.modalities: + if modalities is not None: + modalities_list.extend(modalities) # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) @@ -179,7 +185,7 @@ def forward( new_image_features = [] height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): - if len(image_sizes[image_idx]) == 1: + if modalities_list[image_idx] == 1: image_aspect_ratio = ( self.config.image_aspect_ratio ) # single image @@ -191,6 +197,7 @@ def forward( if ( image_feature.shape[0] > 1 and "anyres" in image_aspect_ratio + and modalities_list[image_idx] == "image" ): base_image_feature = image_feature[0] image_feature = image_feature[1:] @@ -290,7 +297,7 @@ def forward( ) image_feature = image_feature.unsqueeze(0) else: - if image_feature.shape[0] > 16: # video + if modalities_list[image_idx] == "video": # video # 2x2 pooling num_of_frames = image_feature.shape[0] image_feature = image_feature.view( diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index cd7526b0d9..f1195aff7c 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -832,6 +832,7 @@ def v1_chat_generate_request( return_logprobs = [] logprob_start_lens = [] top_logprobs_nums = [] + modalities_list = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -864,10 +865,12 @@ def v1_chat_generate_request( ) stop = request.stop image_data = None + modalities = [] else: conv = generate_chat_conv(request, chat_template_name) prompt = conv.get_prompt() image_data = conv.image_data + modalities = conv.modalities stop = conv.stop_str or [] if request.stop: if isinstance(request.stop, str): @@ -880,6 +883,7 @@ def v1_chat_generate_request( prompt_ids = request.messages stop = request.stop image_data = None + modalities = [] input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) logprob_start_lens.append(-1) @@ -901,6 +905,7 @@ def v1_chat_generate_request( } ) image_data_list.append(image_data) + modalities_list.extend(modalities) if len(all_requests) == 1: input_ids = input_ids[0] if isinstance(input_ids, str): @@ -912,6 +917,7 @@ def v1_chat_generate_request( return_logprobs = return_logprobs[0] logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] + modalities_list = modalities_list[:1] else: if isinstance(input_ids[0], str): prompt_kwargs = {"text": input_ids} @@ -928,6 +934,7 @@ def v1_chat_generate_request( stream=all_requests[0].stream, return_text_in_logprobs=True, rid=request_ids, + modalities=modalities_list, ) if len(all_requests) == 1: return adapted_request, all_requests[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 8073df7952..5525cd8827 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -213,6 +213,7 @@ class ChatCompletionMessageContentImageURL(BaseModel): class ChatCompletionMessageContentImagePart(BaseModel): type: Literal["image_url"] image_url: ChatCompletionMessageContentImageURL + modalities: Optional[Literal["image", "multi-images", "video"]] = "image" ChatCompletionMessageContentPart = Union[ diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 4f764c09cd..727f5774ca 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -140,12 +140,14 @@ def test_mult_images_chat_completion(self): "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" }, + "modalities": "multi-images", }, { "type": "image_url", "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" }, + "modalities": "multi-images", }, { "type": "text", @@ -192,6 +194,7 @@ def prepare_video_messages(self, video_path): frame_format = { "type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", } for base64_frame in base64_frames: From c9b75917d577523ba1c1c581c638b9d2e94b777b Mon Sep 17 00:00:00 2001 From: Kai-Hsun Chen Date: Mon, 9 Sep 2024 02:14:25 -0700 Subject: [PATCH 11/33] [server] Passing `model_override_args` to `launch_server` via the CLI. (#1298) Signed-off-by: Kai-Hsun Chen --- benchmark/blog_v0_2/405b_sglang.sh | 2 +- python/sglang/bench_latency.py | 1 + python/sglang/launch_server.py | 12 ++++----- python/sglang/launch_server_llavavid.py | 11 +++----- python/sglang/srt/server_args.py | 34 +++++++++++++++++++++++++ test/srt/run_suite.py | 1 + test/srt/test_server_args.py | 24 +++++++++++++++++ test/srt/test_serving_latency.py | 2 +- 8 files changed, 71 insertions(+), 16 deletions(-) create mode 100644 test/srt/test_server_args.py diff --git a/benchmark/blog_v0_2/405b_sglang.sh b/benchmark/blog_v0_2/405b_sglang.sh index eae5e22060..d31f8daf8e 100644 --- a/benchmark/blog_v0_2/405b_sglang.sh +++ b/benchmark/blog_v0_2/405b_sglang.sh @@ -6,7 +6,7 @@ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # Launch sglang -# python -m sglang.launch_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 +# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 # offline python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 9006b7150a..6113495776 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -480,6 +480,7 @@ def main(server_args, bench_args): if __name__ == "__main__": + # TODO(kevin85421): Make the parser setup unit testable. parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 1df64e848c..06aa140d9b 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -1,20 +1,18 @@ """Launch the inference server.""" -import argparse import os +import sys from sglang.srt.server import launch_server -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_child_process if __name__ == "__main__": - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) + server_args = prepare_server_args(sys.argv[1:]) + model_override_args = server_args.json_model_override_args try: - launch_server(server_args) + launch_server(server_args, model_override_args=model_override_args) except Exception as e: raise e finally: diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index 43eefef4ef..6b8d151ee1 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,14 +1,11 @@ """Launch the inference server for Llava-video model.""" -import argparse +import sys -from sglang.srt.server import ServerArgs, launch_server +from sglang.srt.server import launch_server, prepare_server_args if __name__ == "__main__": - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) + server_args = prepare_server_args(sys.argv[1:]) model_override_args = {} model_override_args["mm_spatial_pool_stride"] = 2 @@ -20,7 +17,7 @@ model_override_args["max_sequence_length"] = 4096 * 2 model_override_args["tokenizer_model_max_length"] = 4096 * 2 model_override_args["model_max_length"] = 4096 * 2 - if "34b" in args.model_path.lower(): + if "34b" in server_args.model_path.lower(): model_override_args["image_token_index"] = 64002 launch_server(server_args, model_override_args, None) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8a56c02e16..e21f02108c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -17,6 +17,7 @@ import argparse import dataclasses +import json import logging import random from typing import List, Optional, Union @@ -95,6 +96,9 @@ class ServerArgs: nnodes: int = 1 node_rank: Optional[int] = None + # Model override args in JSON + json_model_override_args: Optional[dict] = None + def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -455,10 +459,22 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) + # Model override args + parser.add_argument( + "--json-model-override-args", + type=str, + help="A dictionary in JSON string format used to override default model configurations.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size + args.json_model_override_args = ( + json.loads(args.json_model_override_args) + if args.json_model_override_args + else None + ) attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) @@ -482,6 +498,24 @@ def check_server_args(self): self.disable_flashinfer = False +def prepare_server_args(args: argparse.Namespace) -> ServerArgs: + """ + Prepare the server arguments from the command line arguments. + + Args: + args: The command line arguments. Typically, it should be `sys.argv[1:]` + to ensure compatibility with `parse_args` when no arguments are passed. + + Returns: + The server arguments. + """ + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + raw_args = parser.parse_args(args) + server_args = ServerArgs.from_cli_args(raw_args) + return server_args + + @dataclasses.dataclass class PortArgs: tokenizer_port: int diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cafcf3f2d5..d5982844ce 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ "test_triton_attn_backend.py", "test_update_weights.py", "test_vision_openai_server.py", + "test_server_args.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py new file mode 100644 index 0000000000..71129e3eb1 --- /dev/null +++ b/test/srt/test_server_args.py @@ -0,0 +1,24 @@ +import unittest + +from sglang.srt.server_args import prepare_server_args + + +class TestPrepareServerArgs(unittest.TestCase): + def test_prepare_server_args(self): + server_args = prepare_server_args( + [ + "--model-path", + "model_path", + "--json-model-override-args", + '{"rope_scaling": {"factor": 2.0, "type": "linear"}}', + ] + ) + self.assertEqual(server_args.model_path, "model_path") + self.assertEqual( + server_args.json_model_override_args, + {"rope_scaling": {"factor": 2.0, "type": "linear"}}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_serving_latency.py b/test/srt/test_serving_latency.py index e762892c8e..3dae4541a0 100644 --- a/test/srt/test_serving_latency.py +++ b/test/srt/test_serving_latency.py @@ -12,7 +12,7 @@ def test_default(self): "python3", "-m", "sglang.bench_latency", - "--model", + "--model-path", DEFAULT_MODEL_NAME_FOR_TEST, "--batch-size", "1", From e4d68afcf00869a5467f101d176fecc3cd97b7b8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 9 Sep 2024 04:14:11 -0700 Subject: [PATCH 12/33] [Minor] Many cleanup (#1357) --- benchmark/gsm8k/README.md | 5 - benchmark/gsm8k/bench_other.py | 30 ++-- benchmark/gsm8k/bench_sglang.py | 39 +++-- benchmark/gsm8k/download_data.sh | 2 - benchmark/hellaswag/README.md | 5 - benchmark/hellaswag/bench_other.py | 23 +-- benchmark/hellaswag/bench_sglang.py | 24 +-- .../usage/llava_video/srt_example_llava_v.py | 3 +- python/sglang/bench_serving.py | 71 ++++---- python/sglang/launch_server.py | 3 +- python/sglang/launch_server_llavavid.py | 4 +- python/sglang/srt/constrained/fsm_cache.py | 67 ++++---- .../sglang/srt/managers/controller_multi.py | 6 +- .../sglang/srt/managers/controller_single.py | 5 - .../sglang/srt/managers/tokenizer_manager.py | 4 +- python/sglang/srt/managers/tp_worker.py | 157 +++++++++--------- .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/server.py | 9 +- python/sglang/srt/server_args.py | 40 ++--- python/sglang/test/few_shot_gsm8k.py | 132 +++++++++++++++ python/sglang/test/test_programs.py | 12 +- python/sglang/utils.py | 69 ++++---- test/srt/test_moe_eval_accuracy_large.py | 4 +- test/srt/test_server_args.py | 3 +- 24 files changed, 419 insertions(+), 299 deletions(-) delete mode 100755 benchmark/gsm8k/download_data.sh create mode 100644 python/sglang/test/few_shot_gsm8k.py diff --git a/benchmark/gsm8k/README.md b/benchmark/gsm8k/README.md index a7dc04d9a9..c110f533c7 100644 --- a/benchmark/gsm8k/README.md +++ b/benchmark/gsm8k/README.md @@ -1,8 +1,3 @@ -## Download data -``` -bash download_data.sh -``` - ## Run benchmark ### Benchmark sglang diff --git a/benchmark/gsm8k/bench_other.py b/benchmark/gsm8k/bench_other.py index 2a938d6bb9..a8bbcfb5c1 100644 --- a/benchmark/gsm8k/bench_other.py +++ b/benchmark/gsm8k/bench_other.py @@ -10,7 +10,7 @@ from tqdm import tqdm from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate -from sglang.utils import dump_state_text, read_jsonl +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl INVALID = -9999999 @@ -41,24 +41,28 @@ def get_answer_value(answer_str): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + call_generate = get_call_generate(args) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) labels.append(get_answer_value(lines[i]["answer"])) assert all(l != INVALID for l in labels) states = [None] * len(labels) - # Select backend - call_generate = get_call_generate(args) - # Run requests if args.backend != "lmql": # Use thread pool @@ -113,11 +117,13 @@ async def batched_call(batch_size): # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) - print(f"Latency: {latency:.3f}") - print(f"Invalid: {invalid:.3f}") + + # Print results print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") - # Write results + # Dump results dump_state_text(f"tmp_output_{args.backend}.txt", states) with open(args.result_file, "a") as fout: @@ -138,7 +144,7 @@ async def batched_call(batch_size): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=5) + parser.add_argument("--num-shots", type=int, default=5) parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_other_args_and_parse(parser) diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index d32790fe0c..9fe9b79baa 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -6,11 +6,12 @@ import numpy as np +from sglang.api import set_default_backend from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) -from sglang.utils import dump_state_text, read_jsonl +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl INVALID = -9999999 @@ -41,15 +42,22 @@ def get_answer_value(answer_str): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + set_default_backend(select_sglang_backend(args)) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) labels.append(get_answer_value(lines[i]["answer"])) assert all(l != INVALID for l in labels) @@ -72,15 +80,11 @@ def few_shot_gsm8k(s, question): ########## SGL Program End ########## ##################################### - # Select backend - backend = select_sglang_backend(args) - # Run requests tic = time.time() states = few_shot_gsm8k.run_batch( arguments, temperature=0, - backend=backend, num_threads=args.parallel, progress_bar=True, ) @@ -96,11 +100,20 @@ def few_shot_gsm8k(s, question): # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) - print(f"Latency: {latency:.3f}") - print(f"Invalid: {invalid:.3f}") + + # Compute speed + num_output_tokens = sum( + s.get_meta_info("answer")["completion_tokens"] for s in states + ) + output_throughput = num_output_tokens / latency + + # Print results print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + print(f"Output throughput: {output_throughput:.3f} token/s") - # Write results + # Dump results dump_state_text(f"tmp_output_{args.backend}.txt", states) with open(args.result_file, "a") as fout: @@ -121,7 +134,7 @@ def few_shot_gsm8k(s, question): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=5) + parser.add_argument("--num-shots", type=int, default=5) parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_sglang_args_and_parse(parser) diff --git a/benchmark/gsm8k/download_data.sh b/benchmark/gsm8k/download_data.sh deleted file mode 100755 index a9aa7756d2..0000000000 --- a/benchmark/gsm8k/download_data.sh +++ /dev/null @@ -1,2 +0,0 @@ -wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl -wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl \ No newline at end of file diff --git a/benchmark/hellaswag/README.md b/benchmark/hellaswag/README.md index b3e7abc30f..cb7e65366f 100644 --- a/benchmark/hellaswag/README.md +++ b/benchmark/hellaswag/README.md @@ -1,8 +1,3 @@ -## Download data -``` -wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl -``` - ## Run benchmark ### Benchmark sglang diff --git a/benchmark/hellaswag/bench_other.py b/benchmark/hellaswag/bench_other.py index 5b9ba797bd..04be4569a9 100644 --- a/benchmark/hellaswag/bench_other.py +++ b/benchmark/hellaswag/bench_other.py @@ -8,7 +8,7 @@ from tqdm import tqdm from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select -from sglang.utils import read_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def get_one_example(lines, i, include_answer): @@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + call_select = get_call_select(args) + + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] choices = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) choices.append(lines[i]["endings"]) labels.append(lines[i]["label"]) preds = [None] * len(labels) - # Select backend - call_select = get_call_select(args) - # Run requests if args.backend != "lmql": # Use thread pool @@ -65,7 +69,6 @@ def get_one_answer(i): total=len(questions), ) ) - else: # Use asyncio async def batched_call(batch_size): @@ -108,7 +111,7 @@ async def batched_call(batch_size): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=20) + parser.add_argument("--num-shots", type=int, default=20) parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_other_args_and_parse(parser) diff --git a/benchmark/hellaswag/bench_sglang.py b/benchmark/hellaswag/bench_sglang.py index 2ccf1aaee2..f09d7256da 100644 --- a/benchmark/hellaswag/bench_sglang.py +++ b/benchmark/hellaswag/bench_sglang.py @@ -4,11 +4,12 @@ import numpy as np +from sglang.api import set_default_backend from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) -from sglang.utils import read_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def get_one_example(lines, i, include_answer): @@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + set_default_backend(select_sglang_backend(args)) + + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] choices = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) choices.append(lines[i]["endings"]) labels.append(lines[i]["label"]) @@ -56,15 +64,11 @@ def few_shot_hellaswag(s, question, choices): ########## SGL Program End ########## ##################################### - # Select backend - backend = select_sglang_backend(args) - # Run requests tic = time.time() rets = few_shot_hellaswag.run_batch( arguments, temperature=0, - backend=backend, num_threads=args.parallel, progress_bar=True, ) @@ -95,7 +99,7 @@ def few_shot_hellaswag(s, question, choices): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=20) + parser.add_argument("--num-shots", type=int, default=20) parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_sglang_args_and_parse(parser) diff --git a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py index 02bab342ac..c3b8da7d6a 100644 --- a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py +++ b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py @@ -7,6 +7,7 @@ import argparse import csv +import json import os import time @@ -223,7 +224,7 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= tokenizer_path=tokenizer_path, port=cur_port, additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], - model_override_args=model_override_args, + json_model_override_args=json.dumps(model_override_args), tp_size=1, ) sgl.set_default_backend(runtime) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 69d175d843..d51aee4ec9 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -298,34 +298,41 @@ class BenchmarkMetrics: median_e2e_latency_ms: float -default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" +SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" -def download_sharegpt_dataset(path): - url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) - print(f"Downloading dataset from {url}") - try: - response = requests.get(url, stream=True) - response.raise_for_status() + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") - total_size = int(response.headers.get("content-length", 0)) - block_size = 8192 + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors - with open(path, "wb") as f, tqdm( - desc="Downloading", - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as progress_bar: - for data in response.iter_content(block_size): - size = f.write(data) - progress_bar.update(size) + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB - print(f"Dataset downloaded and saved to {path}") - except requests.RequestException as e: - raise Exception(f"Failed to download dataset: {e}") + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename def sample_sharegpt_requests( @@ -338,13 +345,8 @@ def sample_sharegpt_requests( raise ValueError("output_len too small") # Download sharegpt if necessary - if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): - download_sharegpt_dataset(default_sharegpt_path) - dataset_path = default_sharegpt_path - else: - dataset_path = ( - dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path - ) + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. with open(dataset_path) as f: @@ -412,15 +414,8 @@ def sample_random_requests( # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens # Download sharegpt if necessary - if not os.path.isfile(dataset_path) and not os.path.isfile( - default_sharegpt_path - ): - download_sharegpt_dataset(default_sharegpt_path) - dataset_path = default_sharegpt_path - else: - dataset_path = ( - dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path - ) + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. with open(dataset_path) as f: diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 06aa140d9b..ce4cb07c2b 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -9,10 +9,9 @@ if __name__ == "__main__": server_args = prepare_server_args(sys.argv[1:]) - model_override_args = server_args.json_model_override_args try: - launch_server(server_args, model_override_args=model_override_args) + launch_server(server_args) except Exception as e: raise e finally: diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index 6b8d151ee1..6816dcc112 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,5 +1,6 @@ """Launch the inference server for Llava-video model.""" +import json import sys from sglang.srt.server import launch_server, prepare_server_args @@ -19,5 +20,6 @@ model_override_args["model_max_length"] = 4096 * 2 if "34b" in server_args.model_path.lower(): model_override_args["image_token_index"] = 64002 + server_args.json_model_override_args = json.dumps(model_override_args) - launch_server(server_args, model_override_args, None) + launch_server(server_args) diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 57c4913062..fd5995dad1 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -16,6 +16,7 @@ """Cache for the compressed finite state machine.""" from outlines.fsm.json_schema import build_regex_from_schema +from transformers import AutoTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_tool_cache import BaseToolCache @@ -28,12 +29,9 @@ def __init__( tokenizer_args_dict, enable=True, skip_tokenizer_init=False, - json_schema_mode=False, ): super().__init__(enable=enable) - self.json_schema_mode = json_schema_mode - if ( skip_tokenizer_init or tokenizer_path.endswith(".json") @@ -42,44 +40,37 @@ def __init__( # Do not support TiktokenTokenizer or SentencePieceTokenizer return - from importlib.metadata import version + tokenizer_args_dict.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) + try: + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + except AttributeError: + # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) + origin_pad_token_id = tokenizer.pad_token_id - if version("outlines") >= "0.0.35": - from transformers import AutoTokenizer + def fset(self, value): + self._value = value - tokenizer_args_dict.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, **tokenizer_args_dict + type(tokenizer).pad_token_id = property( + fget=type(tokenizer).pad_token_id.fget, fset=fset ) - try: - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - except AttributeError: - # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) - origin_pad_token_id = tokenizer.pad_token_id - - def fset(self, value): - self._value = value - - type(tokenizer).pad_token_id = property( - fget=type(tokenizer).pad_token_id.fget, fset=fset - ) - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token = ( - self.outlines_tokenizer.tokenizer.pad_token - ) - self.outlines_tokenizer.vocabulary = ( - self.outlines_tokenizer.tokenizer.get_vocab() - ) - else: - self.outlines_tokenizer = TransformerTokenizer( - tokenizer_path, **tokenizer_args_dict + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token = ( + self.outlines_tokenizer.tokenizer.pad_token + ) + self.outlines_tokenizer.vocabulary = ( + self.outlines_tokenizer.tokenizer.get_vocab() ) - def init_value(self, value): - if self.json_schema_mode: - regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*") - return RegexGuide(regex, self.outlines_tokenizer), regex + def init_value(self, key): + key_type, key_string = key + if key_type == "json": + regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*") + elif key_type == "regex": + regex = key_string else: - return RegexGuide(value, self.outlines_tokenizer) + raise ValueError(f"Invalid key_type: {key_type}") + + return RegexGuide(regex, self.outlines_tokenizer), regex diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index ba626d4cff..e4b316155a 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -71,12 +71,10 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_override_args, ): # Parse args self.server_args = server_args self.port_args = port_args - self.model_override_args = model_override_args self.load_balance_method = LoadBalanceMethod.from_str( server_args.load_balance_method ) @@ -114,7 +112,6 @@ def start_dp_worker(self, dp_worker_id: int): self.server_args, self.port_args, pipe_controller_writer, - self.model_override_args, True, gpu_ids, dp_worker_id, @@ -189,14 +186,13 @@ def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer, - model_override_args: dict, ): """Start a controller process.""" configure_logger(server_args) try: - controller = ControllerMulti(server_args, port_args, model_override_args) + controller = ControllerMulti(server_args, port_args) except Exception: pipe_writer.send(get_exception_traceback()) raise diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 2ae37059c1..fe03ca1d47 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -40,7 +40,6 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_override_args: dict, gpu_ids: List[int], is_data_parallel_worker: bool, dp_worker_id: int, @@ -76,7 +75,6 @@ def __init__( tp_rank_range, server_args, port_args.nccl_ports[dp_worker_id], - model_override_args, ) # Launch tp rank 0 @@ -85,7 +83,6 @@ def __init__( 0, server_args, port_args.nccl_ports[dp_worker_id], - model_override_args, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group @@ -126,7 +123,6 @@ def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer: multiprocessing.connection.Connection, - model_override_args: dict, is_data_parallel_worker: bool = False, gpu_ids: List[int] = None, dp_worker_id: int = None, @@ -149,7 +145,6 @@ def start_controller_process( controller = ControllerSingle( server_args, port_args, - model_override_args, gpu_ids, is_data_parallel_worker, dp_worker_id, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d0cfed08cd..d2fa676012 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,6 +18,7 @@ import asyncio import concurrent.futures import dataclasses +import json import logging import multiprocessing as mp import os @@ -77,7 +78,6 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_override_args: dict = None, ): self.server_args = server_args @@ -95,7 +95,7 @@ def __init__( self.hf_config = get_config( self.model_path, trust_remote_code=server_args.trust_remote_code, - model_override_args=model_override_args, + model_override_args=json.loads(server_args.json_model_override_args), ) self.is_generation = is_generation_model( self.hf_config.architectures, self.server_args.is_embedding diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7bb9c43356..513bc517f5 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -15,13 +15,14 @@ """A tensor parallel worker.""" +import json import logging import multiprocessing import os import pickle import time import warnings -from typing import Any, List, Optional, Union +from typing import Any, List, Optional import torch import torch.distributed @@ -66,6 +67,7 @@ logger = logging.getLogger(__name__) +# Crash on warning if we are running CI tests crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" @@ -76,11 +78,10 @@ def __init__( tp_rank: int, server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): suppress_other_loggers() - # Copy arguments + # Parse arguments self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = server_args.tp_size @@ -93,9 +94,8 @@ def __init__( server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length, - model_override_args=model_override_args, + model_override_args=json.loads(server_args.json_model_override_args), ) - self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -136,7 +136,7 @@ def __init__( self.max_total_num_tokens - 1, ) - # Sync random seed + # Sync random seed across TP workers server_args.random_seed = broadcast_recv_input( [server_args.random_seed], self.tp_rank, @@ -144,7 +144,7 @@ def __init__( )[0] set_random_seed(server_args.random_seed) - # Print info + # Print debug info logger.info( f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, " @@ -181,7 +181,7 @@ def __init__( self.num_generated_tokens = 0 self.last_stats_tic = time.time() - # Chunked prefill + # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size self.current_inflight_req = None self.is_mixed_chunk = ( @@ -197,16 +197,6 @@ def __init__( "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, - json_schema_mode=False, - ) - self.json_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - json_schema_mode=True, ) self.jump_forward_cache = JumpForwardCache() @@ -227,11 +217,12 @@ def exposed_step(self, recv_reqs: List): try: # Recv requests for recv_req in recv_reqs: - if isinstance( - recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) - ): + if isinstance(recv_req, TokenizedGenerateReqInput): self.handle_generate_request(recv_req) self.do_not_get_new_batch = False + elif isinstance(recv_req, TokenizedEmbeddingReqInput): + self.handle_embedding_request(recv_req) + self.do_not_get_new_batch = False elif isinstance(recv_req, FlushCacheReq): self.flush_cache() elif isinstance(recv_req, AbortReq): @@ -331,57 +322,56 @@ def check_memory(self): def handle_generate_request( self, - recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], + recv_req: TokenizedGenerateReqInput, ): req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.tokenizer = self.tokenizer req.sampling_params = recv_req.sampling_params - if self.model_runner.is_generation: - req.pixel_values = recv_req.pixel_values - if req.pixel_values is not None: - # Use image hash as fake token_ids, which is then used - # for prefix matching - image_hash = hash(tuple(recv_req.image_hashes)) - req.pad_value = [ - (image_hash) % self.model_config.vocab_size, - (image_hash >> 16) % self.model_config.vocab_size, - (image_hash >> 32) % self.model_config.vocab_size, - (image_hash >> 64) % self.model_config.vocab_size, - ] - req.image_sizes = recv_req.image_sizes - ( - req.origin_input_ids, - req.image_offsets, - ) = self.model_runner.model.pad_input_ids( - req.origin_input_ids_unpadded, - req.pad_value, - req.pixel_values, - req.image_sizes, - ) - # Only when pixel values is not None we have modalities - req.modalities = recv_req.modalites - req.return_logprob = recv_req.return_logprob - req.logprob_start_len = recv_req.logprob_start_len - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream - - # Init regex fsm fron json + req.pixel_values = recv_req.pixel_values + if req.pixel_values is not None: + # Use image hash as fake token_ids, which is then used + # for prefix matching + image_hash = hash(tuple(recv_req.image_hashes)) + req.pad_value = [ + (image_hash) % self.model_config.vocab_size, + (image_hash >> 16) % self.model_config.vocab_size, + (image_hash >> 32) % self.model_config.vocab_size, + (image_hash >> 64) % self.model_config.vocab_size, + ] + req.image_sizes = recv_req.image_sizes + ( + req.origin_input_ids, + req.image_offsets, + ) = self.model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, + req.pad_value, + req.pixel_values, + req.image_sizes, + ) + # Only when pixel values is not None we have modalities + req.modalities = recv_req.modalites + req.return_logprob = recv_req.return_logprob + req.logprob_start_len = recv_req.logprob_start_len + req.top_logprobs_num = recv_req.top_logprobs_num + req.stream = recv_req.stream + + # Init regex FSM + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + ): if req.sampling_params.json_schema is not None: - req.regex_fsm, computed_regex_string = self.json_fsm_cache.query( - req.sampling_params.json_schema + req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( + ("json", req.sampling_params.json_schema) ) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - computed_regex_string - ) - - # Init regex fsm elif req.sampling_params.regex is not None: - req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - req.sampling_params.regex - ) + req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( + ("regex", req.sampling_params.regex) + ) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + computed_regex_string + ) # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: @@ -390,16 +380,32 @@ def handle_generate_request( "the max context length. Truncated!!!" ) req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + req.sampling_params.max_new_tokens = min( + ( + req.sampling_params.max_new_tokens + if req.sampling_params.max_new_tokens is not None + else 1 << 30 + ), + self.max_req_input_len - 1 - len(req.origin_input_ids), + ) - if self.model_runner.is_generation: - req.sampling_params.max_new_tokens = min( - ( - req.sampling_params.max_new_tokens - if req.sampling_params.max_new_tokens is not None - else 1 << 30 - ), - self.max_req_input_len - 1 - len(req.origin_input_ids), + self.waiting_queue.append(req) + + def handle_embedding_request( + self, + recv_req: TokenizedEmbeddingReqInput, + ): + req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) + req.tokenizer = self.tokenizer + req.sampling_params = recv_req.sampling_params + + # Truncate prompts that are too long + if len(req.origin_input_ids) >= self.max_req_input_len: + logger.warn( + "Request length is longer than the KV cache pool size or " + "the max context length. Truncated!!!" ) + req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] self.waiting_queue.append(req) @@ -892,7 +898,6 @@ def run_tp_server( tp_rank: int, server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): """Run a tensor parallel model server.""" configure_logger(server_args, prefix=f" TP{tp_rank}") @@ -903,7 +908,6 @@ def run_tp_server( tp_rank, server_args, nccl_port, - model_override_args, ) tp_cpu_group = model_server.model_runner.tp_group.cpu_group @@ -920,14 +924,13 @@ def launch_tp_servers( tp_rank_range: List[int], server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): """Launch multiple tensor parallel servers.""" procs = [] for i in tp_rank_range: proc = multiprocessing.Process( target=run_tp_server, - args=(gpu_ids[i], i, server_args, nccl_port, model_override_args), + args=(gpu_ids[i], i, server_args, nccl_port), ) proc.start() procs.append(proc) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3d3e0cde9d..9c82b2a813 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -18,6 +18,7 @@ import gc import importlib import importlib.resources +import json import logging import pkgutil from functools import lru_cache diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index feaf91dd39..d44d617522 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str): def launch_server( server_args: ServerArgs, - model_override_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): """Launch an HTTP server.""" @@ -317,7 +316,6 @@ def launch_server( tp_rank_range, server_args, ports[3], - model_override_args, ) try: @@ -328,7 +326,7 @@ def launch_server( return # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args) + tokenizer_manager = TokenizerManager(server_args, port_args) if server_args.chat_template: load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) @@ -341,7 +339,7 @@ def launch_server( proc_controller = mp.Process( target=start_controller_process, - args=(server_args, port_args, pipe_controller_writer, model_override_args), + args=(server_args, port_args, pipe_controller_writer), ) proc_controller.start() @@ -501,7 +499,6 @@ class Runtime: def __init__( self, log_level: str = "error", - model_override_args: Optional[dict] = None, *args, **kwargs, ): @@ -525,7 +522,7 @@ def __init__( proc = mp.Process( target=launch_server, - args=(self.server_args, model_override_args, pipe_writer), + args=(self.server_args, pipe_writer), ) proc.start() pipe_writer.close() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e21f02108c..14dd63b5ad 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -76,6 +76,14 @@ class ServerArgs: dp_size: int = 1 load_balance_method: str = "round_robin" + # Distributed args + nccl_init_addr: Optional[str] = None + nnodes: int = 1 + node_rank: Optional[int] = None + + # Model override args in JSON + json_model_override_args: str = "{}" + # Optimization/debug options disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False @@ -91,14 +99,6 @@ class ServerArgs: enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False - # Distributed args - nccl_init_addr: Optional[str] = None - nnodes: int = 1 - node_rank: Optional[int] = None - - # Model override args in JSON - json_model_override_args: Optional[dict] = None - def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -385,6 +385,14 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser.add_argument("--node-rank", type=int, help="The node rank.") + # Model override args + parser.add_argument( + "--json-model-override-args", + type=str, + help="A dictionary in JSON string format used to override default model configurations.", + default=ServerArgs.json_model_override_args, + ) + # Optimization/debug options parser.add_argument( "--disable-flashinfer", @@ -459,22 +467,10 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) - # Model override args - parser.add_argument( - "--json-model-override-args", - type=str, - help="A dictionary in JSON string format used to override default model configurations.", - ) - @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size - args.json_model_override_args = ( - json.loads(args.json_model_override_args) - if args.json_model_override_args - else None - ) attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) @@ -498,7 +494,7 @@ def check_server_args(self): self.disable_flashinfer = False -def prepare_server_args(args: argparse.Namespace) -> ServerArgs: +def prepare_server_args(argv: List[str]) -> ServerArgs: """ Prepare the server arguments from the command line arguments. @@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs: """ parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) - raw_args = parser.parse_args(args) + raw_args = parser.parse_args(argv) server_args = ServerArgs.from_cli_args(raw_args) return server_args diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py new file mode 100644 index 0000000000..18ae2d8c35 --- /dev/null +++ b/python/sglang/test/few_shot_gsm8k.py @@ -0,0 +1,132 @@ +""" +Run few-shot GSM-8K evaluation. + +Usage: +python3 -m sglang.test.few_shot_gsm8k --num-questions 200 +""" + +import argparse +import ast +import re +import time + +import numpy as np + +from sglang.api import set_default_backend +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl + +INVALID = -9999999 + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def main(args): + # Select backend + set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}")) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + questions = [] + labels = [] + for i in range(len(lines[:num_questions])): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q} for q in questions] + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def few_shot_gsm8k(s, question): + s += few_shot_examples + question + s += sgl.gen( + "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"] + ) + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Run requests + tic = time.time() + states = few_shot_gsm8k.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=True, + ) + latency = time.time() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i]["answer"])) + + # print(f"{preds=}") + # print(f"{labels=}") + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Compute speed + num_output_tokens = sum( + s.get_meta_info("answer")["completion_tokens"] for s in states + ) + output_throughput = num_output_tokens / latency + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + print(f"Output throughput: {output_throughput:.3f} token/s") + + # Dump results + dump_state_text("tmp_output_gsm8k.txt", states) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--parallel", type=int, default=128) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + main(args) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index bdecdff2f9..41f466f730 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -7,7 +7,7 @@ import numpy as np import sglang as sgl -from sglang.utils import fetch_and_cache_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def test_few_shot_qa(): @@ -456,10 +456,6 @@ def gen_character_spec(s): def test_hellaswag_select(): """Benchmark the accuracy of sgl.select on the HellaSwag dataset.""" - url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" - lines = fetch_and_cache_jsonl(url) - - # Construct prompts def get_one_example(lines, i, include_answer): ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " if include_answer: @@ -472,6 +468,12 @@ def get_few_shot_examples(lines, k): ret += get_one_example(lines, i, True) + "\n\n" return ret + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) + + # Construct prompts num_questions = 200 num_shots = 20 few_shot_examples = get_few_shot_examples(lines, num_shots) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index b212f6caa3..621efb5373 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps -from typing import Union +from typing import Optional, Union import numpy as np import requests @@ -38,13 +38,11 @@ def is_same_type(values: list): def read_jsonl(filename: str): """Read a JSONL file.""" - rets = [] with open(filename) as fin: for line in fin: if line.startswith("#"): continue - rets.append(json.loads(line)) - return rets + yield json.loads(line) def dump_state_text(filename: str, states: list, mode: str = "w"): @@ -264,38 +262,35 @@ def __call__(self, *args, **kwargs): return module(*args, **kwargs) -def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"): - """Read and cache a jsonl file from a url.""" +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) # Check if the cache file already exists - if os.path.exists(cache_file): - print("Loading data from cache...") - with open(cache_file, "r") as f: - data = [json.loads(line) for line in f] - else: - print("Downloading data from URL...") - # Stream the response to show the progress bar - response = requests.get(url, stream=True) - response.raise_for_status() # Check for request errors - - # Total size of the file in bytes - total_size = int(response.headers.get("content-length", 0)) - chunk_size = 1024 # Download in chunks of 1KB - - # Use tqdm to display the progress bar - with open(cache_file, "wb") as f, tqdm( - desc=cache_file, - total=total_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for chunk in response.iter_content(chunk_size=chunk_size): - f.write(chunk) - bar.update(len(chunk)) - - # Convert the data to a list of dictionaries - with open(cache_file, "r") as f: - data = [json.loads(line) for line in f] - - return data + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index d4b1354b79..b15308dcec 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -42,7 +42,7 @@ def test_mmlu(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.63, f"{metrics}" + assert metrics["score"] >= 0.62, f"{metrics}" def test_human_eval(self): args = SimpleNamespace( @@ -66,7 +66,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.63, f"{metrics}" + assert metrics["score"] >= 0.62, f"{metrics}" if __name__ == "__main__": diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py index 71129e3eb1..d8f31ce1b9 100644 --- a/test/srt/test_server_args.py +++ b/test/srt/test_server_args.py @@ -1,3 +1,4 @@ +import json import unittest from sglang.srt.server_args import prepare_server_args @@ -15,7 +16,7 @@ def test_prepare_server_args(self): ) self.assertEqual(server_args.model_path, "model_path") self.assertEqual( - server_args.json_model_override_args, + json.loads(server_args.json_model_override_args), {"rope_scaling": {"factor": 2.0, "type": "linear"}}, ) From a7c47e0f028c2a9e67cbc99ab67692ec765d3dd0 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 9 Sep 2024 05:32:41 -0700 Subject: [PATCH 13/33] Add torchao quant (int4/int8/fp8) to llama models (#1341) Co-authored-by: Lianmin Zheng --- python/pyproject.toml | 2 +- python/sglang/srt/layers/torchao_utils.py | 36 +++++++++ .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/models/llama.py | 22 ++++++ python/sglang/srt/server_args.py | 9 ++- test/srt/test_eval_accuracy_mini.py | 4 +- test/srt/test_moe_eval_accuracy_large.py | 6 +- test/srt/test_torch_compile.py | 6 +- test/srt/test_torchao.py | 73 +++++++++++++++++++ test/srt/test_triton_attn_backend.py | 4 +- 10 files changed, 151 insertions(+), 12 deletions(-) create mode 100644 python/sglang/srt/layers/torchao_utils.py create mode 100644 test/srt/test_torchao.py diff --git a/python/pyproject.toml b/python/pyproject.toml index daf09ea25d..1389822a34 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", "psutil", "pydantic", "python-multipart", - "torch", "uvicorn", "uvloop", "zmq", + "torch", "torchao", "uvicorn", "uvloop", "zmq", "vllm==0.5.5", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py new file mode 100644 index 0000000000..16eb1f2c5c --- /dev/null +++ b/python/sglang/srt/layers/torchao_utils.py @@ -0,0 +1,36 @@ +""" +Common utilities for torchao. +""" + +import torch +from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, +) + + +def torchao_quantize_param_data(param, torchao_config): + dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) + dummy_linear.weight = param + if "int8wo" in torchao_config: + quantize_(dummy_linear, int8_weight_only()) + elif "int8dq" in torchao_config: + quantize_(dummy_linear, int8_dynamic_activation_int8_weight()) + elif "int4wo" in torchao_config: + group_size = int(torchao_config.split("-")[-1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" + quantize_(dummy_linear, int4_weight_only(group_size=group_size)) + elif "fp8wo" in torchao_config: + from torchao.quantization import float8_weight_only + + # this requires newer hardware + # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 + quantize_(dummy_linear, float8_weight_only()) + return dummy_linear.weight diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9c82b2a813..78f99dcd67 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -97,6 +97,7 @@ def __init__( "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "enable_mla": server_args.enable_mla, + "torchao_config": server_args.torchao_config, } ) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 926d87db8b..ac53712fca 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -42,6 +42,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.torchao_utils import torchao_quantize_param_data +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -299,6 +301,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config + self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -361,6 +364,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if self.torchao_config: + if name.endswith("proj.weight") and param.ndim == 2: + params_dict[name] = torchao_quantize_param_data( + param, self.torchao_config + ) + + if self.torchao_config: + # quantizing the loaded, stacked params, e.g. "...qkv_proj" + stacked_params = set(entry[0] for entry in stacked_params_mapping) + for param_suffix in stacked_params: + for name in params_dict: + if param_suffix in name: + param = params_dict[name] + params_dict[name] = torchao_quantize_param_data( + param, self.torchao_config + ) + + self.load_state_dict(params_dict, assign=True) + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 14dd63b5ad..3dfb1dc411 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -95,6 +95,7 @@ class ServerArgs: disable_custom_all_reduce: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False + torchao_config: str = "" enable_p2p_check: bool = False enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False @@ -443,7 +444,13 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-torch-compile", action="store_true", - help="Optimize the model with torch.compile, experimental feature.", + help="Optimize the model with torch.compile. Experimental feature.", + ) + parser.add_argument( + "--torchao-config", + type=str, + default=ServerArgs.torchao_config, + help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo", ) parser.add_argument( "--enable-p2p-check", diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index 25aa0ca116..6ddd97d940 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -29,12 +29,12 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 if __name__ == "__main__": diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index b15308dcec..b6027b61cb 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -42,7 +42,7 @@ def test_mmlu(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.62, f"{metrics}" + assert metrics["score"] >= 0.625, f"{metrics}" def test_human_eval(self): args = SimpleNamespace( @@ -54,7 +54,7 @@ def test_human_eval(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.42, f"{metrics}" + assert metrics["score"] >= 0.425, f"{metrics}" def test_mgsm_en(self): args = SimpleNamespace( @@ -66,7 +66,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.62, f"{metrics}" + assert metrics["score"] >= 0.625, f"{metrics}" if __name__ == "__main__": diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index e8cafa15d2..40f47d6b6b 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -22,7 +22,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile", "--disable-radix-cache"], + other_args=["--enable-torch-compile"], ) @classmethod @@ -34,12 +34,12 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 def run_decode(self, max_new_tokens): response = requests.post( diff --git a/test/srt/test_torchao.py b/test/srt/test_torchao.py new file mode 100644 index 0000000000..d2084e7d53 --- /dev/null +++ b/test/srt/test_torchao.py @@ -0,0 +1,73 @@ +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestTorchCompile(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--torchao-config", "int4wo-128"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + import time + + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 210 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index a94ca92124..b3f65ac13a 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -32,12 +32,12 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 if __name__ == "__main__": From 689ff588eca5b6d401b6bfd736cf98cd2b776144 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Mon, 9 Sep 2024 13:05:13 -0700 Subject: [PATCH 14/33] [CI] Return output logprobs in unit test (#1361) --- python/sglang/test/runners.py | 57 +++++++++++++++++------ test/srt/models/test_generation_models.py | 37 ++++++++++++--- 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index ac69ab875b..1d18d305fc 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -50,6 +50,12 @@ def get_dtype_str(torch_dtype): raise NotImplementedError() +def get_top_logprobs(logits, k): + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1) + return logprobs + + @dataclass class ModelOutput: output_strs: List[str] = None @@ -108,7 +114,8 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): if prompts is not None: if self.is_generation: output_strs = [] - prefill_logprobs = [] + top_input_logprobs = [] + top_output_logprobs = [] for p in prompts: if isinstance(p, str): input_ids = self.tokenizer.encode( @@ -117,32 +124,43 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): else: input_ids = torch.tensor([p], device="cuda") - output_ids = self.model.generate( - input_ids, do_sample=False, max_new_tokens=max_new_tokens + outputs = self.model.generate( + input_ids, + do_sample=False, + temperature=None, + top_p=None, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + output_scores=True, ) output_strs.append( - self.tokenizer.decode(output_ids[0][len(input_ids[0]) :]) + self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :]) ) + # outputs.scores: (num_token, 1, vocab_size) + top_output_logprobs.append( + [ + get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist() + for logits in outputs.scores + ] + ) + del outputs - logits = self.model.forward(input_ids).logits[0] - logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) - logprobs, top_indices = torch.topk( - logprobs, k=NUM_TOP_LOGPROBS, dim=-1 + input_logits = self.model.forward(input_ids).logits[0] + top_input_logprobs.append( + get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist() ) - # print("index", top_indices) - prefill_logprobs.append(logprobs.tolist()) - del logits - del logprobs + del input_logits out_queue.put( ModelOutput( - output_strs=output_strs, top_input_logprobs=prefill_logprobs + output_strs=output_strs, + top_input_logprobs=top_input_logprobs, + top_output_logprobs=top_output_logprobs, ) ) else: logits = self.model.encode(prompts).tolist() - out_queue.put(ModelOutput(embed_logits=logits)) def forward( @@ -194,6 +212,7 @@ def forward( # the return value contains logprobs from prefill output_strs = [] top_input_logprobs = [] + top_output_logprobs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} for prompt in prompts: response = self.runtime.generate( @@ -219,9 +238,17 @@ def forward( ] ] ) + top_output_logprobs.append( + [ + [tup[0] for tup in x[:NUM_TOP_LOGPROBS]] + for x in response["meta_info"]["output_top_logprobs"] + ] + ) return ModelOutput( - output_strs=output_strs, top_input_logprobs=top_input_logprobs + output_strs=output_strs, + top_input_logprobs=top_input_logprobs, + top_output_logprobs=top_output_logprobs, ) else: response = self.runtime.encode(prompts) diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 08288c510c..46854b3e86 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -21,9 +21,9 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner MODELS = [ - ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1), - ("google/gemma-2-2b", 1, 3, 3e-2, 1), - ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1), + ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1), + ("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1), + ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1), ] TORCH_DTYPES = [torch.float16] @@ -70,6 +70,7 @@ def assert_close_prefill_logits_and_output_strs( torch_dtype, max_new_tokens, prefill_tolerance, + output_tolerance, rouge_threshold, long_context_tolerance, ) -> None: @@ -89,15 +90,37 @@ def assert_close_prefill_logits_and_output_strs( srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) for i in range(len(prompts)): + # input logprobs comparison hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) - - print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs))) - if hf_logprobs.shape[0] <= 100: + input_len = hf_logprobs.shape[0] + print( + "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) + ) + if input_len <= 100: assert torch.all( abs(hf_logprobs - srt_logprobs) < prefill_tolerance ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}" + # output logprobs comparison + hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) + # print( + # "output logprobs diff", + # [ + # float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j]))) + # for j in range(max_new_tokens) + # ], + # ) + print( + "output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) + ) + if input_len <= 100: + assert torch.all( + abs(hf_logprobs - srt_logprobs) < output_tolerance + ), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}" + + # output strings comparison print(f"hf_outputs.output_strs={hf_outputs.output_strs}") print(f"srt_outputs.output_strs={srt_outputs.output_strs}") rouge_l_scores = calculate_rouge_l( @@ -114,6 +137,7 @@ def test_prefill_logits_and_output_strs(self): tp_size, long_context_tolerance, prefill_tolerance, + output_tolerance, rouge_threshold, ) in MODELS: for torch_dtype in TORCH_DTYPES: @@ -125,6 +149,7 @@ def test_prefill_logits_and_output_strs(self): torch_dtype, max_new_tokens, prefill_tolerance=prefill_tolerance, + output_tolerance=output_tolerance, rouge_threshold=rouge_threshold, long_context_tolerance=long_context_tolerance, ) From 69b3bb9ae1c504925455e8b258eefa0fcc15bd81 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 9 Sep 2024 13:49:29 -0700 Subject: [PATCH 15/33] Unify forward mode (#1360) --- python/sglang/bench_latency.py | 5 +-- python/sglang/srt/layers/logits_processor.py | 6 +-- python/sglang/srt/layers/radix_attention.py | 4 +- python/sglang/srt/managers/schedule_batch.py | 7 ++++ python/sglang/srt/managers/tp_worker.py | 11 ++--- .../srt/model_executor/forward_batch_info.py | 41 +++++++++++-------- .../sglang/srt/model_executor/model_runner.py | 30 +++++--------- python/sglang/srt/models/llava.py | 4 +- python/sglang/srt/models/llavavid.py | 4 +- 9 files changed, 54 insertions(+), 58 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 6113495776..be67958460 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -60,7 +60,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_config import ModelConfig -from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -208,14 +207,14 @@ def extend(reqs, model_runner): tree_cache=None, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) - sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND) + sample_output, logits_output = model_runner.forward(batch) next_token_ids = sample_output.batch_next_token_ids.tolist() return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids) - sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE) + sample_output, logits_output = model_runner.forward(batch) next_token_ids = sample_output.batch_next_token_ids.tolist() return next_token_ids, logits_output.next_token_logits diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index b81f3d2a04..72a926cab6 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -103,7 +103,7 @@ def _get_normalized_prompt_logprobs( @staticmethod def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): - if logits_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode.is_decode(): output_top_logprobs = [] max_k = max(logits_metadata.top_logprobs_nums) ret = all_logprobs.topk(max_k, dim=1) @@ -163,7 +163,7 @@ def forward( assert isinstance(logits_metadata, LogitsMetadata) # Get the last hidden states and last logits for the next token prediction - if logits_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode.is_decode(): last_index = None last_hidden = hidden_states else: @@ -195,7 +195,7 @@ def forward( ) else: # When logprob is requested, compute the logits for all tokens. - if logits_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode.is_decode(): last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) # Get the logprob of top-k tokens diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 91735a1b81..1a2feacd3d 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -197,9 +197,9 @@ def forward(self, q, k, v, input_metadata: InputMetadata): k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim) - if input_metadata.forward_mode == ForwardMode.EXTEND: + if input_metadata.forward_mode.is_extend(): return self.extend_forward(q, k, v, input_metadata) - elif input_metadata.forward_mode == ForwardMode.DECODE: + elif input_metadata.forward_mode.is_decode(): return self.decode_forward(q, k, v, input_metadata) def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f126cc9f3a..6c6b7f8426 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,6 +29,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo if TYPE_CHECKING: @@ -334,6 +335,8 @@ class ScheduleBatch: token_to_kv_pool: BaseTokenToKVPool tree_cache: BasePrefixCache + forward_mode: ForwardMode = None + # Batched arguments to model runner input_ids: torch.Tensor = None req_pool_indices: torch.Tensor = None @@ -397,6 +400,8 @@ def alloc_token_slots(self, num_tokens: int): return out_cache_loc def prepare_for_extend(self, vocab_size: int): + self.forward_mode = ForwardMode.EXTEND + bs = self.batch_size() reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] @@ -626,6 +631,8 @@ def check_for_jump_forward(self, model_runner): return jump_forward_reqs def prepare_for_decode(self, input_ids=None): + self.forward_mode = ForwardMode.DECODE + if input_ids is None: input_ids = [ r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1] diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 513bc517f5..736929a654 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -53,7 +53,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_config import ModelConfig -from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -521,9 +520,7 @@ def forward_prefill_batch(self, batch: ScheduleBatch): if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - sample_output, logits_output = self.model_runner.forward( - batch, ForwardMode.EXTEND - ) + sample_output, logits_output = self.model_runner.forward(batch) next_token_ids = batch.check_sample_results(sample_output) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids @@ -588,7 +585,7 @@ def forward_prefill_batch(self, batch: ScheduleBatch): pt += req.extend_input_len else: assert batch.extend_num_tokens != 0 - logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND) + logits_output = self.model_runner.forward(batch) embeddings = logits_output.embeddings.tolist() # Check finish conditions @@ -699,9 +696,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): batch.prepare_for_decode() # Forward and sample the next tokens - sample_output, logits_output = self.model_runner.forward( - batch, ForwardMode.DECODE - ) + sample_output, logits_output = self.model_runner.forward(batch) next_token_ids = batch.check_sample_results(sample_output) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 75f9136d39..a6ad63ce18 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -25,10 +25,9 @@ import triton import triton.language as tl -from sglang.srt.managers.schedule_batch import ScheduleBatch -from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool - if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -41,6 +40,15 @@ class ForwardMode(IntEnum): # Decode one token. DECODE = auto() + def is_prefill(self): + return self == ForwardMode.PREFILL + + def is_extend(self): + return self == ForwardMode.EXTEND + + def is_decode(self): + return self == ForwardMode.DECODE + @dataclass class InputMetadata: @@ -102,7 +110,7 @@ def init_multimuldal_info(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets - if self.forward_mode == ForwardMode.DECODE: + if self.forward_mode.is_decode(): if True: self.positions = self.seq_lens - 1 else: @@ -141,7 +149,7 @@ def compute_positions(self, batch: ScheduleBatch): self.positions = self.positions.to(torch.int64) def compute_extend_infos(self, batch: ScheduleBatch): - if self.forward_mode == ForwardMode.DECODE: + if self.forward_mode.is_decode(): self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None else: @@ -173,10 +181,9 @@ def from_schedule_batch( cls, model_runner: "ModelRunner", batch: ScheduleBatch, - forward_mode: ForwardMode, ): ret = cls( - forward_mode=forward_mode, + forward_mode=batch.forward_mode, sampling_info=batch.sampling_info, batch_size=batch.batch_size(), req_pool_indices=batch.req_pool_indices, @@ -194,13 +201,11 @@ def from_schedule_batch( ret.compute_extend_infos(batch) - if ( - forward_mode != ForwardMode.DECODE - or model_runner.server_args.disable_flashinfer - ): + fm = batch.forward_mode + if not fm.is_decode() or model_runner.server_args.disable_flashinfer: ret.total_num_tokens = int(torch.sum(ret.seq_lens)) - if forward_mode != ForwardMode.DECODE: + if not fm.is_decode(): ret.init_multimuldal_info(batch) if model_runner.server_args.disable_flashinfer: @@ -209,7 +214,7 @@ def from_schedule_batch( flashinfer_use_ragged = False if not model_runner.server_args.disable_flashinfer: if ( - forward_mode != ForwardMode.DECODE + not fm.is_decode() and int(torch.sum(ret.seq_lens)) > 4096 and model_runner.sliding_window_size is None ): @@ -226,7 +231,7 @@ def init_triton_args(self, batch: ScheduleBatch): self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) - if self.forward_mode == ForwardMode.DECODE: + if self.forward_mode.is_decode(): self.triton_max_extend_len = None else: self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") @@ -239,7 +244,7 @@ def init_flashinfer_handlers( prefix_lens_cpu, flashinfer_use_ragged, ): - if self.forward_mode == ForwardMode.DECODE: + if self.forward_mode.is_decode(): prefix_lens = None else: prefix_lens = self.extend_prefix_lens @@ -339,7 +344,7 @@ def update_flashinfer_indices( kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - if forward_mode == ForwardMode.DECODE: + if forward_mode.is_decode(): # CUDA graph uses different flashinfer_decode_wrapper if flashinfer_decode_wrapper is None: flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper @@ -388,7 +393,7 @@ def update_flashinfer_indices( kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") for wrapper_id in range(2): if wrapper_id == 0: - if forward_mode == ForwardMode.DECODE: + if forward_mode.is_decode(): paged_kernel_lens = torch.minimum( seq_lens, torch.tensor(model_runner.sliding_window_size + 1) ) @@ -418,7 +423,7 @@ def update_flashinfer_indices( kv_indices, ) - if forward_mode == ForwardMode.DECODE: + if forward_mode.is_decode(): # CUDA graph uses different flashinfer_decode_wrapper if flashinfer_decode_wrapper is None: flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 78f99dcd67..3cb123c482 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -530,11 +530,7 @@ def forward_decode(self, batch: ScheduleBatch): ): return self.cuda_graph_runner.replay(batch) - input_metadata = InputMetadata.from_schedule_batch( - self, - batch, - ForwardMode.DECODE, - ) + input_metadata = InputMetadata.from_schedule_batch(self, batch) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -542,11 +538,7 @@ def forward_decode(self, batch: ScheduleBatch): @torch.inference_mode() def forward_extend(self, batch: ScheduleBatch): - input_metadata = InputMetadata.from_schedule_batch( - self, - batch, - forward_mode=ForwardMode.EXTEND, - ) + input_metadata = InputMetadata.from_schedule_batch(self, batch) if self.is_generation: return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -562,11 +554,7 @@ def forward_extend(self, batch: ScheduleBatch): @torch.inference_mode() def forward_extend_multi_modal(self, batch: ScheduleBatch): - input_metadata = InputMetadata.from_schedule_batch( - self, - batch, - forward_mode=ForwardMode.EXTEND, - ) + input_metadata = InputMetadata.from_schedule_batch(self, batch) return self.model.forward( batch.input_ids, input_metadata.positions, @@ -577,16 +565,18 @@ def forward_extend_multi_modal(self, batch: ScheduleBatch): ) def forward( - self, batch: ScheduleBatch, forward_mode: ForwardMode + self, batch: ScheduleBatch ) -> Tuple[SampleOutput, LogitsProcessorOutput]: - if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: + assert batch.forward_mode is not None + + if self.is_multimodal_model and batch.forward_mode.is_extend(): return self.forward_extend_multi_modal(batch) - elif forward_mode == ForwardMode.DECODE: + elif batch.forward_mode.is_decode(): return self.forward_decode(batch) - elif forward_mode == ForwardMode.EXTEND: + elif batch.forward_mode.is_extend(): return self.forward_extend(batch) else: - raise ValueError(f"Invaid forward mode: {forward_mode}") + raise ValueError(f"Invaid forward mode: {batch.forward_mode}") @lru_cache() diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 62041a8955..9e20a726a7 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -136,7 +136,7 @@ def forward( image_sizes: Optional[List[List[int]]] = None, image_offsets: Optional[List[int]] = None, ) -> torch.Tensor: - if input_metadata.forward_mode == ForwardMode.EXTEND: + if input_metadata.forward_mode.is_extend(): bs = input_metadata.batch_size # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size @@ -357,7 +357,7 @@ def forward( return self.language_model( input_ids, positions, input_metadata, input_embeds=input_embeds ) - elif input_metadata.forward_mode == ForwardMode.DECODE: + elif input_metadata.forward_mode.is_decode(): return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index f268ecbbcd..45f47cffcb 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -116,7 +116,7 @@ def forward( image_sizes: Optional[List[List[int]]] = None, image_offsets: Optional[List[int]] = None, ) -> torch.Tensor: - if input_metadata.forward_mode == ForwardMode.EXTEND: + if input_metadata.forward_mode.is_extend(): bs = input_metadata.batch_size # Embed text inputs @@ -199,7 +199,7 @@ def forward( return self.language_model( input_ids, positions, input_metadata, input_embeds=input_embeds ) - elif input_metadata.forward_mode == ForwardMode.DECODE: + elif input_metadata.forward_mode.is_decode(): return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): From 9144ed1067f27ae682d48fc4f183e24098b72f6d Mon Sep 17 00:00:00 2001 From: zifeitong Date: Mon, 9 Sep 2024 19:08:25 -0700 Subject: [PATCH 16/33] Support OpenAI API json_schema response format (#1363) --- python/sglang/srt/constrained/jump_forward.py | 1 - python/sglang/srt/openai_api/adapter.py | 43 ++++++++++++------- python/sglang/srt/openai_api/protocol.py | 13 ++++-- test/srt/test_json_constrained.py | 5 ++- 4 files changed, 41 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index 244931e050..b00c48d478 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -23,7 +23,6 @@ import interegular import outlines.caching -from outlines.fsm.json_schema import build_regex_from_schema from sglang.srt.constrained import ( FSMInfo, diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index f1195aff7c..d1b296e9b9 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -28,6 +28,13 @@ from fastapi.responses import JSONResponse, StreamingResponse from pydantic import ValidationError +try: + from outlines.fsm.json_schema import convert_json_schema_to_str +except ImportError: + # Before outlines 0.0.47, convert_json_schema_to_str is under + # outlines.integrations.utils + from outlines.integrations.utils import convert_json_schema_to_str + from sglang.srt.conversation import ( Conversation, SeparatorStyle, @@ -888,22 +895,26 @@ def v1_chat_generate_request( return_logprobs.append(request.logprobs) logprob_start_lens.append(-1) top_logprobs_nums.append(request.top_logprobs) - sampling_params_list.append( - { - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "min_new_tokens": request.min_tokens, - "stop": stop, - "stop_token_ids": request.stop_token_ids, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "repetition_penalty": request.repetition_penalty, - "regex": request.regex, - "json_schema": request.json_schema, - "n": request.n, - } - ) + + sampling_params = { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "n": request.n, + } + if request.response_format and request.response_format.type == "json_schema": + sampling_params["json_schema"] = convert_json_schema_to_str( + request.response_format.json_schema.schema_ + ) + sampling_params_list.append(sampling_params) + image_data_list.append(image_data) modalities_list.extend(modalities) if len(all_requests) == 1: diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 5525cd8827..3d7d450c9d 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -82,6 +82,14 @@ class StreamOptions(BaseModel): include_usage: Optional[bool] = False +class JsonSchemaResponseFormat(BaseModel): + name: str + description: Optional[str] = None + # use alias to workaround pydantic conflict + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + strict: Optional[bool] = False + + class FileRequest(BaseModel): # https://platform.openai.com/docs/api-reference/files/create file: bytes # The File object (not file name) to be uploaded @@ -237,8 +245,8 @@ class ChatCompletionMessageUserParam(BaseModel): class ResponseFormat(BaseModel): - # type must be "json_object" or "text" - type: Literal["text", "json_object"] + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None class ChatCompletionRequest(BaseModel): @@ -264,7 +272,6 @@ class ChatCompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None - json_schema: Optional[str] = None min_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 5393ecc33c..122d79968c 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -79,7 +79,10 @@ def test_json_openai(self): ], temperature=0, max_tokens=128, - extra_body={"json_schema": self.json_schema}, + response_format={ + "type": "json_schema", + "json_schema": {"name": "foo", "schema": json.loads(self.json_schema)}, + }, ) text = response.choices[0].message.content From 743007e1ce07b99529b49d95413f4879853be1ac Mon Sep 17 00:00:00 2001 From: Chayenne Date: Tue, 10 Sep 2024 10:09:13 +0800 Subject: [PATCH 17/33] Adding Documentation for installation (#1300) Co-authored-by: zhaochen20 --- docs/en/index.rst | 24 +++++++-- docs/en/install.md | 116 ++++++++++++++++++++++++++++++++++++++++++ docs/requirements.txt | 2 - 3 files changed, 136 insertions(+), 6 deletions(-) create mode 100644 docs/en/install.md diff --git a/docs/en/index.rst b/docs/en/index.rst index 5e4701c53b..d1a96e8cb0 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -1,4 +1,4 @@ -Welcome to SGLang's tutorials! +Welcome to SGLang! ==================================== .. figure:: ./_static/image/logo.png @@ -27,9 +27,22 @@ SGLang has the following core features: * **Flexible Frontend Language**: Enables easy programming of LLM applications with chained generation calls, advanced prompting, control flow, multiple modalities, parallelism, and external interactions. +* **Extensive Model Support**: SGLang supports a wide range of generative models including the Llama series (up to Llama 3.1), Mistral, Gemma, Qwen, DeepSeek, LLaVA, Yi-VL, StableLM, Command-R, DBRX, Grok, ChatGLM, InternLM 2 and Exaone 3. It also supports embedding models such as e5-mistral and gte-Qwen2. Easily extensible to support new models. + +* **Open Source Community**: SGLang is an open source project with a vibrant community of contributors. We welcome contributions from anyone interested in advancing the state of the art in LLM and VLM serving. + Documentation ------------- +.. In this documentation, we'll dive into these following areas to help you get the most out of SGLang. + +.. _installation: +.. toctree:: + :maxdepth: 1 + :caption: Installation + + install.md + .. _hyperparameter_tuning: .. toctree:: :maxdepth: 1 @@ -58,7 +71,10 @@ Documentation sampling_params.md -Search Bar -================== -* :ref:`search` +.. _benchmark_and_profilling: +.. toctree:: + :maxdepth: 1 + :caption: Benchmark and Profilling + + benchmark_and_profiling.md \ No newline at end of file diff --git a/docs/en/install.md b/docs/en/install.md new file mode 100644 index 0000000000..877f69d680 --- /dev/null +++ b/docs/en/install.md @@ -0,0 +1,116 @@ +# SGLang Installation Guide + +SGLang consists of a frontend language (Structured Generation Language, SGLang) and a backend runtime (SGLang Runtime, SRT). The frontend can be used separately from the backend, allowing for a detached frontend-backend setup. + +## Quick Installation Options + +### 1. Frontend Installation (Client-side, any platform) + +```bash +pip install --upgrade pip +pip install sglang +``` + +**Note: You can check [these examples](https://github.com/sgl-project/sglang/tree/main/examples/frontend_language/usage) for how to use frontend and backend separately.** + +### 2. Backend Installation (Server-side, Linux only) + +```bash +pip install --upgrade pip +pip install "sglang[all]" +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ +``` + +**Note: The backend (SRT) is only needed on the server side and is only available for Linux right now.** + +**Important: Please check the [flashinfer installation guidance](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** + +### 3. From Source (Latest version, Linux only for full installation) + +```bash +# Use the latest release branch +# As of this documentation, it's v0.2.15, but newer versions may be available +# Do not clone the main branch directly; always use a specific release version +# The main branch may contain unresolved bugs before a new release +git clone -b v0.2.15 https://github.com/sgl-project/sglang.git +cd sglang +pip install -e "python[all]" +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ +``` + +### 4. OpenAI Backend Only (Client-side, any platform) + +If you only need to use the OpenAI backend, you can avoid installing other dependencies by using: + +```bash +pip install "sglang[openai]" +``` + +## Advanced Installation Options + +### 1. Using Docker (Server-side, Linux only) + +The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/blob/main/docker). Replace `` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). + +```bash +docker run --gpus all -p 30000:30000 \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HF_TOKEN=" --ipc=host \ + lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 +``` + +### 2.Using docker compose + +This method is recommended if you plan to serve it as a service. A better approach is to use the [k8s-sglang-service.yaml](https://github.com/sgl-project/sglang/blob/main/docker/k8s-sglang-service.yaml). + +1. Copy the [compose.yml](https://github.com/sgl-project/sglang/blob/main/docker/compose.yaml) to your local machine +2. Execute the command `docker compose up -d` in your terminal. + +### 3.Run on Kubernetes or Clouds with SkyPilot + +
+More + +To deploy on Kubernetes or 12+ clouds, you can use [SkyPilot](https://github.com/skypilot-org/skypilot). + +1. Install SkyPilot and set up Kubernetes cluster or cloud access: see [SkyPilot's documentation](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html). +2. Deploy on your own infra with a single command and get the HTTP API endpoint: +
+SkyPilot YAML: sglang.yaml + +```yaml +# sglang.yaml +envs: + HF_TOKEN: null + +resources: + image_id: docker:lmsysorg/sglang:latest + accelerators: A100 + ports: 30000 + +run: | + conda deactivate + python3 -m sglang.launch_server \ + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \ + --host 0.0.0.0 \ + --port 30000 +``` +
+ +```bash +# Deploy on any cloud or Kubernetes cluster. Use --cloud to select a specific cloud provider. +HF_TOKEN= sky launch -c sglang --env HF_TOKEN sglang.yaml + +# Get the HTTP API endpoint +sky status --endpoint 30000 sglang +``` +3. To further scale up your deployment with autoscaling and failure recovery, check out the [SkyServe + SGLang guide](https://github.com/skypilot-org/skypilot/tree/master/llm/sglang#serving-llama-2-with-sglang-for-more-traffic-using-skyserve). +
+ +## Troubleshooting + +- For FlashInfer issues on newer GPUs, use `--disable-flashinfer --disable-flashinfer-sampling` when launching the server. +- For out-of-memory errors, try `--mem-fraction-static 0.7` when launching the server. + +For more details and advanced usage, visit the [SGLang GitHub repository](https://github.com/sgl-project/sglang). \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index 826a34bc15..2f86ac9970 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,6 +7,4 @@ sphinx-tabs sphinxcontrib-mermaid pillow pydantic -torch -transformers urllib3<2.0.0 From 8d1095dbf0565cb7d6e5e3d10728a6542c8db6ae Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 9 Sep 2024 20:48:28 -0700 Subject: [PATCH 18/33] [Docs] Improve documentations (#1368) --- .gitignore | 3 + README.md | 19 ++-- docs/en/backend.md | 171 +++++++++++++++++++++++++++++++ docs/en/frontend.md | 239 ++++++++++++++++++++++++++++++++++++++++++++ docs/en/index.rst | 82 ++++----------- docs/en/install.md | 84 ++++++---------- 6 files changed, 474 insertions(+), 124 deletions(-) create mode 100644 docs/en/backend.md create mode 100644 docs/en/frontend.md diff --git a/.gitignore b/.gitignore index ca43e1ccba..14f5212ece 100644 --- a/.gitignore +++ b/.gitignore @@ -166,6 +166,9 @@ cython_debug/ # Vim *.swp +# Documentation +docs/en/_build + # SGL benchmark/mmlu/data benchmark/mmlu/data.tar diff --git a/README.md b/README.md index eb3099cf7a..92c5c2ec39 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,12 @@ SGLang is a fast serving framework for large language models and vision language models. It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. - The core features include: -- **Fast Backend Runtime**: Efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, and quantization (AWQ/FP8/GPTQ/Marlin). -- **Flexible Frontend Language**: Enables easy programming of LLM applications with chained generation calls, advanced prompting, control flow, multiple modalities, parallelism, and external interactions. + +- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ). +- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions. +- **Extensive Model Support**: Supports a wide range of generative models (Llama 3, Gemma 2, Mistral, QWen, DeepSeek, LLaVA, etc.) and embedding models (e5-mistral), with easy extensibility for integrating new models. +- **Active Community**: SGLang is open-source and backed by an active community with industry adoption, welcoming contributions to improve LLM and VLM serving. ## News - [2024/09] 🔥 SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). @@ -44,6 +46,8 @@ The core features include: ## Install +You can install SGLang using any of the methods below. + ### Method 1: With pip ``` pip install --upgrade pip @@ -67,7 +71,7 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` ### Method 3: Using docker -The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](docker). +The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). Replace `` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). ```bash @@ -218,6 +222,10 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --chunked-prefill-size 4096 ``` +- To enable torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes. +- To enable fp8 weight quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. +- To enable fp8 kv cache quanzation, you can add `--kv-cache-dtype fp8_e5m2`. +- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md). - Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port. ``` # Node 0 @@ -226,9 +234,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct # Node 1 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 ``` -- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md). -- To enable experimental torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes. -- To enable fp8 quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. ### Supported Models diff --git a/docs/en/backend.md b/docs/en/backend.md new file mode 100644 index 0000000000..af874f5933 --- /dev/null +++ b/docs/en/backend.md @@ -0,0 +1,171 @@ +## Backend: SGLang Runtime (SRT) +The SGLang Runtime (SRT) is an efficient serving engine. + +### Quick Start +Launch a server +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 +``` + +Send a request +``` +curl http://localhost:30000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Once upon a time,", + "sampling_params": { + "max_new_tokens": 16, + "temperature": 0 + } + }' +``` +Learn more about the argument format [here](docs/en/sampling_params.md). + +### OpenAI Compatible API +In addition, the server supports OpenAI-compatible APIs. + +```python +import openai +client = openai.Client( + base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + +# Text completion +response = client.completions.create( + model="default", + prompt="The capital of France is", + temperature=0, + max_tokens=32, +) +print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0, + max_tokens=64, +) +print(response) + +# Text embedding +response = client.embeddings.create( + model="default", + input="How are you today", +) +print(response) +``` + +It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/). + +### Additional Server Arguments +- Add `--tp 2` to enable multi-GPU tensor parallelism. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2 +``` +- Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2 +``` +- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --mem-fraction-static 0.7 +``` +- See [hyperparameter_tuning.md](docs/en/hyperparameter_tuning.md) on tuning hyperparameters for better performance. +- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --chunked-prefill-size 4096 +``` +- To enable torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes. +- To enable fp8 weight quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. +- To enable fp8 kv cache quanzation, you can add `--kv-cache-dtype fp8_e5m2`. +- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md). +- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port. +``` +# Node 0 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0 + +# Node 1 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 +``` + +### Supported Models + +**Generative Models** +- Llama / Llama 2 / Llama 3 / Llama 3.1 +- Mistral / Mixtral / Mistral NeMo +- Gemma / Gemma 2 +- Qwen / Qwen 2 / Qwen 2 MoE +- DeepSeek / DeepSeek 2 +- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) + - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava` + - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava` + - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) +- LLaVA 1.5 / 1.6 / NeXT + - `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3` + - `python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8 --chat-template=chatml-llava` + - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) +- Yi-VL +- StableLM +- Command-R +- DBRX +- Grok +- ChatGLM +- InternLM 2 +- Exaone 3 + +**Embedding Models** + +- e5-mistral +- gte-Qwen2 + - `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding` + +Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md). + +#### Use Models From ModelScope +
+More + +To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable SGLANG_USE_MODELSCOPE. +``` +export SGLANG_USE_MODELSCOPE=true +``` +Launch [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) Server +``` +SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 +``` + +
+ +#### Run Llama 3.1 405B +
+More + +```bash +# Run 405B (fp8) on a single node +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8 + +# Run 405B (fp16) on two nodes +## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port +GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph + +## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port +GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph +``` + +
+ +### Benchmark Performance + +- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. + Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. + A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, please use `sglang.bench_serving` instead. + ``` + python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32 + ``` +- Benchmark online serving. Launch a server first and run the following command. + ``` + python3 -m sglang.bench_serving --backend sglang --num-prompt 10 + ``` \ No newline at end of file diff --git a/docs/en/frontend.md b/docs/en/frontend.md new file mode 100644 index 0000000000..4f18939b3a --- /dev/null +++ b/docs/en/frontend.md @@ -0,0 +1,239 @@ +## Frontend: Structured Generation Language (SGLang) +The frontend language can be used with local models or API models. It is an alternative to the OpenAI API. You may found it easier to use for complex prompting workflow. + +### Quick Start +The example below shows how to use sglang to answer a mulit-turn question. + +#### Using Local Models +First, launch a server with +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 +``` + +Then, connect to the server and answer a multi-turn question. + +```python +from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint + +@function +def multi_turn_question(s, question_1, question_2): + s += system("You are a helpful assistant.") + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(RuntimeEndpoint("http://localhost:30000")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", +) + +for m in state.messages(): + print(m["role"], ":", m["content"]) + +print(state["answer_1"]) +``` + +#### Using OpenAI Models +Set the OpenAI API Key +``` +export OPENAI_API_KEY=sk-****** +``` + +Then, answer a multi-turn question. +```python +from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI + +@function +def multi_turn_question(s, question_1, question_2): + s += system("You are a helpful assistant.") + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(OpenAI("gpt-3.5-turbo")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", +) + +for m in state.messages(): + print(m["role"], ":", m["content"]) + +print(state["answer_1"]) +``` + +#### More Examples + +Anthropic and VertexAI (Gemini) models are also supported. +You can find more examples at [examples/quick_start](examples/frontend_language/quick_start). + +### Language Feature +To begin with, import sglang. +```python +import sglang as sgl +``` + +`sglang` provides some simple primitives such as `gen`, `select`, `fork`, `image`. +You can implement your prompt flow in a function decorated by `sgl.function`. +You can then invoke the function with `run` or `run_batch`. +The system will manage the state, chat template, parallelism and batching for you. + +The complete code for the examples below can be found at [readme_examples.py](examples/frontend_language/usage/readme_examples.py) + +#### Control Flow +You can use any Python code within the function body, including control flow, nested function calls, and external libraries. + +```python +@sgl.function +def tool_use(s, question): + s += "To answer this question: " + question + ". " + s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". " + + if s["tool"] == "calculator": + s += "The math expression is" + sgl.gen("expression") + elif s["tool"] == "search engine": + s += "The key word to search is" + sgl.gen("word") +``` + +#### Parallelism +Use `fork` to launch parallel prompts. +Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel. + +```python +@sgl.function +def tip_suggestion(s): + s += ( + "Here are two tips for staying healthy: " + "1. Balanced Diet. 2. Regular Exercise.\n\n" + ) + + forks = s.fork(2) + for i, f in enumerate(forks): + f += f"Now, expand tip {i+1} into a paragraph:\n" + f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") + + s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" + s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" + s += "In summary" + sgl.gen("summary") +``` + +#### Multi-Modality +Use `sgl.image` to pass an image as input. + +```python +@sgl.function +def image_qa(s, image_file, question): + s += sgl.user(sgl.image(image_file) + question) + s += sgl.assistant(sgl.gen("answer", max_tokens=256) +``` + +See also [srt_example_llava.py](examples/frontend_language/quick_start/local_example_llava_next.py). + +#### Constrained Decoding +Use `regex` to specify a regular expression as a decoding constraint. +This is only supported for local models. + +```python +@sgl.function +def regular_expression_gen(s): + s += "Q: What is the IP address of the Google DNS servers?\n" + s += "A: " + sgl.gen( + "answer", + temperature=0, + regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)", + ) +``` + +#### JSON Decoding +Use `regex` to specify a JSON schema with a regular expression. + +```python +character_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" + + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" + + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" + + r""" "wand": \{\n""" + + r""" "wood": "[\w\d\s]{1,16}",\n""" + + r""" "core": "[\w\d\s]{1,16}",\n""" + + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" + + r""" \},\n""" + + r""" "alive": "(Alive|Deceased)",\n""" + + r""" "patronus": "[\w\d\s]{1,16}",\n""" + + r""" "bogart": "[\w\d\s]{1,16}"\n""" + + r"""\}""" +) + +@sgl.function +def character_gen(s, name): + s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n" + s += sgl.gen("json_output", max_tokens=256, regex=character_regex) +``` + +See also [json_decode.py](examples/frontend_language/usage/json_decode.py) for an additional example of specifying formats with Pydantic models. + +#### Batching +Use `run_batch` to run a batch of requests with continuous batching. + +```python +@sgl.function +def text_qa(s, question): + s += "Q: " + question + "\n" + s += "A:" + sgl.gen("answer", stop="\n") + +states = text_qa.run_batch( + [ + {"question": "What is the capital of the United Kingdom?"}, + {"question": "What is the capital of France?"}, + {"question": "What is the capital of Japan?"}, + ], + progress_bar=True +) +``` + +#### Streaming +Add `stream=True` to enable streaming. + +```python +@sgl.function +def text_qa(s, question): + s += "Q: " + question + "\n" + s += "A:" + sgl.gen("answer", stop="\n") + +state = text_qa.run( + question="What is the capital of France?", + temperature=0.1, + stream=True +) + +for out in state.text_iter(): + print(out, end="", flush=True) +``` + +#### Roles + +Use `sgl.system`, `sgl.user` and `sgl.assistant` to set roles when using Chat models. You can also define more complex role prompts using begin and end tokens. + +```python +@sgl.function +def chat_example(s): + s += sgl.system("You are a helpful assistant.") + # Same as: s += s.system("You are a helpful assistant.") + + with s.user(): + s += "Question: What is the capital of France?" + + s += sgl.assistant_begin() + s += "Answer: " + sgl.gen(max_tokens=100, stop="\n") + s += sgl.assistant_end() +``` + +#### Tips and Implementation Details +- The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. +- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. diff --git a/docs/en/index.rst b/docs/en/index.rst index d1a96e8cb0..1c3e947c0c 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -1,80 +1,32 @@ -Welcome to SGLang! +SGLang Documentation ==================================== -.. figure:: ./_static/image/logo.png - :width: 50% - :align: center - :alt: SGLang - :class: no-scaled-link +SGLang is a fast serving framework for large language models and vision language models. +It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. +The core features include: -.. raw:: html +- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ). +- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions. +- **Extensive Model Support**: Supports a wide range of generative models (Llama 3, Gemma 2, Mistral, QWen, DeepSeek, LLaVA, etc.) and embedding models (e5-mistral), with easy extensibility for integrating new models. +- **Active Community**: SGLang is open-source and backed by an active community with industry adoption, welcoming contributions to improve LLM and VLM serving. -

- SGLang is yet another fast serving framework for large language models and vision language models. - -

-

- - Star - Watch - Fork -

- -SGLang has the following core features: - -* **Fast Backend Runtime**: Efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, flashinfer kernels, and quantization (AWQ/FP8/GPTQ/Marlin). - -* **Flexible Frontend Language**: Enables easy programming of LLM applications with chained generation calls, advanced prompting, control flow, multiple modalities, parallelism, and external interactions. - -* **Extensive Model Support**: SGLang supports a wide range of generative models including the Llama series (up to Llama 3.1), Mistral, Gemma, Qwen, DeepSeek, LLaVA, Yi-VL, StableLM, Command-R, DBRX, Grok, ChatGLM, InternLM 2 and Exaone 3. It also supports embedding models such as e5-mistral and gte-Qwen2. Easily extensible to support new models. - -* **Open Source Community**: SGLang is an open source project with a vibrant community of contributors. We welcome contributions from anyone interested in advancing the state of the art in LLM and VLM serving. - -Documentation -------------- - -.. In this documentation, we'll dive into these following areas to help you get the most out of SGLang. - -.. _installation: .. toctree:: :maxdepth: 1 - :caption: Installation + :caption: Getting Started install.md + backend.md + frontend.md -.. _hyperparameter_tuning: .. toctree:: :maxdepth: 1 - :caption: Hyperparameter Tuning + :caption: References + sampling_params.md hyperparameter_tuning.md - -.. _custom_chat_template: -.. toctree:: - :maxdepth: 1 - :caption: Custom Chat Template - - custom_chat_template.md - -.. _model_support: -.. toctree:: - :maxdepth: 1 - :caption: Model Support - model_support.md - -.. _sampling_params: -.. toctree:: - :maxdepth: 1 - :caption: Sampling Params - - sampling_params.md - - -.. _benchmark_and_profilling: -.. toctree:: - :maxdepth: 1 - :caption: Benchmark and Profilling - - benchmark_and_profiling.md \ No newline at end of file + contributor_guide.md + choices_methods.md + benchmark_and_profiling.md + troubleshooting.md diff --git a/docs/en/install.md b/docs/en/install.md index 877f69d680..656bc6840a 100644 --- a/docs/en/install.md +++ b/docs/en/install.md @@ -1,73 +1,56 @@ -# SGLang Installation Guide +## Install SGLang -SGLang consists of a frontend language (Structured Generation Language, SGLang) and a backend runtime (SGLang Runtime, SRT). The frontend can be used separately from the backend, allowing for a detached frontend-backend setup. +You can install SGLang using any of the methods below. -## Quick Installation Options - -### 1. Frontend Installation (Client-side, any platform) - -```bash -pip install --upgrade pip -pip install sglang +### Method 1: With pip ``` - -**Note: You can check [these examples](https://github.com/sgl-project/sglang/tree/main/examples/frontend_language/usage) for how to use frontend and backend separately.** - -### 2. Backend Installation (Server-side, Linux only) - -```bash pip install --upgrade pip pip install "sglang[all]" -pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ -``` - -**Note: The backend (SRT) is only needed on the server side and is only available for Linux right now.** - -**Important: Please check the [flashinfer installation guidance](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** -### 3. From Source (Latest version, Linux only for full installation) - -```bash -# Use the latest release branch -# As of this documentation, it's v0.2.15, but newer versions may be available -# Do not clone the main branch directly; always use a specific release version -# The main branch may contain unresolved bugs before a new release -git clone -b v0.2.15 https://github.com/sgl-project/sglang.git -cd sglang -pip install -e "python[all]" +# Install FlashInfer CUDA kernels pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` -### 4. OpenAI Backend Only (Client-side, any platform) - -If you only need to use the OpenAI backend, you can avoid installing other dependencies by using: - -```bash -pip install "sglang[openai]" +### Method 2: From source ``` +# Use the last release branch +git clone -b v0.3.0 https://github.com/sgl-project/sglang.git +cd sglang -## Advanced Installation Options +pip install --upgrade pip +pip install -e "python[all]" -### 1. Using Docker (Server-side, Linux only) +# Install FlashInfer CUDA kernels +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ +``` -The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/blob/main/docker). Replace `` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). +### Method 3: Using docker +The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). +Replace `` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). ```bash -docker run --gpus all -p 30000:30000 \ +docker run --gpus all \ + -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ - --env "HF_TOKEN=" --ipc=host \ + --env "HF_TOKEN=" \ + --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 ``` -### 2.Using docker compose +### Method 4: Using docker compose + +
+More -This method is recommended if you plan to serve it as a service. A better approach is to use the [k8s-sglang-service.yaml](https://github.com/sgl-project/sglang/blob/main/docker/k8s-sglang-service.yaml). +> This method is recommended if you plan to serve it as a service. +> A better approach is to use the [k8s-sglang-service.yaml](./docker/k8s-sglang-service.yaml). -1. Copy the [compose.yml](https://github.com/sgl-project/sglang/blob/main/docker/compose.yaml) to your local machine +1. Copy the [compose.yml](./docker/compose.yaml) to your local machine 2. Execute the command `docker compose up -d` in your terminal. +
-### 3.Run on Kubernetes or Clouds with SkyPilot +### Method 5: Run on Kubernetes or Clouds with SkyPilot
More @@ -108,9 +91,6 @@ sky status --endpoint 30000 sglang 3. To further scale up your deployment with autoscaling and failure recovery, check out the [SkyServe + SGLang guide](https://github.com/skypilot-org/skypilot/tree/master/llm/sglang#serving-llama-2-with-sglang-for-more-traffic-using-skyserve).
-## Troubleshooting - -- For FlashInfer issues on newer GPUs, use `--disable-flashinfer --disable-flashinfer-sampling` when launching the server. -- For out-of-memory errors, try `--mem-fraction-static 0.7` when launching the server. - -For more details and advanced usage, visit the [SGLang GitHub repository](https://github.com/sgl-project/sglang). \ No newline at end of file +### Common Notes +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. +- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. From fec2d1223c82f5701a384030c842dc92e0543e22 Mon Sep 17 00:00:00 2001 From: wangchao Date: Tue, 10 Sep 2024 16:17:37 +0800 Subject: [PATCH 19/33] [Fix] fix bug of `undefined is_single` in meth `create_abort_task` (#1370) --- python/sglang/srt/managers/io_struct.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8e53df3355..f5279eb8d0 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -53,6 +53,8 @@ class GenerateReqInput: # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None + is_single: bool = True + def post_init(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None @@ -194,6 +196,8 @@ class EmbeddingReqInput: # Dummy sampling params for compatibility sampling_params: Union[List[Dict], Dict] = None + is_single: bool = True + def post_init(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None From e72275cf7f6f9783cbd6031a1dcfd93bd45e40da Mon Sep 17 00:00:00 2001 From: William Date: Tue, 10 Sep 2024 17:57:52 +0800 Subject: [PATCH 20/33] Support MiniCPM3 (#1371) --- README.md | 1 + python/sglang/srt/layers/decode_attention.py | 5 +- python/sglang/srt/layers/extend_attention.py | 5 +- python/sglang/srt/model_config.py | 5 + python/sglang/srt/models/minicpm3.py | 669 +++++++++++++++++++ 5 files changed, 683 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/models/minicpm3.py diff --git a/README.md b/README.md index 92c5c2ec39..8af73c49d2 100644 --- a/README.md +++ b/README.md @@ -259,6 +259,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - ChatGLM - InternLM 2 - Exaone 3 +- MiniCPM / MiniCPM 3 **Embedding Models** diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/decode_attention.py index 9c9822b852..ebf29cc592 100644 --- a/python/sglang/srt/layers/decode_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -483,11 +483,14 @@ def _decode_grouped_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 96, 128, 256, 576} + assert Lk in {16, 32, 64, 96, 128, 256, 576, 288} if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 else: BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 8880622854..5c8e51c5fe 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -280,12 +280,15 @@ def extend_attention_fwd( assert Lq == Lk and Lv == Lo # TODO: is the assertion necessary? - assert Lq in {16, 32, 64, 96, 128, 256, 576} + assert Lq in {16, 32, 64, 96, 128, 256, 576, 288} assert Lv in {16, 32, 64, 96, 128, 256, 512} if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 + elif Lq == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index edf89f6b97..14702f0b5d 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -64,6 +64,11 @@ def __init__( self.attention_arch = AttentionArch.MLA self.kv_lora_rank = self.hf_config.kv_lora_rank self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + elif "MiniCPM3ForCausalLM" in self.hf_config.architectures: + self.head_dim = 128 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim else: self.attention_arch = AttentionArch.MHA diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py new file mode 100644 index 0000000000..c559d0f310 --- /dev/null +++ b/python/sglang/srt/models/minicpm3.py @@ -0,0 +1,669 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Inference-only MiniCPM3 model compatible with HuggingFace weights.""" + +import math +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from flashinfer import bmm_fp8 +from torch import nn +from transformers import PretrainedConfig +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import InputMetadata + + +class MiniCPM3MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +def input_to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +class MiniCPM3Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_id=None, + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + # TODO support head_size 96 + self.attn = RadixAttention( + self.num_local_heads, + 128, + self.scaling, + num_kv_heads=self.num_local_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank :] + original_shapes = [q_pe.shape, k_pe.shape] + q_pe, k_pe = self.rotary_emb( + positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1) + ) + q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1]) + q[..., self.qk_nope_head_dim :] = q_pe + k = torch.empty_like(q) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view( + -1, self.num_local_heads * 128 + ) + k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view( + -1, self.num_local_heads * 128 + ) + v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view( + -1, self.num_local_heads * 128 + ) + attn_output = self.attn(q, k, v, input_metadata) + attn_output = attn_output.view(-1, self.num_local_heads, 128)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPM3AttentionMLA(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_id=None, + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + self.attn = RadixAttention( + self.num_local_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + self.scaling, + num_kv_heads=1, + layer_id=layer_id, + v_head_dim=self.kv_lora_rank, + ) + + self.w_kc = None + self.w_vc = None + self.w_scale = None + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + q_len = hidden_states.shape[0] + q_input = hidden_states.new_empty( + q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim + ) + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + if self.w_kc.dtype == torch.float8_e4m3fn: + q_nope_val, q_nope_scale = input_to_float8( + q_nope.transpose(0, 1), torch.float8_e4m3fn + ) + q_nope_out = bmm_fp8( + q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 + ) + else: + q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) + q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) + + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + v_input = latent_cache[..., : self.kv_lora_rank] + v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) + k_input = latent_cache.unsqueeze(1) + k_input[..., : self.kv_lora_rank] = v_input + k_pe = k_input[..., self.kv_lora_rank :] + + original_shapes = [q_pe.shape, k_pe.shape] + q_pe, k_pe = self.rotary_emb( + positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1) + ) + q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1]) + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe + + attn_output = self.attn(q_input, k_input, v_input, input_metadata) + attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) + + if self.w_vc.dtype == torch.float8_e4m3fn: + attn_output_val, attn_output_scale = input_to_float8( + attn_output.transpose(0, 1), torch.float8_e4m3fn + ) + attn_bmm_output = bmm_fp8( + attn_output_val, + self.w_vc, + attn_output_scale, + self.w_scale, + torch.bfloat16, + ) + else: + attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) + attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) + output, _ = self.o_proj(attn_output) + + return output + + +class MiniCPM3DecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + if global_server_args_dict["enable_mla"]: + self.self_attn = MiniCPM3AttentionMLA( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=self.hidden_size // config.num_attention_heads, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + layer_id=layer_id, + ) + else: + self.self_attn = MiniCPM3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=self.hidden_size // config.num_attention_heads, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + layer_id=layer_id, + ) + self.mlp = MiniCPM3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) + ) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) + ) + + return hidden_states, None + + +class MiniCPM3Model(nn.Module): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList( + [ + MiniCPM3DecoderLayer( + config, i, cache_config=cache_config, quant_config=quant_config + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) * self.config.scale_emb + else: + hidden_states = input_embeds + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + input_metadata, + residual, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MiniCPM3ForCausalLM(nn.Module): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + + self.num_experts = getattr(self.config, "num_experts", 0) + self.quant_config = quant_config + self.model = MiniCPM3Model( + config, cache_config=cache_config, quant_config=quant_config + ) + # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if not self.config.tie_word_embeddings: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.scale_width = self.config.hidden_size / self.config.dim_model_base + + self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is not None: + input_embeds = input_embeds * self.config.scale_emb + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = hidden_states / self.scale_width + if self.config.tie_word_embeddings: + lm_head_weight = self.model.embed_tokens.weight + else: + lm_head_weight = self.lm_head.weight + logits_output = self.logits_processor( + input_ids, hidden_states, lm_head_weight, input_metadata + ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ( + "ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, loaded_weight, weight_name, expert_id=expert_id + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + if global_server_args_dict["enable_mla"]: + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if hasattr(self_attn.kv_b_proj, "weight_scale"): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + del self_attn.kv_b_proj + + +EntryClass = MiniCPM3ForCausalLM From dff2860a690757966e408b598a8f0b47a29a4713 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Wed, 11 Sep 2024 00:35:03 +0800 Subject: [PATCH 21/33] Fix CORS compatibility with OpenAI, vLLM, TGI, LMDeploy (#1373) Co-authored-by: Yineng Zhang --- python/sglang/srt/server.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d44d617522..b73a01265f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -37,6 +37,7 @@ import uvicorn import uvloop from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint @@ -93,6 +94,14 @@ app = FastAPI() tokenizer_manager = None +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + @app.get("/health") async def health() -> Response: From 6c7cb903655d4b8523e45838e597c11e10a6600f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 10 Sep 2024 11:27:03 -0700 Subject: [PATCH 22/33] [Minor] improve kill scripts and torchao import (#1375) --- python/sglang/srt/layers/torchao_utils.py | 14 ++++++++------ test/killall_sglang.sh | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 16eb1f2c5c..bc7bde86ef 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -3,15 +3,17 @@ """ import torch -from torchao.quantization import ( - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, - quantize_, -) def torchao_quantize_param_data(param, torchao_config): + # Lazy import to suppress some warnings + from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, + ) + dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) dummy_linear.weight = param if "int8wo" in torchao_config: diff --git a/test/killall_sglang.sh b/test/killall_sglang.sh index 0e2cb82a81..c536548d4d 100644 --- a/test/killall_sglang.sh +++ b/test/killall_sglang.sh @@ -1 +1 @@ -kill -9 $(ps aux | grep 'sglang' | grep -v 'grep' | awk '{print $2}') +kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') From fbb4754cb8c6585763ab631231508e84e6c287e2 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 10 Sep 2024 13:10:36 -0700 Subject: [PATCH 23/33] Fix vocab mask update bug (#1376) --- python/sglang/srt/managers/schedule_batch.py | 2 - .../srt/model_executor/forward_batch_info.py | 3 +- .../srt/sampling/sampling_batch_info.py | 46 +++++++++++-------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6c6b7f8426..2e2489cd29 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -652,8 +652,6 @@ def prepare_for_decode(self, input_ids=None): self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc - self.sampling_info.update_regex_vocab_mask(self) - def filter_batch(self, unfinished_indices: List[int]): if unfinished_indices is None or len(unfinished_indices) == 0: # Filter out all requests diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a6ad63ce18..c1fb233579 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -195,7 +195,8 @@ def from_schedule_batch( top_logprobs_nums=batch.top_logprobs_nums, ) - ret.sampling_info.prepare_penalties() + ret.sampling_info.update_penalties() + ret.sampling_info.update_regex_vocab_mask(batch) ret.compute_positions(batch) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 20b1968d24..622f27df11 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -34,6 +34,9 @@ class SamplingBatchInfo: linear_penalties: torch.Tensor = None scaling_penalties: torch.Tensor = None + def __len__(self): + return len(self.temperatures) + def can_run_in_cuda_graph(self): # Vocab bias and min_ps are not supported in CUDA graph return ( @@ -118,11 +121,9 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): # Handle logit bias but only allocate when needed ret.logit_bias = None - ret.update_regex_vocab_mask(batch) - return ret - def prepare_penalties(self): + def update_penalties(self): self.scaling_penalties = None self.linear_penalties = None @@ -174,6 +175,26 @@ def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor): if self_val is not None: # logit_bias can be None setattr(self, item, self_val[new_indices]) + @staticmethod + def merge_bias_tensor( + lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0 + ): + # bias tensor can be None + if lhs is not None or rhs is not None: + shape, dtype = None, None + if lhs is not None: + shape, dtype = lhs.shape[1:], lhs.dtype + else: + shape, dtype = rhs.shape[1:], rhs.dtype + with torch.dtype(dtype): + if lhs is None: + lhs = torch.empty((bs1, *shape), device="cuda").fill_(default) + if rhs is None: + rhs = torch.empty((bs2, *shape), device="cuda").fill_(default) + return torch.cat([lhs, rhs]) + + return None + def merge(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) @@ -187,19 +208,6 @@ def merge(self, other: "SamplingBatchInfo"): other_val = getattr(other, item, None) setattr(self, item, torch.concat([self_val, other_val])) - # logit_bias can be None - if self.logit_bias is not None or other.logit_bias is not None: - vocab_size = ( - self.logit_bias.shape[1] - if self.logit_bias is not None - else other.logit_bias.shape[1] - ) - if self.logit_bias is None: - self.logit_bias = torch.zeros( - (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - if other.logit_bias is None: - other.logit_bias = torch.zeros( - (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) + self.logit_bias = SamplingBatchInfo.merge_bias_tensor( + self.logit_bias, other.logit_bias, len(self), len(other) + ) From 3a6e8b6d78d8d33b5c241b4d95f531ac20e31964 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 10 Sep 2024 15:15:08 -0700 Subject: [PATCH 24/33] [Minor] move triton attention kernels into a separate folder (#1379) --- python/sglang/bench_latency.py | 2 +- python/sglang/srt/{ => configs}/model_config.py | 0 python/sglang/srt/constrained/__init__.py | 2 ++ python/sglang/srt/conversation.py | 2 +- python/sglang/srt/hf_transformers_utils.py | 4 +--- python/sglang/srt/layers/radix_attention.py | 15 +++++++++++---- .../{ => triton_attention}/decode_attention.py | 0 .../{ => triton_attention}/extend_attention.py | 2 +- .../{ => triton_attention}/prefill_attention.py | 0 python/sglang/srt/managers/tp_worker.py | 2 +- .../srt/model_executor/forward_batch_info.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 3 +-- scripts/deprecated/test_flashinfer.py | 5 ++++- 13 files changed, 24 insertions(+), 15 deletions(-) rename python/sglang/srt/{ => configs}/model_config.py (100%) rename python/sglang/srt/layers/{ => triton_attention}/decode_attention.py (100%) rename python/sglang/srt/layers/{ => triton_attention}/extend_attention.py (99%) rename python/sglang/srt/layers/{ => triton_attention}/prefill_attention.py (100%) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index be67958460..bfe7394322 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -57,9 +57,9 @@ import torch import torch.distributed as dist +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch -from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/configs/model_config.py similarity index 100% rename from python/sglang/srt/model_config.py rename to python/sglang/srt/configs/model_config.py diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index 7e097c6fc2..c47c5c8dd5 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -13,6 +13,8 @@ limitations under the License. """ +"""For constrained decoding.""" + import json from typing import Dict, Optional, Union diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 9a1227218b..341551eca3 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -13,7 +13,7 @@ limitations under the License. """ -"""Conversation templates.""" +"""Conversation chat templates.""" # Adapted from # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index ae3070c5a7..f6c414ec3a 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -16,11 +16,9 @@ """Utilities for Huggingface Transformers.""" import contextlib -import functools -import json import os import warnings -from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union +from typing import Dict, Optional, Type, Union from huggingface_hub import snapshot_download from transformers import ( diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 1a2feacd3d..adada7cdaf 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -22,13 +22,20 @@ from torch import nn from sglang.global_config import global_config -from sglang.srt.layers.decode_attention import decode_attention_fwd -from sglang.srt.layers.extend_attention import extend_attention_fwd +from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd +from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.model_runner import global_server_args_dict class RadixAttention(nn.Module): + """ + The attention layer implementation. + Now it has two backends: FlashInfer and Triton. + FlashInfer is faster and Triton is easier to customize. + It supports two operators: extend (i.e. prefill with cached prefix) and decode. + """ + def __init__( self, num_heads: int, @@ -49,8 +56,10 @@ def __init__( self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id + self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 self.sliding_window_size = sliding_window_size if sliding_window_size else -1 + # Choose backend if ( not global_server_args_dict.get("disable_flashinfer", False) and self.qk_head_dim == self.v_head_dim @@ -61,8 +70,6 @@ def __init__( self.extend_forward = self.extend_forward_triton self.decode_forward = self.decode_forward_triton - self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 - def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): if self.qk_head_dim != self.v_head_dim: o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/triton_attention/decode_attention.py similarity index 100% rename from python/sglang/srt/layers/decode_attention.py rename to python/sglang/srt/layers/triton_attention/decode_attention.py diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/triton_attention/extend_attention.py similarity index 99% rename from python/sglang/srt/layers/extend_attention.py rename to python/sglang/srt/layers/triton_attention/extend_attention.py index 5c8e51c5fe..81039e6760 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/triton_attention/extend_attention.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -from sglang.srt.layers.prefill_attention import context_attention_fwd +from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd CUDA_CAPABILITY = torch.cuda.get_device_capability() diff --git a/python/sglang/srt/layers/prefill_attention.py b/python/sglang/srt/layers/triton_attention/prefill_attention.py similarity index 100% rename from python/sglang/srt/layers/prefill_attention.py rename to python/sglang/srt/layers/triton_attention/prefill_attention.py diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 736929a654..fe7c4bcabe 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ import torch.distributed as dist from sglang.global_config import global_config +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer @@ -52,7 +53,6 @@ ) from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache -from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index c1fb233579..867bd95a1f 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -15,7 +15,7 @@ limitations under the License. """ -"""ModelRunner runs the forward passes of the models.""" +"""Meta data for a forward pass.""" from dataclasses import dataclass from enum import IntEnum, auto from typing import TYPE_CHECKING, List diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3cb123c482..3033a7ce46 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -18,7 +18,6 @@ import gc import importlib import importlib.resources -import json import logging import pkgutil from functools import lru_cache @@ -45,6 +44,7 @@ from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config +from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict @@ -53,7 +53,6 @@ MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_config import AttentionArch, ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( diff --git a/scripts/deprecated/test_flashinfer.py b/scripts/deprecated/test_flashinfer.py index 638647677e..7f0a081f6c 100644 --- a/scripts/deprecated/test_flashinfer.py +++ b/scripts/deprecated/test_flashinfer.py @@ -6,8 +6,11 @@ ) from flashinfer.decode import _grouped_size_compiled_for_decode_kernels -from sglang.srt.layers.extend_attention import extend_attention_fwd, redundant_attention from sglang.srt.layers.token_attention import token_attention_fwd +from sglang.srt.layers.triton_attention.extend_attention import ( + extend_attention_fwd, + redundant_attention, +) flashinfer_prefill_wrapper = None flashinfer_decode_wrapper = None From 46094e0c1b9c81a1f12f356472af694d9ef613cc Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 10 Sep 2024 17:11:16 -0700 Subject: [PATCH 25/33] Deprecate --disable-flashinfer and introduce --attention-backend (#1380) --- README.md | 2 +- docs/en/install.md | 2 +- python/sglang/srt/layers/radix_attention.py | 8 +- python/sglang/srt/layers/sampler.py | 8 +- python/sglang/srt/managers/schedule_batch.py | 10 ++- python/sglang/srt/managers/tp_worker.py | 2 +- .../srt/model_executor/forward_batch_info.py | 6 +- .../sglang/srt/model_executor/model_runner.py | 23 +++--- python/sglang/srt/server.py | 2 +- python/sglang/srt/server_args.py | 75 +++++++++++++------ test/srt/test_moe_serving_throughput.py | 9 +-- test/srt/test_serving_throughput.py | 11 ++- test/srt/test_triton_attn_backend.py | 2 +- 13 files changed, 99 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 8af73c49d2..7ebada73d6 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ sky status --endpoint 30000 sglang ### Common Notes -- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please disable it by adding `--disable-flashinfer --disable-flashinfer-sampling` and open an issue on GitHub. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. ## Backend: SGLang Runtime (SRT) diff --git a/docs/en/install.md b/docs/en/install.md index 656bc6840a..60645ce849 100644 --- a/docs/en/install.md +++ b/docs/en/install.md @@ -92,5 +92,5 @@ sky status --endpoint 30000 sglang
### Common Notes -- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please disable it by adding `--disable-flashinfer --disable-flashinfer-sampling` and open an issue on GitHub. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index adada7cdaf..48567e43d4 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -61,14 +61,18 @@ def __init__( # Choose backend if ( - not global_server_args_dict.get("disable_flashinfer", False) + global_server_args_dict["attention_backend"] == "flashinfer" and self.qk_head_dim == self.v_head_dim ): self.extend_forward = self.extend_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer - else: + elif global_server_args_dict["attention_backend"] == "triton": self.extend_forward = self.extend_forward_triton self.decode_forward = self.decode_forward_triton + else: + raise ValueError( + f"Invalid attention backend: {global_server_args_dict['attention_backend']}" + ) def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): if self.qk_head_dim != self.v_head_dim: diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 6cb7d5b550..16b6b80e9c 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -78,7 +78,7 @@ def forward_cuda( probs = self._get_probs(logits, sampling_info) - if not global_server_args_dict["disable_flashinfer_sampling"]: + if global_server_args_dict["sampling_backend"] == "flashinfer": max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device @@ -93,11 +93,15 @@ def forward_cuda( batch_next_token_ids, success = flashinfer_top_k_top_p( probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps ) - else: + elif global_server_args_dict["sampling_backend"] == "pytorch": # Here we provide a slower fallback implementation. batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps ) + else: + raise ValueError( + f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" + ) return SampleOutput(success, probs, batch_next_token_ids) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2e2489cd29..e46177bdbc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -31,6 +31,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: from sglang.srt.layers.sampler import SampleOutput @@ -40,10 +41,11 @@ # Put some global args for easy access global_server_args_dict = { - "disable_flashinfer": False, - "disable_flashinfer_sampling": False, - "triton_attention_reduce_in_fp32": False, - "enable_mla": False, + "attention_backend": ServerArgs.attention_backend, + "sampling_backend": ServerArgs.sampling_backend, + "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, + "enable_mla": ServerArgs.enable_mla, + "torchao_config": ServerArgs.torchao_config, } diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fe7c4bcabe..b1131b011f 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -128,7 +128,7 @@ def __init__( if server_args.max_running_requests is None else server_args.max_running_requests ), - self.model_runner.req_to_token_pool.size - 1, + self.model_runner.req_to_token_pool.size, ) self.max_req_input_len = min( self.model_config.context_len - 1, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 867bd95a1f..c158b3ce23 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -203,17 +203,17 @@ def from_schedule_batch( ret.compute_extend_infos(batch) fm = batch.forward_mode - if not fm.is_decode() or model_runner.server_args.disable_flashinfer: + if not fm.is_decode() or model_runner.server_args.attention_backend == "triton": ret.total_num_tokens = int(torch.sum(ret.seq_lens)) if not fm.is_decode(): ret.init_multimuldal_info(batch) - if model_runner.server_args.disable_flashinfer: + if model_runner.server_args.attention_backend == "triton": ret.init_triton_args(batch) flashinfer_use_ragged = False - if not model_runner.server_args.disable_flashinfer: + if model_runner.server_args.attention_backend == "flashinfer": if ( not fm.is_decode() and int(torch.sum(ret.seq_lens)) > 4096 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3033a7ce46..b04b0d7c01 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -53,7 +53,7 @@ MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, @@ -92,8 +92,8 @@ def __init__( ) global_server_args_dict.update( { - "disable_flashinfer": server_args.disable_flashinfer, - "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, + "attention_backend": server_args.attention_backend, + "sampling_backend": server_args.sampling_backend, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "enable_mla": server_args.enable_mla, "torchao_config": server_args.torchao_config, @@ -111,7 +111,7 @@ def __init__( self.load_model() self.init_memory_pool( min_per_gpu_memory, - server_args.max_num_reqs, + server_args.max_running_requests, server_args.max_total_tokens, ) self.init_cublas() @@ -344,8 +344,8 @@ def profile_max_num_token(self, total_gpu_memory: int): def init_memory_pool( self, total_gpu_memory: int, - max_num_reqs: int = None, - max_total_tokens: int = None, + max_num_reqs: Optional[int] = None, + max_total_tokens: Optional[int] = None, ): if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype @@ -379,7 +379,7 @@ def init_memory_pool( ), 2048, ), - 5120, + 4096, ) self.req_to_token_pool = ReqToTokenPool( @@ -399,7 +399,7 @@ def init_memory_pool( ) logger.info("using MLA Triton implementaion, flashinfer is disabled") # FIXME: temporarily only Triton MLA is supported - self.server_args.disable_flashinfer = True + self.server_args.attention_backend = "triton" else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, @@ -424,7 +424,7 @@ def init_cublas(self): def init_flashinfer(self): """Init flashinfer attention kernel wrappers.""" - if self.server_args.disable_flashinfer: + if self.server_args.attention_backend != "flashinfer": assert ( self.sliding_window_size is None ), "turn on flashinfer to support window attention" @@ -491,7 +491,10 @@ def init_cuda_graphs(self): from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner - if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: + if ( + self.server_args.disable_cuda_graph + or self.server_args.attention_backend != "flashinfer" + ): self.cuda_graph_runner = None return diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b73a01265f..4aaf018a1b 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -425,7 +425,7 @@ def _set_envs_and_config(server_args: ServerArgs): maybe_set_triton_cache_manager() # Check flashinfer version - if not server_args.disable_flashinfer: + if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer", "0.1.6", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3dfb1dc411..0881344c08 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -17,7 +17,6 @@ import argparse import dataclasses -import json import logging import random from typing import List, Optional, Union @@ -50,7 +49,6 @@ class ServerArgs: # Memory and scheduling mem_fraction_static: Optional[float] = None max_running_requests: Optional[int] = None - max_num_reqs: Optional[int] = None max_total_tokens: Optional[int] = None chunked_prefill_size: int = 8192 max_prefill_tokens: int = 16384 @@ -85,6 +83,9 @@ class ServerArgs: json_model_override_args: str = "{}" # Optimization/debug options + attention_backend: str = "flashinfer" + sampling_backend: str = "flashinfer" + disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False @@ -101,6 +102,7 @@ class ServerArgs: triton_attention_reduce_in_fp32: bool = False def __post_init__(self): + # Set missing default values if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -111,6 +113,7 @@ def __post_init__(self): # Disable chunked prefill self.chunked_prefill_size = None + # Mem fraction depends on the tensor parallelism size if self.mem_fraction_static is None: if self.tp_size >= 16: self.mem_fraction_static = 0.79 @@ -131,6 +134,29 @@ def __post_init__(self): if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) + # Deprecation warnings + if self.disable_flashinfer: + logger.warning( + "The option '--disable-flashinfer' will be deprecated in the next release. " + "Please use '--attention-backend triton' instead." + ) + if self.disable_flashinfer_sampling: + logger.warning( + "The option '--disable-flashinfer-sampling' will be deprecated in the next release. " + "Please use '--sampling-backend pytorch' instead. " + ) + + # Model-specific patches + if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: + logger.info( + "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" + ) + self.trust_remote_code = False + + if "gemma-2" in self.model_path.lower(): + logger.info("When using sliding window in gemma-2, turn on flashinfer.") + self.attention_backend = "flashinfer" + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -214,11 +240,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) - parser.add_argument( - "--is-embedding", - action="store_true", - help="Whether to use a CausalLM as an embedding model.", - ) parser.add_argument( "--context-length", type=int, @@ -253,6 +274,11 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.chat_template, help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", ) + parser.add_argument( + "--is-embedding", + action="store_true", + help="Whether to use a CausalLM as an embedding model.", + ) parser.add_argument( "--mem-fraction-static", type=float, @@ -265,17 +291,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.max_running_requests, help="The maximum number of running requests.", ) - parser.add_argument( - "--max-num-reqs", - type=int, - default=ServerArgs.max_num_reqs, - help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", - ) parser.add_argument( "--max-total-tokens", type=int, default=ServerArgs.max_total_tokens, - help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.", + help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -395,15 +416,29 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # Optimization/debug options + parser.add_argument( + "--attention-backend", + type=str, + choices=["flashinfer", "triton"], + default=ServerArgs.attention_backend, + help="Choose the kernels for attention layers.", + ) + parser.add_argument( + "--sampling-backend", + type=str, + choices=["flashinfer", "pytorch"], + default=ServerArgs.sampling_backend, + help="Choose the kernels for sampling layers.", + ) parser.add_argument( "--disable-flashinfer", action="store_true", - help="Disable flashinfer attention kernels.", + help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.", ) parser.add_argument( "--disable-flashinfer-sampling", action="store_true", - help="Disable flashinfer sampling kernels.", + help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.", ) parser.add_argument( "--disable-radix-cache", @@ -491,14 +526,6 @@ def check_server_args(self): assert not ( self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" - if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: - logger.info( - "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" - ) - self.trust_remote_code = False - if "gemma-2" in self.model_path.lower(): - logger.info("When using sliding window in gemma-2, turn on flashinfer.") - self.disable_flashinfer = False def prepare_server_args(argv: List[str]) -> ServerArgs: diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index 2acf626c1c..e0c851d4ed 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -14,13 +14,12 @@ class TestServingThroughput(unittest.TestCase): - def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size): + def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size): # Launch the server other_args = [] if disable_radix_cache: other_args.append("--disable-radix-cache") - if disable_flashinfer: - other_args.append("--disable-flashinfer") + other_args.extend(["--attention-backend", attention_backend]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--tensor-parallel-size", "2"]) @@ -70,7 +69,7 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size def test_default(self): res = self.run_test( disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=ServerArgs.chunked_prefill_size, ) @@ -80,7 +79,7 @@ def test_default(self): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=ServerArgs.chunked_prefill_size, ) diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index d4ed12612a..1b458e9e6a 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -14,13 +14,12 @@ class TestServingThroughput(unittest.TestCase): - def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size): + def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size): # Launch the server other_args = [] if disable_radix_cache: other_args.append("--disable-radix-cache") - if disable_flashinfer: - other_args.append("--disable-flashinfer") + other_args.extend(["--attention-backend", attention_backend]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) model = DEFAULT_MODEL_NAME_FOR_TEST @@ -69,7 +68,7 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size def test_default(self): res = self.run_test( disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=ServerArgs.chunked_prefill_size, ) @@ -79,7 +78,7 @@ def test_default(self): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=ServerArgs.chunked_prefill_size, ) @@ -89,7 +88,7 @@ def test_default_without_radix_cache(self): def test_default_without_chunked_prefill(self): res = self.run_test( disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=-1, ) diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index b3f65ac13a..9c6519d911 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -20,7 +20,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--disable-flashinfer"], + other_args=["--attention-backend", "triton"], ) @classmethod From 144bc70fcceede77fc2c2fbd286676b57f9a0c94 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 10 Sep 2024 17:38:59 -0700 Subject: [PATCH 26/33] Organize flashinfer indices update (#1378) --- python/sglang/srt/layers/flashinfer_utils.py | 237 ++++++++++++++++++ python/sglang/srt/managers/schedule_batch.py | 4 + .../srt/model_executor/cuda_graph_runner.py | 7 +- .../srt/model_executor/forward_batch_info.py | 200 +-------------- test/srt/test_create_kvindices.py | 4 +- 5 files changed, 252 insertions(+), 200 deletions(-) create mode 100644 python/sglang/srt/layers/flashinfer_utils.py diff --git a/python/sglang/srt/layers/flashinfer_utils.py b/python/sglang/srt/layers/flashinfer_utils.py new file mode 100644 index 0000000000..1f9ab15145 --- /dev/null +++ b/python/sglang/srt/layers/flashinfer_utils.py @@ -0,0 +1,237 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + max_context_len, + kv_indices_ptr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + req_to_token_ptr += req_pool_index * max_context_len + kv_indices_ptr += kv_indices_offset + + ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) + st_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = ld_offset < kv_end + data = tl.load(req_to_token_ptr + ld_offset, mask=mask) + tl.store(kv_indices_ptr + st_offset, data, mask=mask) + ld_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE + + +class FlashinferUpdater: + def __init__( + self, + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper=None, + flashinfer_use_ragged=False, + ): + self.forward_mode = forward_mode + self.model_runner = model_runner + self.req_pool_indices = req_pool_indices + self.seq_lens = seq_lens + self.prefix_lens = prefix_lens + self.flashinfer_use_ragged = flashinfer_use_ragged + + self.num_qo_heads = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.tp_size + ) + self.head_dim = model_runner.model_config.head_dim + self.batch_size = len(req_pool_indices) + + self.kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" + ) + + ( + self.flashinfer_decode_wrapper, + self.flashinfer_prefill_wrapper_ragged, + self.flashinfer_prefill_wrapper_paged, + ) = ( + flashinfer_decode_wrapper, + self.model_runner.flashinfer_prefill_wrapper_ragged, + self.model_runner.flashinfer_prefill_wrapper_paged, + ) + # CUDA graph uses different flashinfer_decode_wrapper + if self.flashinfer_decode_wrapper is None: + self.flashinfer_decode_wrapper = self.model_runner.flashinfer_decode_wrapper + + def _init_indices_no_window(self): + if self.flashinfer_use_ragged: + paged_kernel_lens = self.prefix_lens + else: + paged_kernel_lens = self.seq_lens + + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + self.kv_indices = torch.empty( + self.kv_indptr[-1], dtype=torch.int32, device="cuda" + ) + + create_flashinfer_kv_indices_triton[(self.batch_size,)]( + self.model_runner.req_to_token_pool.req_to_token, + self.req_pool_indices, + paged_kernel_lens, + self.kv_indptr, + None, + self.model_runner.req_to_token_pool.req_to_token.size(1), + self.kv_indices, + ) + + def _init_indices_window(self, wrapper_id): + # window attention use paged only + if wrapper_id == 0: + if self.forward_mode.is_decode(): + paged_kernel_lens = torch.minimum( + self.seq_lens, + torch.tensor(self.model_runner.sliding_window_size + 1), + ) + else: + paged_kernel_lens = torch.minimum( + self.seq_lens, + torch.tensor(self.model_runner.sliding_window_size) + + self.seq_lens + - self.prefix_lens, + ) + else: + paged_kernel_lens = self.seq_lens + + kv_start_idx = self.seq_lens - paged_kernel_lens + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + self.kv_indices = torch.empty( + self.kv_indptr[-1], dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(self.batch_size,)]( + self.model_runner.req_to_token_pool.req_to_token, + self.req_pool_indices, + paged_kernel_lens, + self.kv_indptr, + kv_start_idx, + self.model_runner.req_to_token_pool.req_to_token.size(1), + self.kv_indices, + ) + + def _update_decode_indices(self, decode_wrapper): + decode_wrapper.end_forward() + decode_wrapper.begin_forward( + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.model_runner.kv_cache_dtype, + q_data_type=self.model_runner.dtype, + ) + + def _update_extend_indices(self, ragged_wrapper, paged_wrapper): + # extend part + qo_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) + + if self.flashinfer_use_ragged: + ragged_wrapper.end_forward() + ragged_wrapper.begin_forward( + qo_indptr, + qo_indptr, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + ) + + # cached part + paged_wrapper.end_forward() + paged_wrapper.begin_forward( + qo_indptr, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + ) + + def update_indices_no_window(self): + self._init_indices_no_window() + + if self.forward_mode.is_decode(): + self._update_decode_indices(self.flashinfer_decode_wrapper) + else: + self._update_extend_indices( + self.flashinfer_prefill_wrapper_ragged, + self.flashinfer_prefill_wrapper_paged, + ) + + def update_indices_window(self): + assert self.flashinfer_use_ragged is False + + for wrapper_id in range(2): + self._init_indices_window(wrapper_id) + if self.forward_mode.is_decode(): + self._update_decode_indices(self.flashinfer_decode_wrapper[wrapper_id]) + else: + self._update_extend_indices( + None, + self.flashinfer_prefill_wrapper_paged[wrapper_id], + ) + + +def update_flashinfer_indices( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper=None, + flashinfer_use_ragged=False, +): + flashinfer_updater = FlashinferUpdater( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper, + flashinfer_use_ragged, + ) + + if model_runner.sliding_window_size is None: + flashinfer_updater.update_indices_no_window() + else: + flashinfer_updater.update_indices_window() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e46177bdbc..b6000734a2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -349,6 +349,7 @@ class ScheduleBatch: # For mixed chunekd prefill prefix_lens_cpu: List[int] = None + running_bs: int = None # For processing logprobs return_logprob: bool = False @@ -446,6 +447,9 @@ def prepare_for_extend(self, vocab_size: int): self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) def mix_with_running(self, running_batch: "ScheduleBatch"): + self.forward_mode = ForwardMode.MIXED + self.running_bs = running_batch.batch_size() + # NOTE: prefix_indices is what has been cached, but we don't cache each decode step prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs] prefix_lens_cpu.extend( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 4459213b02..c24dd50846 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -25,6 +25,7 @@ from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp +from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices from sglang.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, @@ -32,11 +33,7 @@ ) from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch -from sglang.srt.model_executor.forward_batch_info import ( - ForwardMode, - InputMetadata, - update_flashinfer_indices, -) +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import monkey_patch_vllm_all_gather diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index c158b3ce23..f3bed6bcf1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -22,8 +22,8 @@ import numpy as np import torch -import triton -import triton.language as tl + +from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -39,16 +39,21 @@ class ForwardMode(IntEnum): EXTEND = auto() # Decode one token. DECODE = auto() + # Contains both PREFILL and EXTEND. + MIXED = auto() def is_prefill(self): return self == ForwardMode.PREFILL def is_extend(self): - return self == ForwardMode.EXTEND + return self == ForwardMode.EXTEND or self == ForwardMode.MIXED def is_decode(self): return self == ForwardMode.DECODE + def is_mixed(self): + return self == ForwardMode.MIXED + @dataclass class InputMetadata: @@ -270,192 +275,3 @@ def init_flashinfer_handlers( model_runner.flashinfer_decode_wrapper, flashinfer_use_ragged, ) - - -@triton.jit -def create_flashinfer_kv_indices_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices_ptr, - page_kernel_lens_ptr, - kv_indptr, - kv_start_idx, - max_context_len, - kv_indices_ptr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(axis=0) - req_pool_index = tl.load(req_pool_indices_ptr + pid) - kv_indices_offset = tl.load(kv_indptr + pid) - - kv_start = 0 - kv_end = 0 - if kv_start_idx: - kv_start = tl.load(kv_start_idx + pid).to(tl.int32) - kv_end = kv_start - kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - - req_to_token_ptr += req_pool_index * max_context_len - kv_indices_ptr += kv_indices_offset - - ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) - st_offset = tl.arange(0, BLOCK_SIZE) - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for _ in range(num_loop): - mask = ld_offset < kv_end - data = tl.load(req_to_token_ptr + ld_offset, mask=mask) - tl.store(kv_indices_ptr + st_offset, data, mask=mask) - ld_offset += BLOCK_SIZE - st_offset += BLOCK_SIZE - - -def update_flashinfer_indices( - forward_mode, - model_runner, - req_pool_indices, - seq_lens, - prefix_lens, - flashinfer_decode_wrapper=None, - flashinfer_use_ragged=False, -): - """Init auxiliary variables for FlashInfer attention backend.""" - num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size - num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) - head_dim = model_runner.model_config.head_dim - batch_size = len(req_pool_indices) - - if model_runner.sliding_window_size is None: - if flashinfer_use_ragged: - paged_kernel_lens = prefix_lens - else: - paged_kernel_lens = seq_lens - - kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") - create_flashinfer_kv_indices_triton[(batch_size,)]( - model_runner.req_to_token_pool.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - None, - model_runner.req_to_token_pool.req_to_token.size(1), - kv_indices, - ) - - kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - - if forward_mode.is_decode(): - # CUDA graph uses different flashinfer_decode_wrapper - if flashinfer_decode_wrapper is None: - flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper - - flashinfer_decode_wrapper.end_forward() - flashinfer_decode_wrapper.begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - data_type=model_runner.kv_cache_dtype, - q_data_type=model_runner.dtype, - ) - else: - # extend part - qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - - if flashinfer_use_ragged: - model_runner.flashinfer_prefill_wrapper_ragged.end_forward() - model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - ) - - # cached part - model_runner.flashinfer_prefill_wrapper_paged.end_forward() - model_runner.flashinfer_prefill_wrapper_paged.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) - else: - # window attention use paged only - kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - for wrapper_id in range(2): - if wrapper_id == 0: - if forward_mode.is_decode(): - paged_kernel_lens = torch.minimum( - seq_lens, torch.tensor(model_runner.sliding_window_size + 1) - ) - else: - paged_kernel_lens = torch.minimum( - seq_lens, - torch.tensor(model_runner.sliding_window_size) - + seq_lens - - prefix_lens, - ) - else: - paged_kernel_lens = seq_lens - - kv_start_idx = seq_lens - paged_kernel_lens - - kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") - create_flashinfer_kv_indices_triton[(batch_size,)]( - model_runner.req_to_token_pool.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - model_runner.req_to_token_pool.req_to_token.size(1), - kv_indices, - ) - - if forward_mode.is_decode(): - # CUDA graph uses different flashinfer_decode_wrapper - if flashinfer_decode_wrapper is None: - flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper - - flashinfer_decode_wrapper[wrapper_id].end_forward() - flashinfer_decode_wrapper[wrapper_id].begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - data_type=model_runner.kv_cache_dtype, - q_data_type=model_runner.dtype, - ) - else: - # extend part - qo_indptr = torch.zeros( - (batch_size + 1,), dtype=torch.int32, device="cuda" - ) - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - - model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward() - model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 230302f264..2159cca958 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -4,9 +4,7 @@ import numpy as np import torch -from sglang.srt.model_executor.forward_batch_info import ( - create_flashinfer_kv_indices_triton, -) +from sglang.srt.layers.flashinfer_utils import create_flashinfer_kv_indices_triton class TestCreateKvIndices(unittest.TestCase): From 8c0efa514dd17c49de2cf334a3cb49ec40fa3f3a Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Wed, 11 Sep 2024 03:22:07 -0700 Subject: [PATCH 27/33] remove assertion in triton attention and add an unit test (#1385) --- .../triton_attention/decode_attention.py | 4 - .../triton_attention/extend_attention.py | 107 --------- .../triton_attention/prefill_attention.py | 2 - test/srt/test_triton_attention_kernels.py | 213 ++++++++++++++++++ 4 files changed, 213 insertions(+), 113 deletions(-) create mode 100644 test/srt/test_triton_attention_kernels.py diff --git a/python/sglang/srt/layers/triton_attention/decode_attention.py b/python/sglang/srt/layers/triton_attention/decode_attention.py index ebf29cc592..82ce6efc54 100644 --- a/python/sglang/srt/layers/triton_attention/decode_attention.py +++ b/python/sglang/srt/layers/triton_attention/decode_attention.py @@ -199,8 +199,6 @@ def _decode_att_m_fwd( BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 96, 128, 256} batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -482,8 +480,6 @@ def _decode_grouped_att_m_fwd( BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 96, 128, 256, 576, 288} if Lk == 576: BLOCK_DMODEL = 512 diff --git a/python/sglang/srt/layers/triton_attention/extend_attention.py b/python/sglang/srt/layers/triton_attention/extend_attention.py index 81039e6760..1193c4124a 100644 --- a/python/sglang/srt/layers/triton_attention/extend_attention.py +++ b/python/sglang/srt/layers/triton_attention/extend_attention.py @@ -277,12 +277,6 @@ def extend_attention_fwd( o_extend.shape[-1], ) - assert Lq == Lk and Lv == Lo - - # TODO: is the assertion necessary? - assert Lq in {16, 32, 64, 96, 128, 256, 576, 288} - assert Lv in {16, 32, 64, 96, 128, 256, 512} - if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 @@ -395,104 +389,3 @@ def redundant_attention( pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i] o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr] pt += cur_seq_len_extend - - -def test_once(B, N_CTX, H_Q, H_KV, D): - dtype = torch.float16 - - b_seq_len_prefix = torch.randint( - 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" - ) - b_seq_len_extend = torch.randint( - 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" - ) - b_seq_len = b_seq_len_prefix + b_seq_len_extend - max_len_in_batch = torch.max(b_seq_len, 0)[0].item() - - b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") - req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda") - b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") - b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) - b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - for i in range(B): - req_to_tokens[i, : b_seq_len[i]] = torch.arange( - b_start_loc[i], b_start_loc[i] + b_seq_len[i] - ) - - total_token_num = torch.sum(b_seq_len).item() - extend_token_num = torch.sum(b_seq_len_extend).item() - k_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device="cuda" - ).normal_(mean=0.1, std=0.2) - v_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device="cuda" - ).normal_(mean=0.1, std=0.2) - - k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") - v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") - q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") - for i in range(B): - extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] - extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] - extend_start = b_start_loc_extend[i] - extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] - k_extend[extend_start:extend_end] = k_buffer[ - extend_start_in_buffer:extend_end_in_buffer - ] - v_extend[extend_start:extend_end] = v_buffer[ - extend_start_in_buffer:extend_end_in_buffer - ] - q_extend[extend_start:extend_end] = torch.empty( - (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" - ).normal_(mean=0.1, std=0.2) - - o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") - o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") - - b_seq_len_extend = b_seq_len - b_seq_len_prefix - b_start_loc_extend = torch.zeros_like(b_seq_len) - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() - extend_attention_fwd( - q_extend, - k_extend, - v_extend, - o_extend, - k_buffer, - v_buffer, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_seq_len_prefix, - b_start_loc_extend, - b_seq_len_extend, - max_len_in_batch, - max_len_extend, - ) - - redundant_attention( - q_extend, - k_extend, - v_extend, - o_redundant, - k_buffer, - v_buffer, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_seq_len_prefix, - max_len_in_batch, - ) - - print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant))) - print("Max: ", torch.max(torch.abs(o_extend - o_redundant))) - - assert torch.allclose(o_extend, o_redundant, rtol=1e-2) - - -if __name__ == "__main__": - test_once(19, 12331, 12, 4, 128) - test_once(19, 12331, 12, 4, 96) diff --git a/python/sglang/srt/layers/triton_attention/prefill_attention.py b/python/sglang/srt/layers/triton_attention/prefill_attention.py index fbf9976fbc..e19e73ec1f 100644 --- a/python/sglang/srt/layers/triton_attention/prefill_attention.py +++ b/python/sglang/srt/layers/triton_attention/prefill_attention.py @@ -151,8 +151,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): BLOCK = 64 Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 96, 128, 256} sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py new file mode 100644 index 0000000000..0d094b5571 --- /dev/null +++ b/test/srt/test_triton_attention_kernels.py @@ -0,0 +1,213 @@ +import random +import unittest + +import torch + +from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd +from sglang.srt.layers.triton_attention.extend_attention import ( + extend_attention_fwd, + redundant_attention, +) +from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd + + +class TestExtendAttention(unittest.TestCase): + + def _set_all_seeds(self, seed): + """Set all random seeds for reproducibility.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def setUp(self): + # Set seeds before each test method + self._set_all_seeds(42) + + def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): + dtype = torch.bfloat16 + + b_seq_len_prefix = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len_extend = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() + + b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") + req_to_tokens = torch.empty( + (B, max_len_in_batch), dtype=torch.int32, device="cuda" + ) + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + for i in range(B): + req_to_tokens[i, : b_seq_len[i]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len[i] + ) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + v_extend[extend_start:extend_end] = v_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.empty( + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_redundant = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + ) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + b_start_loc_extend = torch.zeros_like(b_seq_len) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + b_start_loc_extend, + b_seq_len_extend, + max_len_in_batch, + max_len_extend, + ) + + redundant_attention( + q_extend, + k_extend, + v_extend, + o_redundant, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + max_len_in_batch, + ) + + self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2)) + + def test_extend_attention(self): + + # Define the varying parameter values + attention_values = [128, 96, 80, 13] + + # Loop through the values and call the method + for value in attention_values: + self._test_extend_attention_once(19, 12331, 12, 4, value) + + def _test_context_attention_once(self, head_dim): + # Set up a simple test case + batch_size = 2 + num_heads = 4 + seq_lens = [8, 12] + max_seq_len = max(seq_lens) + + # Create random input tensors + q = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") + k = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") + v = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") + o = torch.zeros(sum(seq_lens), num_heads, head_dim, device="cuda") + + # Create b_start_loc and b_seq_len tensors + b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda") + b_seq_len = torch.tensor(seq_lens, device="cuda") + + context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len) + + def test_context_attention(self): + # Here we just to ensure there is no error + # TODO: correctnesss test + head_dim = [128, 96, 80, 13] + + for dim in head_dim: + self._test_context_attention_once(dim) + + def _test_decode_attention_once(self, B, H_Q, H_KV, D): + dtype = torch.bfloat16 + seq_len = 10 # This represents the number of tokens already in the sequence + total_tokens = B * seq_len + sm_scale = 1.0 / (D**0.5) + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + + req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) + b_req_idx = torch.arange(B, device="cuda") + b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") + b_seq_len = torch.full((B,), seq_len, device="cuda") + + decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + seq_len, + total_tokens, + sm_scale, + ) + + def test_decode_attention(self): + # Here we just to ensure there is no error + # TODO: correctnesss test + + # Test configurations + configs = [ + (2, 4, 4, 64), # MHA + (2, 4, 2, 64), # GQA + (2, 4, 4, 80), # Non-standard head dim + (2, 4, 4, 13), # Prime number head dim + ] + + for B, H_Q, H_KV, D in configs: + self._test_decode_attention_once(B, H_Q, H_KV, D) + + +if __name__ == "__main__": + unittest.main() From 224200e3c2accfe4e1c1ca4fb5906a5b8b609586 Mon Sep 17 00:00:00 2001 From: Vectory Date: Wed, 11 Sep 2024 18:55:24 +0800 Subject: [PATCH 28/33] BaiChuan2 Model (#1367) Co-authored-by: wanpenghan --- README.md | 2 +- python/sglang/srt/models/baichuan.py | 421 +++++++++++++++++++++++++++ 2 files changed, 422 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/models/baichuan.py diff --git a/README.md b/README.md index 7ebada73d6..d1449abfce 100644 --- a/README.md +++ b/README.md @@ -259,8 +259,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - ChatGLM - InternLM 2 - Exaone 3 +- BaiChuan2 - MiniCPM / MiniCPM 3 - **Embedding Models** - e5-mistral diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py new file mode 100644 index 0000000000..d699064b38 --- /dev/null +++ b/python/sglang/srt/models/baichuan.py @@ -0,0 +1,421 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BaiChuan model compatible with HuggingFace weights.""" +import math +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.config import CacheConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler +from sglang.srt.model_executor.forward_batch_info import InputMetadata + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BaiChuanMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class BaiChuanAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + position_embedding: str, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + layer_id: int = 0, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size + self.head_dim = hidden_size // self.total_num_heads + self.postion_embedding = position_embedding + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.total_num_kv_heads = self.num_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + # pylint: disable=invalid-name + self.W_pack = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + # Create the alibi slopes and slice them. + if self.postion_embedding == "ALIBI": + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.W_pack(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.postion_embedding != "ALIBI": + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class BaiChuanDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + position_embedding: str, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = BaiChuanAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + position_embedding=position_embedding, + rope_theta=rope_theta, + layer_id=layer_id, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + ) + self.mlp = BaiChuanMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class BaiChuanModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + position_embedding: str, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + BaiChuanDecoderLayer( + config, + layer_id=i, + position_embedding=position_embedding, + quant_config=quant_config, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BaiChuanBaseForCausalLM(nn.Module): + packed_modules_mapping = { + "W_pack": ["W_pack"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "W_pack", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.config = config + + self.quant_config = quant_config + self.model = BaiChuanModel(config, position_embedding, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata) + logits_output = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + + return sample_output, logits_output + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name == "lm_head.weight": + # Unlike Baichuan, Baichuan2 normalizes the head weights. + # Refer to: + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 + # Distinguish between Baichuan and Baichuan2 by checking the + # vocab size. This is suggested by + # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 + is_baichuan2 = self.config.vocab_size == 125696 + if is_baichuan2: + loaded_weight = torch.nn.functional.normalize(loaded_weight) + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +class BaichuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 13B and Baichuan2 7B/13B.""" + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + if config.hidden_size == 4096: # baichuan2 7b + super().__init__(config, "ROPE", cache_config, quant_config) + else: # baichuan 13b, baichuan2 13b + super().__init__(config, "ALIBI", cache_config, quant_config) + + +EntryClass = [BaichuanForCausalLM] From 15c75e41462dfdb6e405bf061ab0640bb04ccdbf Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 11 Sep 2024 04:36:21 -0700 Subject: [PATCH 29/33] [Fix] Fix --disable-flashinfer (#1389) --- python/sglang/srt/server_args.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0881344c08..776f7bec32 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -140,11 +140,13 @@ def __post_init__(self): "The option '--disable-flashinfer' will be deprecated in the next release. " "Please use '--attention-backend triton' instead." ) + self.attention_backend = "triton" if self.disable_flashinfer_sampling: logger.warning( "The option '--disable-flashinfer-sampling' will be deprecated in the next release. " "Please use '--sampling-backend pytorch' instead. " ) + self.sampling_backend = "pytorch" # Model-specific patches if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: From c03cece42f425cc8e73df77a6d1fcc316fd44b50 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 11 Sep 2024 04:50:04 -0700 Subject: [PATCH 30/33] Improve error reporting during server launch (#1390) --- python/sglang/srt/server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4aaf018a1b..5bdee03de3 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -447,13 +447,12 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): time.sleep(1) try: res = requests.get(url + "/get_model_info", timeout=5, headers=headers) - assert res.status_code == 200, f"{res}" + assert res.status_code == 200, f"{res=}, {res.text=}" success = True break - except (AssertionError, requests.exceptions.RequestException) as e: + except (AssertionError, requests.exceptions.RequestException): last_traceback = get_exception_traceback() pass - model_info = res.json() if not success: if pipe_finish_writer is not None: @@ -462,6 +461,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): kill_child_process(pid, including_parent=False) return + model_info = res.json() + # Send a warmup request request_name = "/generate" if model_info["is_generation"] else "/encode" max_new_tokens = 8 if model_info["is_generation"] else 1 From fec185ce0cbaf3b0597d0d1a71c335a8c52ce1ba Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 11 Sep 2024 11:44:26 -0700 Subject: [PATCH 31/33] Refactor attention backend (#1381) --- python/sglang/srt/layers/attention_backend.py | 383 ++++++++++++++++++ python/sglang/srt/layers/flashinfer_utils.py | 72 ++-- python/sglang/srt/layers/radix_attention.py | 175 +------- .../triton_attention/decode_attention.py | 7 +- .../triton_attention/extend_attention.py | 31 +- python/sglang/srt/managers/schedule_batch.py | 2 +- .../srt/model_executor/cuda_graph_runner.py | 154 +++---- .../srt/model_executor/forward_batch_info.py | 134 ++---- .../sglang/srt/model_executor/model_runner.py | 125 ++---- .../srt/sampling/sampling_batch_info.py | 8 +- python/sglang/srt/server.py | 10 +- python/sglang/srt/server_args.py | 15 +- test/srt/test_create_kvindices.py | 2 +- test/srt/test_moe_serving_throughput.py | 3 +- test/srt/test_serving_throughput.py | 3 +- test/srt/test_triton_attention_kernels.py | 8 +- 16 files changed, 568 insertions(+), 564 deletions(-) create mode 100644 python/sglang/srt/layers/attention_backend.py diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py new file mode 100644 index 0000000000..35fe4ed925 --- /dev/null +++ b/python/sglang/srt/layers/attention_backend.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +""" +Support different attention backends. +Now there are two backends: FlashInfer and Triton. +FlashInfer is faster and Triton is easier to customize. +Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, +) +from flashinfer.cascade import merge_state +from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + +from sglang.global_config import global_config +from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + + +class AttentionBackend(ABC): + """The base class of attention backends""" + + @abstractmethod + def init_forward_metadata( + self, batch: ScheduleBatch, input_metadata: InputMetadata + ): + pass + + def forward(self, q, k, v, layer, input_metadata: InputMetadata): + if input_metadata.forward_mode.is_decode(): + return self.forward_decode(q, k, v, layer, input_metadata) + else: + return self.forward_extend(q, k, v, layer, input_metadata) + + +class FlashInferAttnBackend(AttentionBackend): + """Flashinfer attention kernels.""" + + def __init__(self, model_runner: ModelRunner): + super().__init__() + self.model_runner = model_runner + + if not _grouped_size_compiled_for_decode_kernels( + model_runner.model_config.num_attention_heads // model_runner.tp_size, + model_runner.model_config.get_num_kv_heads(model_runner.tp_size), + ): + self.decode_use_tensor_cores = True + else: + self.decode_use_tensor_cores = False + + self.workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device="cuda", + ) + + if model_runner.sliding_window_size is None: + self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + self.workspace_buffer, "NHD" + ) + self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD" + ) + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, + ) + else: + # Two wrappers: one for sliding window attention and one for full attention. + # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs + self.prefill_wrapper_ragged = None + self.prefill_wrapper_paged = [] + self.decode_wrapper = [] + for _ in range(2): + self.prefill_wrapper_paged.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) + self.decode_wrapper.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, + ) + ) + + self.forward_metadata = None + self.cuda_graph_metadata = {} + + def init_forward_metadata( + self, batch: ScheduleBatch, input_metadata: InputMetadata + ): + if input_metadata.forward_mode.is_decode(): + prefix_lens = None + use_ragged = False + total_num_tokens = None + else: + prefix_lens = input_metadata.extend_prefix_lens + + # Some heuristics to check whether to use ragged forward + use_ragged = False + if ( + int(torch.sum(input_metadata.seq_lens)) > 4096 + and self.model_runner.sliding_window_size is None + ): + use_ragged = True + + total_num_tokens = torch.sum(input_metadata.seq_lens).item() + + update_flashinfer_indices( + input_metadata.forward_mode, + self.model_runner, + input_metadata.req_pool_indices, + input_metadata.seq_lens, + prefix_lens, + use_ragged=use_ragged, + ) + + self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper) + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device="cuda" + ) + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.model_runner.model_config.context_len,), + dtype=torch.int32, + device="cuda", + ) + self.cuda_graph_kv_last_page_len = torch.ones( + (max_bs,), dtype=torch.int32, device="cuda" + ) + + if self.model_runner.sliding_window_size is not None: + self.cuda_graph_kv_indptr = [ + self.cuda_graph_kv_indptr, + self.cuda_graph_kv_indptr.clone(), + ] + self.cuda_graph_kv_indices = [ + self.cuda_graph_kv_indices, + self.cuda_graph_kv_indices.clone(), + ] + + def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens): + if self.model_runner.sliding_window_size is None: + decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices, + paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs], + ) + else: + decode_wrapper = [] + for i in range(2): + decode_wrapper.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[ + :bs + ], + ) + ) + + update_flashinfer_indices( + ForwardMode.DECODE, + self.model_runner, + req_pool_indices, + seq_lens, + None, + decode_wrapper, + ) + + self.cuda_graph_metadata[bs] = decode_wrapper + + self.forward_metadata = (False, None, decode_wrapper) + + def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens): + update_flashinfer_indices( + ForwardMode.DECODE, + self.model_runner, + req_pool_indices[:bs], + seq_lens[:bs], + None, + self.cuda_graph_metadata[bs], + ) + + def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + if not isinstance(self.prefill_wrapper_paged, list): + prefill_wrapper_paged = self.prefill_wrapper_paged + else: + if layer.sliding_window_size != -1: + prefill_wrapper_paged = self.prefill_wrapper_paged[0] + else: + prefill_wrapper_paged = self.prefill_wrapper_paged[1] + + use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata + + if not use_ragged: + if k is not None: + assert v is not None + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + o = prefill_wrapper_paged.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=True, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap, + ) + else: + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + if input_metadata.extend_no_prefix: + o = o1 + else: + o2, s2 = prefill_wrapper_paged.forward_return_lse( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=False, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + o, _ = merge_state(o1, s1, o2, s2) + + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + + if total_num_tokens >= global_config.layer_sync_threshold: + torch.cuda.synchronize() + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata + + if isinstance(decode_wrapper, list): + if layer.sliding_window_size != -1: + decode_wrapper = decode_wrapper[0] + else: + decode_wrapper = decode_wrapper[1] + + if k is not None: + assert v is not None + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + + o = decode_wrapper.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + +class TritonAttnBackend(AttentionBackend): + def __init__(self, model_runner: ModelRunner): + # Lazy import to avoid the initialization of cuda context + from sglang.srt.layers.triton_attention.decode_attention import ( + decode_attention_fwd, + ) + from sglang.srt.layers.triton_attention.extend_attention import ( + extend_attention_fwd, + ) + + super().__init__() + + self.decode_attention_fwd = decode_attention_fwd + self.extend_attention_fwd = extend_attention_fwd + + self.forward_metadata = None + + def init_forward_metadata( + self, batch: ScheduleBatch, input_metadata: InputMetadata + ): + """Init auxiliary variables for triton attention backend.""" + + if input_metadata.forward_mode.is_decode(): + max_seq_len = torch.max(input_metadata.seq_lens).item() + start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32) + start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0) + + total_num_tokens = torch.sum(input_metadata.seq_lens).item() + max_extend_len = None + else: + start_loc = max_seq_len = total_num_tokens = None + prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item() + + self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens + + def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + + start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata + + self.extend_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.contiguous(), + v.contiguous(), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id), + input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id), + input_metadata.req_to_token_pool.req_to_token, + input_metadata.req_pool_indices, + input_metadata.seq_lens, + input_metadata.extend_seq_lens, + input_metadata.extend_start_loc, + max_extend_len, + layer.scaling, + layer.logit_cap, + ) + + return o + + def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata + + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + + self.decode_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id), + input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + input_metadata.req_to_token_pool.req_to_token, + input_metadata.req_pool_indices, + start_loc, + input_metadata.seq_lens, + max_seq_len, + total_num_tokens, + layer.scaling, + layer.logit_cap, + ) + + return o diff --git a/python/sglang/srt/layers/flashinfer_utils.py b/python/sglang/srt/layers/flashinfer_utils.py index 1f9ab15145..c473d6e452 100644 --- a/python/sglang/srt/layers/flashinfer_utils.py +++ b/python/sglang/srt/layers/flashinfer_utils.py @@ -10,8 +10,8 @@ def create_flashinfer_kv_indices_triton( page_kernel_lens_ptr, kv_indptr, kv_start_idx, - max_context_len, kv_indices_ptr, + max_context_len: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 512 pid = tl.program_id(axis=0) @@ -47,15 +47,15 @@ def __init__( req_pool_indices, seq_lens, prefix_lens, - flashinfer_decode_wrapper=None, - flashinfer_use_ragged=False, + decode_wrapper=None, + use_ragged=False, ): self.forward_mode = forward_mode self.model_runner = model_runner self.req_pool_indices = req_pool_indices self.seq_lens = seq_lens self.prefix_lens = prefix_lens - self.flashinfer_use_ragged = flashinfer_use_ragged + self.use_ragged = use_ragged self.num_qo_heads = ( model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -71,20 +71,17 @@ def __init__( ) ( - self.flashinfer_decode_wrapper, - self.flashinfer_prefill_wrapper_ragged, - self.flashinfer_prefill_wrapper_paged, + self.decode_wrapper, + self.prefill_wrapper_ragged, + self.prefill_wrapper_paged, ) = ( - flashinfer_decode_wrapper, - self.model_runner.flashinfer_prefill_wrapper_ragged, - self.model_runner.flashinfer_prefill_wrapper_paged, + decode_wrapper or self.model_runner.attn_backend.decode_wrapper, + self.model_runner.attn_backend.prefill_wrapper_ragged, + self.model_runner.attn_backend.prefill_wrapper_paged, ) - # CUDA graph uses different flashinfer_decode_wrapper - if self.flashinfer_decode_wrapper is None: - self.flashinfer_decode_wrapper = self.model_runner.flashinfer_decode_wrapper - def _init_indices_no_window(self): - if self.flashinfer_use_ragged: + def _init_indices_no_sliding_window(self): + if self.use_ragged: paged_kernel_lens = self.prefix_lens else: paged_kernel_lens = self.seq_lens @@ -103,13 +100,13 @@ def _init_indices_no_window(self): paged_kernel_lens, self.kv_indptr, None, - self.model_runner.req_to_token_pool.req_to_token.size(1), self.kv_indices, + self.model_runner.req_to_token_pool.req_to_token.size(1), ) - def _init_indices_window(self, wrapper_id): - # window attention use paged only + def _init_indices_sliding_window(self, wrapper_id): if wrapper_id == 0: + # window attention use paged only if self.forward_mode.is_decode(): paged_kernel_lens = torch.minimum( self.seq_lens, @@ -123,6 +120,7 @@ def _init_indices_window(self, wrapper_id): - self.prefix_lens, ) else: + # full attention paged_kernel_lens = self.seq_lens kv_start_idx = self.seq_lens - paged_kernel_lens @@ -139,8 +137,8 @@ def _init_indices_window(self, wrapper_id): paged_kernel_lens, self.kv_indptr, kv_start_idx, - self.model_runner.req_to_token_pool.req_to_token.size(1), self.kv_indices, + self.model_runner.req_to_token_pool.req_to_token.size(1), ) def _update_decode_indices(self, decode_wrapper): @@ -164,7 +162,7 @@ def _update_extend_indices(self, ragged_wrapper, paged_wrapper): ) qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) - if self.flashinfer_use_ragged: + if self.use_ragged: ragged_wrapper.end_forward() ragged_wrapper.begin_forward( qo_indptr, @@ -187,28 +185,28 @@ def _update_extend_indices(self, ragged_wrapper, paged_wrapper): 1, ) - def update_indices_no_window(self): - self._init_indices_no_window() + def update_indices_no_sliding_window(self): + self._init_indices_no_sliding_window() if self.forward_mode.is_decode(): - self._update_decode_indices(self.flashinfer_decode_wrapper) + self._update_decode_indices(self.decode_wrapper) else: self._update_extend_indices( - self.flashinfer_prefill_wrapper_ragged, - self.flashinfer_prefill_wrapper_paged, + self.prefill_wrapper_ragged, + self.prefill_wrapper_paged, ) - def update_indices_window(self): - assert self.flashinfer_use_ragged is False + def update_indices_sliding_window(self): + assert self.use_ragged is False for wrapper_id in range(2): - self._init_indices_window(wrapper_id) + self._init_indices_sliding_window(wrapper_id) if self.forward_mode.is_decode(): - self._update_decode_indices(self.flashinfer_decode_wrapper[wrapper_id]) + self._update_decode_indices(self.decode_wrapper[wrapper_id]) else: self._update_extend_indices( None, - self.flashinfer_prefill_wrapper_paged[wrapper_id], + self.prefill_wrapper_paged[wrapper_id], ) @@ -218,20 +216,20 @@ def update_flashinfer_indices( req_pool_indices, seq_lens, prefix_lens, - flashinfer_decode_wrapper=None, - flashinfer_use_ragged=False, + decode_wrapper=None, + use_ragged=False, ): - flashinfer_updater = FlashinferUpdater( + updater = FlashinferUpdater( forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens, - flashinfer_decode_wrapper, - flashinfer_use_ragged, + decode_wrapper, + use_ragged, ) if model_runner.sliding_window_size is None: - flashinfer_updater.update_indices_no_window() + updater.update_indices_no_sliding_window() else: - flashinfer_updater.update_indices_window() + updater.update_indices_sliding_window() diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 48567e43d4..8454d29281 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -15,25 +15,14 @@ """Radix attention.""" -from typing import Optional - -import torch -from flashinfer.cascade import merge_state from torch import nn -from sglang.global_config import global_config -from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd -from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.model_executor.model_runner import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import InputMetadata class RadixAttention(nn.Module): """ The attention layer implementation. - Now it has two backends: FlashInfer and Triton. - FlashInfer is faster and Triton is easier to customize. - It supports two operators: extend (i.e. prefill with cached prefix) and decode. """ def __init__( @@ -43,8 +32,8 @@ def __init__( scaling: float, num_kv_heads: int, layer_id: int, - sliding_window_size: Optional[int] = None, - logit_cap: int = -1, + sliding_window_size: int = -1, + logit_cap: float = 0.0, v_head_dim: int = -1, ): super().__init__() @@ -56,164 +45,14 @@ def __init__( self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id - self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 - self.sliding_window_size = sliding_window_size if sliding_window_size else -1 - - # Choose backend - if ( - global_server_args_dict["attention_backend"] == "flashinfer" - and self.qk_head_dim == self.v_head_dim - ): - self.extend_forward = self.extend_forward_flashinfer - self.decode_forward = self.decode_forward_flashinfer - elif global_server_args_dict["attention_backend"] == "triton": - self.extend_forward = self.extend_forward_triton - self.decode_forward = self.decode_forward_triton - else: - raise ValueError( - f"Invalid attention backend: {global_server_args_dict['attention_backend']}" - ) - - def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): - if self.qk_head_dim != self.v_head_dim: - o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) - else: - o = torch.empty_like(q) - - self.store_kv_cache(k, v, input_metadata) - extend_attention_fwd( - q.view(-1, self.tp_q_head_num, self.qk_head_dim), - k.contiguous(), - v.contiguous(), - o.view(-1, self.tp_q_head_num, self.v_head_dim), - input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), - input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), - input_metadata.req_to_token_pool.req_to_token, - input_metadata.req_pool_indices, - input_metadata.triton_start_loc, - input_metadata.seq_lens, - input_metadata.triton_prefix_lens, - input_metadata.extend_start_loc, - input_metadata.extend_seq_lens, - input_metadata.triton_max_seq_len, - input_metadata.triton_max_extend_len, - sm_scale=self.scaling, - logit_cap=self.logit_cap, - ) - - return o - - def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): - if self.qk_head_dim != self.v_head_dim: - o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) - else: - o = torch.empty_like(q) - self.store_kv_cache(k, v, input_metadata) - - decode_attention_fwd( - q.view(-1, self.tp_q_head_num, self.qk_head_dim), - input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), - input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), - o.view(-1, self.tp_q_head_num, self.v_head_dim), - input_metadata.req_to_token_pool.req_to_token, - input_metadata.req_pool_indices, - input_metadata.triton_start_loc, - input_metadata.seq_lens, - input_metadata.triton_max_seq_len, - input_metadata.total_num_tokens, - sm_scale=self.scaling, - logit_cap=self.logit_cap, - ) - - return o - - def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - # using two wrappers is unnecessary in the current PR, but are prepared for future PRs - prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged - if self.sliding_window_size != -1: - prefill_wrapper_paged = prefill_wrapper_paged[0] - else: - if isinstance(prefill_wrapper_paged, list): - prefill_wrapper_paged = prefill_wrapper_paged[1] - - if not input_metadata.flashinfer_use_ragged: - if k is not None: - assert v is not None - self.store_kv_cache(k, v, input_metadata) - - o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), - causal=True, - sm_scale=self.scaling, - window_left=self.sliding_window_size, - logits_soft_cap=self.logit_cap, - ) - else: - o1, s1 = ( - input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), - v.contiguous().view(-1, self.tp_v_head_num, self.head_dim), - causal=True, - sm_scale=self.scaling, - logits_soft_cap=self.logit_cap, - ) - ) - - if input_metadata.extend_no_prefix: - o = o1 - else: - o2, s2 = prefill_wrapper_paged.forward_return_lse( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), - causal=False, - sm_scale=self.scaling, - logits_soft_cap=self.logit_cap, - ) - - o, _ = merge_state(o1, s1, o2, s2) - - self.store_kv_cache(k, v, input_metadata) - - if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: - torch.cuda.synchronize() - - return o.view(-1, self.tp_q_head_num * self.head_dim) - - def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - decode_wrapper = input_metadata.flashinfer_decode_wrapper - if self.sliding_window_size != -1: - decode_wrapper = decode_wrapper[0] - else: - if isinstance(decode_wrapper, list): - decode_wrapper = decode_wrapper[1] - - if k is not None: - assert v is not None - self.store_kv_cache(k, v, input_metadata) - - o = decode_wrapper.forward( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), - sm_scale=self.scaling, - logits_soft_cap=self.logit_cap, - ) - - return o.view(-1, self.tp_q_head_num * self.head_dim) + self.logit_cap = logit_cap + self.sliding_window_size = sliding_window_size or -1 def forward(self, q, k, v, input_metadata: InputMetadata): if k is not None: + # For cross-layer sharing, kv can be None assert v is not None k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim) - if input_metadata.forward_mode.is_extend(): - return self.extend_forward(q, k, v, input_metadata) - elif input_metadata.forward_mode.is_decode(): - return self.decode_forward(q, k, v, input_metadata) - - def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): - input_metadata.token_to_kv_pool.set_kv_buffer( - self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v - ) + return input_metadata.attn_backend.forward(q, k, v, self, input_metadata) diff --git a/python/sglang/srt/layers/triton_attention/decode_attention.py b/python/sglang/srt/layers/triton_attention/decode_attention.py index 82ce6efc54..adfa0d936d 100644 --- a/python/sglang/srt/layers/triton_attention/decode_attention.py +++ b/python/sglang/srt/layers/triton_attention/decode_attention.py @@ -15,6 +15,7 @@ """ Memory-efficient attention for decoding. +It supports page size = 1. """ # Adapted from @@ -197,7 +198,6 @@ def _decode_att_m_fwd( logit_cap, ): BLOCK = 32 - # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -478,7 +478,6 @@ def _decode_grouped_att_m_fwd( logit_cap, ): BLOCK = 32 - # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] if Lk == 576: @@ -570,9 +569,9 @@ def _decode_grouped_softmax_reducev_fwd( BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, + Lv=Lv, num_warps=num_warps, num_stages=1, - Lv=Lv, ) @@ -588,7 +587,7 @@ def decode_attention_fwd( max_len_in_batch, total_num_tokens, sm_scale, - logit_cap=-1, + logit_cap=0.0, att_m=None, ): if att_m is None: diff --git a/python/sglang/srt/layers/triton_attention/extend_attention.py b/python/sglang/srt/layers/triton_attention/extend_attention.py index 1193c4124a..3cf150d8dd 100644 --- a/python/sglang/srt/layers/triton_attention/extend_attention.py +++ b/python/sglang/srt/layers/triton_attention/extend_attention.py @@ -61,14 +61,14 @@ def _fwd_kernel( stride_buf_vbs, stride_buf_vh, stride_req_to_tokens_b, + logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - logit_cap: tl.constexpr, - Lq: tl.constexpr, - Lv: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -111,7 +111,7 @@ def _fwd_kernel( ) qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) - # stage1: compute scores with prefix + # stage 1: compute scores with prefix offs_n = tl.arange(0, BLOCK_N) acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) @@ -174,7 +174,7 @@ def _fwd_kernel( e_max = n_e_max - # stage2: compute the trianlge part + # stage 2: compute the trianlge part cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) for start_n in range(0, cur_block_m_end, BLOCK_N): @@ -255,26 +255,22 @@ def extend_attention_fwd( v_buffer, req_to_tokens, b_req_idx, - b_start_loc, b_seq_len, - b_seq_len_prefix, - b_start_loc_extend, b_seq_len_extend, - max_len_in_batch, + b_start_loc_extend, max_len_extend, sm_scale=None, - logit_cap=-1, + logit_cap=0.0, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors k_buffer, v_buffer: (prefix + extend) tensors in mem_manager """ - Lq, Lk, Lv, Lo = ( + Lq, Lk, Lv = ( q_extend.shape[-1], k_extend.shape[-1], v_extend.shape[-1], - o_extend.shape[-1], ) if Lq == 576: @@ -303,7 +299,7 @@ def extend_attention_fwd( else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) - sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale + sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] @@ -338,27 +334,24 @@ def extend_attention_fwd( v_buffer.stride(0), v_buffer.stride(1), req_to_tokens.stride(0), + logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - logit_cap=logit_cap, Lq=Lq, Lv=Lv, + num_warps=num_warps, + num_stages=num_stages, ) def redundant_attention( q_extend, - k_extend, - v_extend, o_extend, k_buffer, v_buffer, - req_to_tokens, b_req_idx, b_start_loc, b_seq_len, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b6000734a2..0a5eb3cdf6 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -368,7 +368,7 @@ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): ) def batch_size(self): - return len(self.reqs) if self.reqs is not None else 0 + return len(self.reqs) if self.reqs else 0 def is_empty(self): return len(self.reqs) == 0 diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c24dd50846..ecaeb404c8 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -13,15 +13,13 @@ limitations under the License. """ -"""Run the model with cuda graph.""" +"""Run the model with cuda graph and torch.compile.""" import bisect from contextlib import contextmanager -from typing import Callable, List +from typing import Callable import torch -from flashinfer import BatchDecodeWithPagedKVCacheWrapper -from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp @@ -55,6 +53,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): def patch_model( model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" ): + """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None try: @@ -86,23 +85,28 @@ def set_torch_compile_config(): class CudaGraphRunner: - def __init__( - self, - model_runner: "ModelRunner", - max_batch_size_to_capture: int, - use_torch_compile: bool, - disable_padding: bool, - ): + """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" + + def __init__(self, model_runner: "ModelRunner"): + # Parse args self.model_runner = model_runner self.graphs = {} self.input_buffers = {} self.output_buffers = {} self.flashinfer_handlers = {} self.graph_memory_pool = None - self.disable_padding = disable_padding + self.use_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + + # Batch sizes to capture + if self.model_runner.server_args.disable_cuda_graph_padding: + self.capture_bs = list(range(1, 32)) + [64, 128] + else: + self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else [] # Common inputs - self.max_bs = max_batch_size_to_capture + self.max_bs = max(self.capture_bs) self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") self.req_pool_indices = torch.zeros( (self.max_bs,), dtype=torch.int32, device="cuda" @@ -115,56 +119,39 @@ def __init__( (self.max_bs,), dtype=torch.int32, device="cuda" ) - # FlashInfer inputs - self.flashinfer_kv_indptr = torch.zeros( - (self.max_bs + 1,), dtype=torch.int32, device="cuda" - ) - self.flashinfer_kv_indices = torch.zeros( - (self.max_bs * model_runner.model_config.context_len,), - dtype=torch.int32, - device="cuda", - ) - self.flashinfer_kv_last_page_len = torch.ones( - (self.max_bs,), dtype=torch.int32, device="cuda" - ) - if model_runner.sliding_window_size is None: - self.flashinfer_workspace_buffer = ( - self.model_runner.flashinfer_workspace_buffer - ) - else: - self.flashinfer_workspace_buffer = ( - self.model_runner.flashinfer_workspace_buffer - ) - - self.flashinfer_kv_indptr = [ - self.flashinfer_kv_indptr, - self.flashinfer_kv_indptr.clone(), - ] - self.flashinfer_kv_indices = [ - self.flashinfer_kv_indices, - self.flashinfer_kv_indices.clone(), - ] + # Attention backend + self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) - # Sampling inputs + # Sampling info vocab_size = model_runner.model_config.vocab_size self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size) - self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] - - if use_torch_compile: + if self.use_torch_compile: set_torch_compile_config() + # Capture + try: + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n" + "Possible solutions:\n" + "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + ) + def can_run(self, batch_size: int): if self.disable_padding: return batch_size in self.graphs else: return batch_size <= self.max_bs - def capture(self, batch_size_list: List[int]): - self.batch_size_list = batch_size_list + def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream - for bs in batch_size_list: + for bs in self.capture_bs: with patch_model( self.model_runner.model, bs in self.compile_bs, @@ -172,14 +159,10 @@ def capture(self, batch_size_list: List[int]): ) as forward: ( graph, - input_buffers, output_buffers, - flashinfer_handler, ) = self.capture_one_batch_size(bs, forward) self.graphs[bs] = graph - self.input_buffers[bs] = input_buffers self.output_buffers[bs] = output_buffers - self.flashinfer_handlers[bs] = flashinfer_handler def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() @@ -192,48 +175,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable): position_ids_offsets = self.position_ids_offsets[:bs] out_cache_loc = self.out_cache_loc[:bs] - # FlashInfer inputs - if not _grouped_size_compiled_for_decode_kernels( - self.model_runner.model_config.num_attention_heads - // self.model_runner.tp_size, - self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size), - ): - use_tensor_cores = True - else: - use_tensor_cores = False - if self.model_runner.sliding_window_size is None: - flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=use_tensor_cores, - paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1], - paged_kv_indices_buffer=self.flashinfer_kv_indices, - paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], - ) - else: - flashinfer_decode_wrapper = [] - for i in range(2): - flashinfer_decode_wrapper.append( - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=use_tensor_cores, - paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1], - paged_kv_indices_buffer=self.flashinfer_kv_indices[i], - paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[ - :bs - ], - ) - ) - update_flashinfer_indices( - ForwardMode.DECODE, - self.model_runner, - req_pool_indices, - seq_lens, - None, - flashinfer_decode_wrapper, + # Attention backend + self.model_runner.attn_backend.capture_cuda_graph_init( + bs, req_pool_indices, seq_lens ) # Run and capture @@ -246,13 +190,12 @@ def run_once(): seq_lens=seq_lens, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, return_logprob=False, top_logprobs_nums=0, positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), - flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) - return forward(input_ids, input_metadata.positions, input_metadata) for _ in range(2): @@ -274,15 +217,15 @@ def run_once(): self.model_runner.tp_group.barrier() self.graph_memory_pool = graph.pool() - return graph, None, out, flashinfer_decode_wrapper + return graph, out def replay(self, batch: ScheduleBatch): assert batch.out_cache_loc is not None raw_bs = len(batch.reqs) # Pad - index = bisect.bisect_left(self.batch_size_list, raw_bs) - bs = self.batch_size_list[index] + index = bisect.bisect_left(self.capture_bs, raw_bs) + bs = self.capture_bs[index] if bs != raw_bs: self.seq_lens.zero_() self.position_ids_offsets.fill_(1) @@ -295,14 +238,9 @@ def replay(self, batch: ScheduleBatch): self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets self.out_cache_loc[:raw_bs] = batch.out_cache_loc - # FlashInfer inputs - update_flashinfer_indices( - ForwardMode.DECODE, - self.model_runner, - self.req_pool_indices[:bs], - self.seq_lens[:bs], - None, - self.flashinfer_handlers[bs], + # Attention backend + self.model_runner.attn_backend.replay_cuda_graph_init( + bs, self.req_pool_indices, self.seq_lens ) # Sampling inputs diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index f3bed6bcf1..8542ced358 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -23,9 +23,8 @@ import numpy as np import torch -from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices - if TYPE_CHECKING: + from sglang.srt.layers.attention_backend import AttentionBackend from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner @@ -66,12 +65,11 @@ class InputMetadata: seq_lens: torch.Tensor req_to_token_pool: ReqToTokenPool token_to_kv_pool: BaseTokenToKVPool + attn_backend: AttentionBackend # Output location of the KV cache out_cache_loc: torch.Tensor - total_num_tokens: int = None - # Position information positions: torch.Tensor = None @@ -93,18 +91,6 @@ class InputMetadata: image_offsets: List[List[int]] = None modalities: List[List[str]] = None - # Trition attention backend - triton_max_seq_len: int = 0 - triton_max_extend_len: int = 0 - triton_start_loc: torch.Tensor = None - triton_prefix_lens: torch.Tensor = None - - # FlashInfer attention backend - flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None - flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None - flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - flashinfer_use_ragged: bool = False - def init_multimuldal_info(self, batch: ScheduleBatch): reqs = batch.reqs self.pixel_values = [r.pixel_values for r in reqs] @@ -154,32 +140,27 @@ def compute_positions(self, batch: ScheduleBatch): self.positions = self.positions.to(torch.int64) def compute_extend_infos(self, batch: ScheduleBatch): - if self.forward_mode.is_decode(): - self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None - self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None - else: - extend_lens_cpu = [ - len(r.fill_ids) - batch.prefix_lens_cpu[i] - for i, r in enumerate(batch.reqs) - ] - self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") - self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") - self.extend_start_loc = torch.zeros_like(self.seq_lens) - self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) - - self.extend_seq_lens_cpu = extend_lens_cpu - self.logprob_start_lens_cpu = [ - ( - min( - req.logprob_start_len - batch.prefix_lens_cpu[i], - extend_lens_cpu[i] - 1, - ) - if req.logprob_start_len >= batch.prefix_lens_cpu[i] - else extend_lens_cpu[i] - 1 # Fake extend, actually decode + extend_lens_cpu = [ + len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs) + ] + self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") + self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + self.extend_start_loc = torch.zeros_like(self.seq_lens) + self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) + self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) + + self.extend_seq_lens_cpu = extend_lens_cpu + self.logprob_start_lens_cpu = [ + ( + min( + req.logprob_start_len - batch.prefix_lens_cpu[i], + extend_lens_cpu[i] - 1, ) - for i, req in enumerate(batch.reqs) - ] + if req.logprob_start_len >= batch.prefix_lens_cpu[i] + else extend_lens_cpu[i] - 1 # Fake extend, actually decode + ) + for i, req in enumerate(batch.reqs) + ] @classmethod def from_schedule_batch( @@ -195,6 +176,7 @@ def from_schedule_batch( seq_lens=batch.seq_lens, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, + attn_backend=model_runner.attn_backend, out_cache_loc=batch.out_cache_loc, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, @@ -202,76 +184,12 @@ def from_schedule_batch( ret.sampling_info.update_penalties() ret.sampling_info.update_regex_vocab_mask(batch) - ret.compute_positions(batch) - ret.compute_extend_infos(batch) - - fm = batch.forward_mode - if not fm.is_decode() or model_runner.server_args.attention_backend == "triton": - ret.total_num_tokens = int(torch.sum(ret.seq_lens)) - - if not fm.is_decode(): + if not batch.forward_mode.is_decode(): ret.init_multimuldal_info(batch) + ret.compute_extend_infos(batch) - if model_runner.server_args.attention_backend == "triton": - ret.init_triton_args(batch) - - flashinfer_use_ragged = False - if model_runner.server_args.attention_backend == "flashinfer": - if ( - not fm.is_decode() - and int(torch.sum(ret.seq_lens)) > 4096 - and model_runner.sliding_window_size is None - ): - flashinfer_use_ragged = True - ret.init_flashinfer_handlers( - model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged - ) + model_runner.attn_backend.init_forward_metadata(batch, ret) return ret - - def init_triton_args(self, batch: ScheduleBatch): - """Init auxiliary variables for triton attention backend.""" - self.triton_max_seq_len = int(torch.max(self.seq_lens)) - self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) - self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) - - if self.forward_mode.is_decode(): - self.triton_max_extend_len = None - else: - self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") - extend_seq_lens = self.seq_lens - self.triton_prefix_lens - self.triton_max_extend_len = int(torch.max(extend_seq_lens)) - - def init_flashinfer_handlers( - self, - model_runner, - prefix_lens_cpu, - flashinfer_use_ragged, - ): - if self.forward_mode.is_decode(): - prefix_lens = None - else: - prefix_lens = self.extend_prefix_lens - - update_flashinfer_indices( - self.forward_mode, - model_runner, - self.req_pool_indices, - self.seq_lens, - prefix_lens, - flashinfer_use_ragged=flashinfer_use_ragged, - ) - - ( - self.flashinfer_prefill_wrapper_ragged, - self.flashinfer_prefill_wrapper_paged, - self.flashinfer_decode_wrapper, - self.flashinfer_use_ragged, - ) = ( - model_runner.flashinfer_prefill_wrapper_ragged, - model_runner.flashinfer_prefill_wrapper_paged, - model_runner.flashinfer_decode_wrapper, - flashinfer_use_ragged, - ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b04b0d7c01..80c741652d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -25,12 +25,6 @@ import torch import torch.nn as nn -from flashinfer import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - BatchPrefillWithRaggedKVCacheWrapper, -) -from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig from vllm.distributed import ( @@ -43,8 +37,8 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry -from sglang.global_config import global_config from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict @@ -69,6 +63,8 @@ class ModelRunner: + """ModelRunner runs the forward passes of the models.""" + def __init__( self, model_config: ModelConfig, @@ -100,6 +96,7 @@ def __init__( } ) + # Model-specific adjustment if self.is_multimodal_model: logger.info( "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." @@ -107,6 +104,7 @@ def __init__( server_args.chunked_prefill_size = None server_args.mem_fraction_static *= 0.95 + # Init componnets min_per_gpu_memory = self.init_torch_distributed() self.load_model() self.init_memory_pool( @@ -115,7 +113,7 @@ def __init__( server_args.max_total_tokens, ) self.init_cublas() - self.init_flashinfer() + self.init_attention_backend() self.init_cuda_graphs() def init_torch_distributed(self): @@ -397,9 +395,6 @@ def init_memory_pool( qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, ) - logger.info("using MLA Triton implementaion, flashinfer is disabled") - # FIXME: temporarily only Triton MLA is supported - self.server_args.attention_backend = "triton" else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, @@ -422,106 +417,42 @@ def init_cublas(self): c = a @ b return c - def init_flashinfer(self): - """Init flashinfer attention kernel wrappers.""" - if self.server_args.attention_backend != "flashinfer": - assert ( - self.sliding_window_size is None - ), "turn on flashinfer to support window attention" - self.flashinfer_prefill_wrapper_ragged = None - self.flashinfer_prefill_wrapper_paged = None - self.flashinfer_decode_wrapper = None - return - - if not _grouped_size_compiled_for_decode_kernels( - self.model_config.num_attention_heads // self.tp_size, - self.model_config.get_num_kv_heads(self.tp_size), - ): - use_tensor_cores = True - else: - use_tensor_cores = False - - if self.sliding_window_size is None: - self.flashinfer_workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device="cuda", - ) - self.flashinfer_prefill_wrapper_ragged = ( - BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD" - ) - ) - self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD" - ) - self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_tensor_cores=use_tensor_cores, + def init_attention_backend(self): + """Init attention kernel backend.""" + if self.server_args.attention_backend == "flashinfer": + self.attn_backend = FlashInferAttnBackend(self) + elif self.server_args.attention_backend == "triton": + assert self.sliding_window_size is None, ( + "Window attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." ) + self.attn_backend = TritonAttnBackend(self) else: - self.flashinfer_workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device="cuda", + raise ValueError( + f"Invalid attention backend: {self.server_args.attention_backend}" ) - self.flashinfer_prefill_wrapper_ragged = None - self.flashinfer_prefill_wrapper_paged = [] - self.flashinfer_decode_wrapper = [] - for i in range(2): - self.flashinfer_prefill_wrapper_paged.append( - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD" - ) - ) - self.flashinfer_decode_wrapper.append( - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_tensor_cores=use_tensor_cores, - ) - ) def init_cuda_graphs(self): """Capture cuda graphs.""" + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + + self.cuda_graph_runner = None + if not self.is_generation: # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models return - from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + if self.server_args.disable_cuda_graph: + return - if ( - self.server_args.disable_cuda_graph - or self.server_args.attention_backend != "flashinfer" - ): - self.cuda_graph_runner = None + if self.server_args.attention_backend != "flashinfer": + logger.warning( + f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}" + ) return logger.info("Capture cuda graph begin. This can take up to several minutes.") - - if self.server_args.disable_cuda_graph_padding: - batch_size_list = list(range(1, 32)) + [64, 128] - else: - batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)] - - self.cuda_graph_runner = CudaGraphRunner( - self, - max_batch_size_to_capture=max(batch_size_list), - use_torch_compile=self.server_args.enable_torch_compile, - disable_padding=self.server_args.disable_cuda_graph_padding, - ) - try: - self.cuda_graph_runner.capture(batch_size_list) - except RuntimeError as e: - raise Exception( - f"Capture cuda graph failed: {e}\n" - "Possible solutions:\n" - "1. disable cuda graph by --disable-cuda-graph\n" - "2. set --mem-fraction-static to a smaller value\n" - "3. disable torch compile by not using --enable-torch-compile\n" - "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" - ) + self.cuda_graph_runner = CudaGraphRunner(self) @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 622f27df11..6f6bb61265 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -143,18 +143,16 @@ def update_penalties(self): self.linear_penalties = penalizer.apply(self.linear_penalties) def update_regex_vocab_mask(self, batch: ScheduleBatch): - bs, reqs = batch.batch_size(), batch.reqs - device = "cuda" - has_regex = any(req.regex_fsm is not None for req in reqs) + has_regex = any(req.regex_fsm is not None for req in batch.reqs) # Reset the vocab mask self.vocab_mask = None if has_regex: self.vocab_mask = torch.zeros( - bs, self.vocab_size, dtype=torch.bool, device=device + batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda" ) - for i, req in enumerate(reqs): + for i, req in enumerate(batch.reqs): if req.regex_fsm is not None: self.vocab_mask[i].fill_(1) self.vocab_mask[i][ diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 5bdee03de3..52806daa98 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -335,23 +335,19 @@ def launch_server( return # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) - pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) if server_args.dp_size == 1: start_controller_process = start_controller_process_single else: start_controller_process = start_controller_process_multi - proc_controller = mp.Process( target=start_controller_process, args=(server_args, port_args, pipe_controller_writer), ) proc_controller.start() + pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) proc_detoken = mp.Process( target=start_detokenizer_process, args=( @@ -362,6 +358,10 @@ def launch_server( ) proc_detoken.start() + tokenizer_manager = TokenizerManager(server_args, port_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + # Wait for the model to finish loading controller_init_state = pipe_controller_reader.recv() detoken_init_state = pipe_detoken_reader.recv() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 776f7bec32..36a16bc9f5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -83,8 +83,8 @@ class ServerArgs: json_model_override_args: str = "{}" # Optimization/debug options - attention_backend: str = "flashinfer" - sampling_backend: str = "flashinfer" + attention_backend: Optional[str] = None + sampling_backend: Optional[str] = None disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False @@ -148,6 +148,17 @@ def __post_init__(self): ) self.sampling_backend = "pytorch" + # Default kernel backends + if self.enable_mla: + logger.info("MLA optimization is tunred on. Use triton backend.") + self.attention_backend = "triton" + + if self.attention_backend is None: + self.attention_backend = "flashinfer" + + if self.sampling_backend is None: + self.sampling_backend = "flashinfer" + # Model-specific patches if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: logger.info( diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 2159cca958..8fb0231d81 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -55,8 +55,8 @@ def _run_test(self, batch, max_batch, max_context_len): paged_kernel_lens, kv_indptr, None, - req_to_token.size(1), kv_indices_triton, + req_to_token.size(1), ) # Check diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index e0c851d4ed..65b0b55b92 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -19,7 +19,8 @@ def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size) other_args = [] if disable_radix_cache: other_args.append("--disable-radix-cache") - other_args.extend(["--attention-backend", attention_backend]) + if attention_backend: + other_args.extend(["--attention-backend", attention_backend]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--tensor-parallel-size", "2"]) diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index 1b458e9e6a..81aff3ed2c 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -19,7 +19,8 @@ def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size) other_args = [] if disable_radix_cache: other_args.append("--disable-radix-cache") - other_args.extend(["--attention-backend", attention_backend]) + if attention_backend: + other_args.extend(["--attention-backend", attention_backend]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) model = DEFAULT_MODEL_NAME_FOR_TEST diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 0d094b5571..79b26f67ab 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -96,23 +96,17 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): v_buffer, req_to_tokens, b_req_idx, - b_start_loc, b_seq_len, - b_seq_len_prefix, - b_start_loc_extend, b_seq_len_extend, - max_len_in_batch, + b_start_loc_extend, max_len_extend, ) redundant_attention( q_extend, - k_extend, - v_extend, o_redundant, k_buffer, v_buffer, - req_to_tokens, b_req_idx, b_start_loc, b_seq_len, From 446213777773f217a36b0b415e28c6a8d88d793f Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Wed, 11 Sep 2024 14:40:45 -0700 Subject: [PATCH 32/33] Add no commit to main rule (#1393) --- .pre-commit-config.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2fa1254a66..7489004bd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,6 @@ +default_language_version: + python: python3.9 + repos: - repo: https://github.com/PyCQA/isort rev: 5.13.2 @@ -7,3 +10,8 @@ repos: rev: 24.4.2 hooks: - id: black + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: no-commit-to-branch From 2a71be5e2554e01368b3bc4265db1f7822b0ae3c Mon Sep 17 00:00:00 2001 From: William Date: Thu, 12 Sep 2024 14:46:51 +0800 Subject: [PATCH 33/33] Fix README format (#1399) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d1449abfce..327e965feb 100644 --- a/README.md +++ b/README.md @@ -261,6 +261,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - Exaone 3 - BaiChuan2 - MiniCPM / MiniCPM 3 + + **Embedding Models** - e5-mistral