Skip to content

Commit

Permalink
Merge pull request #22 from ROCm/csrikris_pa_opt_shomy_1_16
Browse files Browse the repository at this point in the history
Integrate PagedAttention Optimization custom kernel into vLLM
  • Loading branch information
shajrawi authored May 30, 2024
2 parents c47256c + c774517 commit 87ec0c7
Show file tree
Hide file tree
Showing 7 changed files with 1,313 additions and 41 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ set(CUSTOM_SRC
"csrc/custom/custom_kernels.cu"
"csrc/custom/fused_kernels.cu"
"csrc/custom/custom.cu"
"csrc/custom/paged_attention/attention_ll4mi.cu"
)

define_gpu_extension_target(
Expand Down
6 changes: 6 additions & 0 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ The default attention function on ROCm is using triton attention kernel. To fall
## Tunable ops
Pytorch tunable ops are supported.
Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order to enable both the runtime tuning and the subsequent use of tuned results. To only use the tuned results without tuning any newly encountered shapes, also define `PYTORCH_TUNABLEOP_TUNING=1`

## Custom PagedAttention

On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`.
Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0.
The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel.
64 changes: 45 additions & 19 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch

from vllm._C import ops
from vllm._custom_C import paged_attention_custom
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
PARTITION_SIZE = 256


@torch.inference_mode()
Expand Down Expand Up @@ -77,6 +78,9 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
if not args.custom_paged_attn:
global PARTITION_SIZE
PARTITION_SIZE = 512
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
tmp_output = torch.empty(
Expand Down Expand Up @@ -118,24 +122,43 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_scale,
)
elif version == "v2":
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
if not args.custom_paged_attn:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
paged_attention_custom(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize()
Expand Down Expand Up @@ -191,6 +214,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument("--custom-paged-attn",
action="store_true",
help="Use custom paged attention")
args = parser.parse_args()
print(args)

Expand Down
25 changes: 25 additions & 0 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,36 @@ void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
at::cuda::getCurrentCUDAStream());
}

void paged_attention_custom(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
#if 0
torch::Tensor& qk_out,
torch::Tensor& softmax_out,
#endif
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);

// declare the extension module with the AddGPU function:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.doc() = "pybind11 example plugin";
m.def("LLMM1", &LLMM1);
m.def("LLMM_Silu", &LLMM_Silu);
m.def("LLZZ", &LLZZ);
m.def(
"paged_attention_custom",
&paged_attention_custom,
"PagedAttention LL4Mi Custom.");
//m.def("MMCustomGPU", &MMCustomGPU);
}
Loading

0 comments on commit 87ec0c7

Please sign in to comment.