Skip to content

Commit

Permalink
[Doc] More examples (#103)
Browse files Browse the repository at this point in the history
Add examples for all classes/functions.
  • Loading branch information
yzh119 authored Feb 1, 2024
1 parent 922f0c6 commit 0bedda7
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/cascade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,7 @@ Cascade Attention Wrapper Classes
.. autoclass:: BatchDecodeWithSharedPrefixPagedKVCacheWrapper
:members:


.. autoclass:: BatchPrefillWithSharedPrefixPagedKVCacheWrapper
:members:

3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
"sphinx.ext.mathjax",
]

autodoc_default_flags = ['members']
autosummary_generate = True

source_suffix = [".rst"]

language = "en"
Expand Down
4 changes: 4 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ Prerequisites
^^^^^^^^^^^^^

- OS: Linux only

- Python: 3.10, 3.11

- PyTorch CUDA 11.8/12.1

- Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version.

- Supported GPU architectures: sm_80, sm_86, sm_89, sm_90 (sm_75 support is working in progress).

Quick Start
Expand Down
10 changes: 8 additions & 2 deletions docs/tutorials/kv_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,17 @@ to index the pages in KV-Cache.
:align: center
:alt: Data structure of Paged KV-Cache.

For each request, we keep an record of its ``page_indices``, ``last_page_length``.
For each request, we keep an record of its ``page_indices``, ``last_page_len`` which
tracks the pages used by this request and the number of entries in the last page. The KV
sequence length of request ``i`` is ``page_size * (len(page_indices[i]) - 1) + last_page_length[i]``.

.. note::
The ``last_page_len`` of each request must be greater than zero, and less than or equal to ``page_size``.

The overall ``kv_indptr`` array (with length ``num_requests+1``) can be computed as:
``[0, len(page_indices[0]), len(page_indices[0])+len(page_indices[1]), ...]``.
The overall ``kv_page_indices`` array (with length ``kv_indptr[-1]``) is the concatenation of all requests' ``page_indices``.
The overall ``last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``.
The overall ``kv_last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``.

The ``kv_data`` tensor is a 5-D tensor with shape (in ``NHD`` layout):

