Skip to content

Commit

Permalink
feat flash decoding for paged attention
Browse files Browse the repository at this point in the history
  • Loading branch information
SunflowerAries committed Mar 28, 2024
1 parent c14eede commit 97d2f34
Show file tree
Hide file tree
Showing 10 changed files with 925 additions and 163 deletions.
39 changes: 26 additions & 13 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,19 @@ def forward(
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
sm_scale,
)
else:
decoding_fused_rotary_embedding(
query_states,
Expand All @@ -353,19 +366,19 @@ def forward(
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)

attn_output = torch.mm(attn_output, self.o_proj_weight)

Expand Down
43 changes: 43 additions & 0 deletions extensions/csrc/cuda/attention/attention_generic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2024, The Colossal-AI team.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <stdint.h>

// A vector type to store Q, K, V elements.
template <typename T, int VEC_SIZE>
struct VecType {};

// A vector type to store FP32 accumulators.
template <typename T>
struct FloatVec {};

// Template vector operations.
template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);

template <typename T>
inline __device__ float sum(T v);

template <typename T, typename TFLOAT>
inline __device__ void from_float(T& dst, TFLOAT src);

template <typename T>
inline __device__ void fma(T a, T b, float& c);
Loading

0 comments on commit 97d2f34

Please sign in to comment.