Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate PagedAttention Optimization custom kernel into vLLM #22

Merged
merged 13 commits into from
May 30, 2024
Next Next commit
initial commit for v0.4.0 with paged attn optimization
lcskrishna committed May 23, 2024
commit f58458882a6995d74133260d570bf92c30929d68
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -196,6 +196,7 @@ set(CUSTOM_SRC
"csrc/custom/custom_kernels.cu"
"csrc/custom/fused_kernels.cu"
"csrc/custom/custom.cu"
"csrc/custom/paged_attention/attention_ll4mi.cu"
)

define_gpu_extension_target(
25 changes: 25 additions & 0 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
@@ -64,11 +64,36 @@ void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
at::cuda::getCurrentCUDAStream());
}

void paged_attention_custom(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
#if 0
torch::Tensor& qk_out,
torch::Tensor& softmax_out,
#endif
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);

// declare the extension module with the AddGPU function:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.doc() = "pybind11 example plugin";
m.def("LLMM1", &LLMM1);
m.def("LLMM_Silu", &LLMM_Silu);
m.def("LLZZ", &LLZZ);
m.def(
"paged_attention_custom",
&paged_attention_custom,
"PagedAttention LL4Mi Custom.");
//m.def("MMCustomGPU", &MMCustomGPU);
}
Loading