Skip to content

Commit

Permalink
[Doc] Fix typo (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Jan 31, 2024
1 parent 354ab22 commit 0e5482f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
28 changes: 18 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Kernel Library for LLM Serving
FlashInfer is a library for Language Languages Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, PageAttention and LoRA. FlashInfer focus on LLM serving and inference, and delivers state-the-art performance across diverse scenarios.

The unique features of FlashInfer include:
1. **Comprehensive Attention Kernels:**: Attention kernels that cover all the common use cases of LLM serving, including *single-request* and *batching* versions of *Prefill*, *Decode*, and *Append* kernels, on different formats of KV-Cache (Padded Tensor, Ragged Tensor, and Page Table).
1. **Comprehensive Attention Kernels**: Attention kernels that cover all the common use cases of LLM serving, including *single-request* and *batching* versions of *Prefill*, *Decode*, and *Append* kernels, on different formats of KV-Cache (Padded Tensor, Ragged Tensor, and Page Table).
2. **Optimized Shared-Prefix Batch Decoding**: FlashInfer enhances shared-prefix batch decoding performance through *cascading*, resulting in an impressive **up to 31x speedup** compared to the baseline vLLM PageAttention implementation (for long prompt of 32768 tokens and large batch size of 256).
3. **Accelerate Attention for Compressed/Quantized KV-Cache**: Modern LLMs are often deployed with quantized/compressed KV-Cache to reduce memory traffic. FlashInfer accelerates these scenarios by optimizing performance for *Grouped-Query Attention*, *Fused-RoPE Attention* and *Quantized Attention*.

Expand All @@ -38,7 +38,7 @@ Using our PyTorch API is the easiest way to get started:
We provide prebuilt wheels for Linux and you can try out FlashInfer with the following command:

```bash
pip install flashinfer -f https://flashinfer.ai/whl/
pip install flashinfer -i https://flashinfer.ai/whl/cu121/ # for CUDA 12.1, use cu118 for CUDA 11.8
```

or you can build from source:
Expand All @@ -57,23 +57,31 @@ Below is a minimal example of using FlashInfer's single-request decode/append/pr
import torch
import flashinfer

k = torch.randn(2048, 32, 128).half().to(0)
v = torch.randn(2048, 32, 128).half().to(0)
kv_len = 2048
num_kv_heads = 32
head_dim = 128

k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)
v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)

# decode attention
q = torch.randn(32, 128).half().to(0)

num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, rotary_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
q = torch.randn(128, 32, 128).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, rotary_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, rotary_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
q = torch.randn(2048, 32, 128).half().to(0) # prefill attention
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill attention without RoPE on-the-fly
qo_len = 2048
q = torch.randn(qo_len, num_qo_heads, head_dim).half().to(0) # prefill attention
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill attention without RoPE on-the-fly, do not apply causal mask
```

Check out [documentation](https://docs.flashinfer.ai/) for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels.
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/recursive_attention.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ The **attention state** on the entire sequence can be defined as:
\begin{bmatrix}\mathbf{v}(\{1,2,\dots, n\})\\s(\{1,2,\dots, n\})\end{bmatrix} = \bigoplus_{i=1}^{n} \begin{bmatrix}\mathbf{v}_i\\s_i\end{bmatrix}
Then $\mathbf{v}(\{1,2,\dots, n\})$ is the final attention output.
Then :math:`\mathbf{v}(\{1,2,\dots, n\})` is the final attention output.

.. note::

Expand Down

0 comments on commit 0e5482f

Please sign in to comment.