Skip to content

Commit

Permalink
doc: update the docstring related to alibi (#147)
Browse files Browse the repository at this point in the history
followup of #146
  • Loading branch information
yzh119 authored Mar 3, 2024
1 parent 383518b commit bf2117b
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 17 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="ROPE_LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
qo_len = 2048
Expand Down
3 changes: 0 additions & 3 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,6 @@ def begin_forward(
The dimension of the heads
page_size : int
The page size of the paged kv cache
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
data_type : Union[str, torch.dtype]
The data type of the paged kv cache
Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def single_decode_with_kv_cache(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -168,7 +168,7 @@ def batch_decode_with_padded_kv_cache(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -257,7 +257,7 @@ def batch_decode_with_padded_kv_cache_return_lse(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -456,7 +456,7 @@ def begin_forward(
The page size of the paged kv cache
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
data_type : Union[str, torch.dtype]
The data type of the paged kv cache
Expand Down Expand Up @@ -525,7 +525,7 @@ def forward(
:attr:`kv_layout` is ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -586,7 +586,7 @@ def forward_return_lse(
:attr:`kv_layout` is ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def single_prefill_with_kv_cache(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -191,7 +191,7 @@ def single_prefill_with_kv_cache_return_lse(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -460,7 +460,7 @@ def forward(
Whether to apply causal mask to the attention matrix.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -529,7 +529,7 @@ def forward_return_lse(
Whether to apply causal mask to the attention matrix.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -744,7 +744,7 @@ def forward(
Whether to apply causal mask to the attention matrix.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -811,7 +811,7 @@ def forward_return_lse(
Whether to apply causal mask to the attention matrix.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down

0 comments on commit bf2117b

Please sign in to comment.