Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate PagedAttention Optimization custom kernel into vLLM #22

Merged
merged 13 commits into from
May 30, 2024
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ set(CUSTOM_SRC
"csrc/custom/custom_kernels.cu"
"csrc/custom/fused_kernels.cu"
"csrc/custom/custom.cu"
"csrc/custom/paged_attention/attention_ll4mi.cu"
)

define_gpu_extension_target(
Expand Down
63 changes: 44 additions & 19 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

from vllm._C import ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._custom_C import paged_attention_custom

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
PARTITION_SIZE = 256


@torch.inference_mode()
Expand Down Expand Up @@ -77,6 +78,9 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
if not args.custom_paged_attn:
global PARTITION_SIZE
PARTITION_SIZE = 512
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
tmp_output = torch.empty(
Expand Down Expand Up @@ -118,24 +122,43 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_scale,
)
elif version == "v2":
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
if not args.custom_paged_attn:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
paged_attention_custom(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize()
Expand Down Expand Up @@ -191,6 +214,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument("--custom-paged-attn", action="store_true",
help="Use custom paged attention")
args = parser.parse_args()
print(args)

Expand Down
25 changes: 25 additions & 0 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,36 @@ void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
at::cuda::getCurrentCUDAStream());
}

void paged_attention_custom(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
#if 0
torch::Tensor& qk_out,
torch::Tensor& softmax_out,
#endif
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);

// declare the extension module with the AddGPU function:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.doc() = "pybind11 example plugin";
m.def("LLMM1", &LLMM1);
m.def("LLMM_Silu", &LLMM_Silu);
m.def("LLZZ", &LLZZ);
m.def(
"paged_attention_custom",
&paged_attention_custom,
"PagedAttention LL4Mi Custom.");
//m.def("MMCustomGPU", &MMCustomGPU);
}
Loading
Loading