Expand Down
223 changes: 218 additions & 5 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def merge_state(
S : torch.Tensor
The logsumexp value from the merged KV-segment ``[A: B]``, shape:
``[batch_size, num_heads]``.
Example
-------
>>> import torch
>>> import flashinfer
>>> seq_len = 2048
>>> num_heads = 32
>>> head_dim = 128
>>> va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
>>> sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
>>> sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> v_merged, s_merged = flashinfer.merge_state(va, sa, vb, sb)
>>> v_merged.shape
torch.Size([2048, 32, 128])
>>> s_merged.shape
torch.Size([2048, 32])
"""
return _kernels.merge_state(v_a, s_a, v_b, s_b)

Expand All @@ -100,22 +117,34 @@ def merge_state_in_place(
s_other : torch.Tensor
The other logsumexp value to be merged, expected to be a float32 tensor,
shape: ``(seq_len, num_heads)``.
Example
-------
>>> import torch
>>> import flashinfer
>>> seq_len = 2048
>>> num_heads = 32
>>> head_dim = 128
>>> v = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
>>> s = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
>>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> flashinfer.merge_state_in_place(v, s, v_other, s_other)
"""
_kernels.merge_state_in_place(v, s, v_other, s_other)


def merge_states(v: torch.Tensor, s: torch.Tensor):
r"""Merge the attention output ``V`` and the logsumexp value ``S`` from multiple
KV-segments.
r"""Merge multiple attention states (v, s).
Parameters
----------
v : torch.Tensor
The attention output from the KV segments, shape:
``[seq_len, num_kv_segments, num_heads, head_dim]``.
``[seq_len, num_states, num_heads, head_dim]``.
s : torch.Tensor
The logsumexp value from the KV segments, shape:
``[seq_len, num_kv_segments, num_heads]``, expected
``[seq_len, num_states, num_heads]``, expected
to be a float32 tensor.
Returns
Expand All @@ -125,6 +154,22 @@ def merge_states(v: torch.Tensor, s: torch.Tensor):
S : torch.Tensor
The logsumexp value from the merged KV-segments, shape:
``[seq_len, num_heads]``.
Example
-------
>>> import torch
>>> import flashinfer
>>> seq_len = 2048
>>> num_heads = 32
>>> head_dim = 128
>>> num_states = 100
>>> v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0")
>>> s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to("cuda:0")
>>> v_merged, s_merged = flashinfer.merge_states(v, s)
>>> v_merged.shape
torch.Size([2048, 32, 128])
>>> s_merged.shape
torch.Size([2048, 32])
"""
return _kernels.merge_states(v, s)

Expand Down Expand Up @@ -184,6 +229,38 @@ def batch_decode_with_shared_prefix_padded_kv_cache(
-------
V : torch.Tensor
The attention output, shape: ``[batch_size, num_heads, head_dim]``
Example
-------
>>> import torch
>>> import flashinfer
>>> shared_prefix_len = 16384
>>> padded_unique_suffix_len = 2048
>>> batch_size = 53
>>> num_qo_heads = 32
>>> num_kv_heads = 32
>>> head_dim = 128
>>> q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
>>> k_shared = torch.randn(shared_prefix_len, num_kv_heads, head_dim).half().to("cuda:0")
>>> v_shared = torch.randn(shared_prefix_len, num_kv_heads, head_dim).half().to("cuda:0")
>>> k_unique = torch.randn(
... batch_size,
... padded_unique_suffix_len,
... num_kv_heads,
... head_dim
... ).half().to("cuda:0")
>>> v_unique = torch.randn(
... batch_size,
... padded_unique_suffix_len,
... num_kv_heads,
... head_dim
... ).half().to("cuda:0")
>>> o = flashinfer.batch_decode_with_shared_prefix_padded_kv_cache(
... q, k_shared, v_shared, k_unique, v_unique, kv_layout="NHD",
... allow_fp16_qk_reduction=True
... )
>>> o.shape
torch.Size([53, 32, 128])
"""
check_kv_layout(kv_layout)
V_shared, S_shared = single_prefill_with_kv_cache_return_lse(
Expand Down Expand Up @@ -218,12 +295,78 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
Check :ref:`our tutorial<page-layout>` for page table layout.
Example
-------
>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 8
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 16MB workspace buffer
>>> workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper(
... workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> shared_prefix_len = 8192
>>> unique_kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0")
>>> unique_kv_page_indptr = torch.tensor(
... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> # 1 <= kv_last_page_len <= page_size
>>> unique_kv_last_page_len = torch.tensor(
... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> unique_kv_data_at_layer = [
... torch.randn(
... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... ) for _ in range(num_layers)
... ]
>>> shared_k_data_at_layer = [
... torch.randn(
... shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... ) for _ in range(num_layers)
... ]
>>> shared_v_data_at_layer = [
... torch.randn(
... shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... ) for _ in range(num_layers)
... ]
>>> # create auxiliary data structures for batch decode attention
>>> wrapper.begin_forward(
... unique_kv_page_indptr,
... unique_kv_page_indices,
... unique_kv_last_page_len,
... num_qo_heads,
... num_kv_heads,
... head_dim,
... page_size,
... data_type=torch.float16
... )
>>> outputs = []
>>> for i in range(num_layers):
... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
... k_shared = shared_k_data_at_layer[i]
... v_shared = shared_v_data_at_layer[i]
... unique_kv_data = unique_kv_data_at_layer[i]
... # compute batch decode attention, reuse auxiliary data structures for all layers
... o = wrapper.forward(q, k_shared, v_shared, unique_kv_data)
... outputs.append(o)
...
>>> # clear auxiliary data structures
>>> wrapper.end_forward()
>>> outputs[0].shape
torch.Size([7, 64, 128])
Note
----
To accelerate computation, FlashInfer's shared prefix batch decode attention creates
some auxiliary data structures, these data structures can be reused across multiple
batch decode attention calls (e.g. different Transformer layers). This wrapper class
manages the lifecycle of these data structures.
manages the lifecycle of these data structures.
"""

def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
Expand Down Expand Up @@ -383,6 +526,76 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
Check :ref:`our tutorial<page-layout>` for paged kv-cache layout.
Example
-------
>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 16
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 16MB workspace buffer
>>> workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> prefill_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper(
... workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> shared_prefix_len = 8192
>>> nnz_qo = 100
>>> qo_indptr = torch.tensor(
... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0"
... )
>>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
>>> paged_kv_indptr = torch.tensor(
... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> # 1 <= paged_kv_last_page_len <= page_size
>>> paged_kv_last_page_len= torch.tensor(
... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> kv_data_at_layer = [
... torch.randn(
... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... ) for _ in range(num_layers)
... ]
>>> shared_k_data_at_layer = [
... torch.randn(
... shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... ) for _ in range(num_layers)
... ]
>>> shared_v_data_at_layer = [
... torch.randn(
... shared_prefix_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... ) for _ in range(num_layers)
... ]
>>> # create auxiliary data structures for batch prefill attention
>>> prefill_wrapper.begin_forward(
... qo_indptr,
... paged_kv_indptr,
... paged_kv_indices,
... paged_kv_last_page_len,
... num_qo_heads,
... num_kv_heads
... )
>>> outputs = []
>>> for i in range(num_layers):
... q = torch.randn(nnz_qo, num_qo_heads, head_dim).half().to("cuda:0")
... kv_data = kv_data_at_layer[i]
... k_shared = shared_k_data_at_layer[i]
... v_shared = shared_v_data_at_layer[i]
... # compute batch prefill attention, reuse auxiliary data structures
... o = prefill_wrapper.forward(
... q, k_shared, v_shared, kv_data, causal=True
... )
... outputs.append(o)
...
s[0].shape>>> # clear auxiliary data structures
>>> prefill_wrapper.end_forward()
>>> outputs[0].shape
torch.Size([100, 64, 128])
Note
----
To accelerate computation, FlashInfer's shared-prefix batch prefill/append attention
Expand Down
Loading

0 comments on commit 0bedda7

Please sign in to comment.