diff --git a/.gitignore b/.gitignore
index 1266124b..e42c8015 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,7 +2,16 @@
python/csrc/generated/
python/flashinfer/_build_meta.py
+# Generated documentation files
+docs/generated
+
+# DS_Store files
+.DS_store
+
+# Microbenchmark files
microbenchmark/
+
+# vscode
.vscode/
# Byte-compiled / optimized / DLL files
diff --git a/docs/_static/FlashInfer-black-background.png b/docs/_static/FlashInfer-black-background.png
new file mode 100644
index 00000000..79eccd90
Binary files /dev/null and b/docs/_static/FlashInfer-black-background.png differ
diff --git a/docs/_static/FlashInfer-white-background.png b/docs/_static/FlashInfer-white-background.png
new file mode 100644
index 00000000..d0936895
Binary files /dev/null and b/docs/_static/FlashInfer-white-background.png differ
diff --git a/docs/api/python/cascade.rst b/docs/api/python/cascade.rst
new file mode 100644
index 00000000..dd095a6c
--- /dev/null
+++ b/docs/api/python/cascade.rst
@@ -0,0 +1,38 @@
+.. _apicascade:
+
+flashinfer.cascade
+==================
+
+.. currentmodule:: flashinfer.cascade
+
+.. _api-merge-states:
+
+Merge Attention States
+----------------------
+
+.. autosummary::
+ :toctree: ../../generated
+
+ merge_state
+ merge_state_in_place
+ merge_states
+
+.. _api-cascade-attention:
+
+Cascade Attention
+-----------------
+
+.. autosummary::
+ :toctree: ../../generated
+
+ batch_decode_with_shared_prefix_padded_kv_cache
+
+
+Cascade Attention Wrapper Classes
+---------------------------------
+
+.. autoclass:: BatchDecodeWithSharedPrefixPagedKVCacheWrapper
+ :members:
+
+.. autoclass:: BatchPrefillWithSharedPrefixPagedKVCacheWrapper
+ :members:
diff --git a/docs/api/python/decode.rst b/docs/api/python/decode.rst
new file mode 100644
index 00000000..ca972664
--- /dev/null
+++ b/docs/api/python/decode.rst
@@ -0,0 +1,26 @@
+.. _apidecode:
+
+flashinfer.decode
+=================
+
+.. currentmodule:: flashinfer.decode
+
+Single Request Decoding
+-----------------------
+
+.. autosummary::
+ :toctree: ../../generated
+
+ single_decode_with_kv_cache
+
+Batch Decoding
+--------------
+
+.. autosummary::
+ :toctree: ../../generated
+
+ batch_decode_with_padded_kv_cache
+ batch_decode_with_padded_kv_cache_return_lse
+
+.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
+ :members:
diff --git a/docs/api/python/page.rst b/docs/api/python/page.rst
new file mode 100644
index 00000000..66e64f68
--- /dev/null
+++ b/docs/api/python/page.rst
@@ -0,0 +1,16 @@
+.. _apipage:
+
+flashinfer.page
+===============
+
+Kernels to manipulte paged kv-cache.
+
+.. currentmodule:: flashinfer.page
+
+Append new K/V tensors to Paged KV-Cache
+----------------------------------------
+
+.. autosummary::
+ :toctree: ../../generated
+
+ append_paged_kv_cache
diff --git a/docs/api/python/prefill.rst b/docs/api/python/prefill.rst
new file mode 100644
index 00000000..9f50f195
--- /dev/null
+++ b/docs/api/python/prefill.rst
@@ -0,0 +1,27 @@
+.. _apiprefill:
+
+flashinfer.prefill
+==================
+
+Attention kernels for prefill & append attention in both single request and batch serving setting.
+
+.. currentmodule:: flashinfer.prefill
+
+Single Request Prefill/Append Attention
+---------------------------------------
+
+.. autosummary::
+ :toctree: ../../generated
+
+ single_prefill_with_kv_cache
+ single_prefill_with_kv_cache_return_lse
+
+Batch Prefill/Append Attention
+------------------------------
+
+.. autoclass:: BatchPrefillWithPagedKVCacheWrapper
+ :members:
+
+.. autoclass:: BatchPrefillWithRaggedKVCacheWrapper
+ :members:
+
\ No newline at end of file
diff --git a/docs/conf.py b/docs/conf.py
index 534fc4e3..55345c89 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,7 +1,7 @@
import os
import sys
-import tlcpack_sphinx_addon
+# import tlcpack_sphinx_addon
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
@@ -10,6 +10,10 @@
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
+sys.path.insert(0, os.path.abspath("../python"))
+os.environ["BUILD_DOC"] = "1"
+autodoc_mock_imports = ["torch"]
+
project = 'FlashInfer'
author = "FlashInfer Contributors"
footer_copyright = '2023-2024, {}'.format(author)
@@ -22,11 +26,10 @@
extensions = [
"sphinx_tabs.tabs",
- "sphinx_toolbox.collapse",
- "sphinxcontrib.httpdomain",
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
- "sphinx_reredirects",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.mathjax",
]
source_suffix = [".rst"]
@@ -44,11 +47,7 @@
# -- Options for HTML output ----------------------------------------------
-# The theme is set by the make target
-import sphinx_rtd_theme
-
-html_theme = "sphinx_rtd_theme"
-html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
+html_theme = "furo" #"sphinx_rtd_theme"
templates_path = []
@@ -60,27 +59,8 @@
"logo_only": True,
}
-header_links = [
- ("Home", "https://flashinfer.ai"),
- ("Github", "https://github.com/flashinfer-ai/flashinfer"),
- ("Discussions", "https://github.com/orgs/flashinfer-ai/discussions"),
-]
-
-html_context = {
- "footer_copyright": footer_copyright,
- "footer_note": footer_note,
- "header_links": header_links,
- "display_github": True,
- "github_user": "flashinfer-ai",
- "github_repo": "flashinfer",
- "github_version": "main/docs/",
- "theme_vcs_pageview_mode": "edit",
- # "header_logo": "/path/to/logo",
- # "header_logo_link": "",
- # "version_selecter": "",
+html_static_path = ["_static"]
+html_theme_options = {
+ "light_logo": "FlashInfer-white-background.png",
+ "dark_logo": "FlashInfer-black-background.png",
}
-
-# add additional overrides
-templates_path += [tlcpack_sphinx_addon.get_templates_path()]
-html_static_path += [tlcpack_sphinx_addon.get_static_path()]
-
diff --git a/docs/index.rst b/docs/index.rst
index 3a23a81a..d171a797 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -6,15 +6,29 @@
Welcome to FlashInfer's documentation!
======================================
+`Blog `_ | `Discussion Forum `_ | `GitHub `_
+
+FlashInfer is a library for Language Languages Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, PageAttention and LoRA. FlashInfer focus on LLM serving and inference, and delivers state-the-art performance across diverse scenarios.
+
.. toctree::
:maxdepth: 2
- :caption: Contents:
+ :caption: Get Started
+ installation
+.. toctree::
+ :maxdepth: 2
+ :caption: Tutorials
-Indices and tables
-==================
+ tutorials/recursive_attention
+ tutorials/kv_layout
+
+.. toctree::
+ :maxdepth: 2
+ :caption: PyTorch API Reference
-* :ref:`genindex`
-* :ref:`modindex`
-* :ref:`search`
+ api/python/decode
+ api/python/prefill
+ api/python/cascade
+ api/python/page
+
\ No newline at end of file
diff --git a/docs/installation.rst b/docs/installation.rst
new file mode 100644
index 00000000..3ad79832
--- /dev/null
+++ b/docs/installation.rst
@@ -0,0 +1,47 @@
+.. _installation:
+
+Installation
+============
+
+Python Package
+--------------
+FlashInfer is available as a Python package, built on top of `PyTorch `_ to
+easily integrate with your python applications.
+
+Prerequisites
+^^^^^^^^^^^^^
+
+- OS: Linux only
+- Python: 3.10, 3.11
+- PyTorch CUDA 11.8/12.1
+ - Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version.
+- Supported GPU architectures: sm_80, sm_86, sm_89, sm_90 (sm_75 support is working in progress).
+
+Quick Start
+^^^^^^^^^^^
+
+.. tabs::
+ .. tab:: PyTorch CUDA 11.8
+
+ .. code-block:: bash
+
+ pip install flashinfer -i https://flashinfer.ai/whl/cu118/
+
+ .. tab:: PyTorch CUDA 12.1
+
+ .. code-block:: bash
+
+ pip install flashinfer -i https://flashinfer.ai/whl/cu121/
+
+
+C++ API
+-------
+
+FlashInfer is a header-only library with only CUDA/C++ standard library dependency
+that can be directly integrated into your C++ project without installation.
+
+You can check our `unittest and benchmarks `_ on how to use our C++ APIs at the moment.
+
+.. note::
+ The ``nvbench`` and ``googletest`` dependency in ``3rdparty`` directory are only
+ used to compile unittests and benchmarks, and are not required for the library itself.
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 0604736c..7717b04c 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,9 +1,7 @@
sphinx-tabs == 3.4.1
-sphinx-rtd-theme
-sphinx == 5.2.3
+sphinx == 7.2.6
sphinx-toolbox == 3.4.0
-tlcpack-sphinx-addon==0.2.2
-sphinxcontrib_httpdomain==1.8.1
-sphinxcontrib-napoleon==0.7
-sphinx-reredirects==0.1.2
-
+sphinxcontrib_httpdomain == 1.8.1
+sphinxcontrib-napoleon == 0.7
+sphinx-reredirects == 0.1.2
+furo == 2024.01.29
diff --git a/docs/tutorials/kv_layout.rst b/docs/tutorials/kv_layout.rst
new file mode 100644
index 00000000..8ee06282
--- /dev/null
+++ b/docs/tutorials/kv_layout.rst
@@ -0,0 +1,98 @@
+.. _kv-layout:
+
+KV-Cache Layout in FlashInfer
+=============================
+
+Layout: NHD/HND
+---------------
+
+FlashInfer provides two layouts for last 3 dimensions in KV-Cache: ``NHD`` and ``HND``:
+
+- ``NHD``: the last 3 dimensions are organized as ``(seq_len, num_heads, head_dim)``.
+- ``HND``: the last 3 dimensions are organized as ``(num_heads, head_dim, seq_len)``.
+
+The ``NHD`` layout is more natural because it's consistent with the output of
+:math:`xW_k` and :math:`xW_v` without transpose. The ``HND`` layout is more friendly
+for GPU implementation when KV-Cache uses low-precision data type (e.g. fp8).
+In practice we don't observe significant performance difference between these two layouts
+on fp16 kV-Cache and we prioritize ``NHD`` layout for better readability. FlashInfer implements
+Attention kernels on both layouts and we provide an option to select between them (``NHD``
+by default).
+
+.. _ragged-layout:
+
+Ragged Tensor
+-------------
+
+In batched inference/serving, the input sequence length may vary across different samples.
+When there is no need to change the sequence length (e.g. in prefilling stage), we can use ``RaggedTensor`` to store
+the key/value tensors in KV-Cache:
+
+.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/ragged.png
+ :width: 400
+ :align: center
+ :alt: Data structure of Ragged KV-Cache.
+
+The keys (or values) of all requests are packed into a single ``data`` tensor without padding,
+we use a ``indptr`` array (``num_requests+1`` elements, the first element is always zero)
+to store the information of variable sequence lengths of each request
+(``indptr[i+1]-indptr[i]`` is the sequence length of request ``i``), the ``data`` tensor has
+shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``.
+
+We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``.
+
+.. _page-layout:
+
+FlashInfer APIs
+~~~~~~~~~~~~~~~
+
+FlashInfer provides :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper` to compute
+the prefill attention between queries stored in ragged tensor and keys/values stored in ragged
+KV-Cache.
+
+Page Table
+----------
+
+When KV-Cache is dynamic (e.g. in append or decode stage), packing all keys/values is not
+efficient because the sequence length per request changes over time. `vLLM `_
+proposes to organize KV-Cache as a Page Table. In FlashInfer, we treat the page-table as
+a block sparse matrix (each used page can be viewed as an non-zero block in block sparse matrix)
+and uses the `CSR format `_
+to index the pages in KV-Cache.
+
+.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/page_layout.png
+ :width: 800
+ :align: center
+ :alt: Data structure of Paged KV-Cache.
+
+For each request, we keep an record of its ``page_indices``, ``last_page_length``.
+The overall ``kv_indptr`` array (with length ``num_requests+1``) can be computed as:
+``[0, len(page_indices[0]), len(page_indices[0])+len(page_indices[1]), ...]``.
+The overall ``kv_page_indices`` array (with length ``kv_indptr[-1]``) is the concatenation of all requests' ``page_indices``.
+The overall ``last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``.
+
+The ``kv_data`` tensor is a 5-D tensor with shape (in ``NHD`` layout):
+
+.. code::
+
+ (max_num_pages, 2, page_size, num_heads, head_dim)
+
+where ``max_num_pages`` is the maximum number of pages used by all requests, ``page_size`` is the number of tokens
+we fit into each page. ``2`` is the number of slots in each page (first one for keys, the second one for values).
+
+FlashInfer APIs
+~~~~~~~~~~~~~~~
+
+:meth:`flashinfer.page.append_paged_kv_cache` can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache
+(the pages for these appended keys/values must be allocated prior to calling this API).
+
+:class:`BatchDecodeWithPagedKVCacheWrapper` and :class:`BatchPrefillWithPagedKVCacheWrapper` implements the decode attention
+and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache.
+
+FAQ
+^^^
+
+How do FlashInfer manages KV-Cache?
+ FlashInfer itself is not responsible for managing the page-table (pop and allocate new pages, etc.) and we leave the strategy
+ to the user: different serving engine might have different strategies to manage the page-table. FlashInfer is only responsible
+ for computing the attention between queries and keys/values stored in KV-Cache.
diff --git a/docs/tutorials/recursive_attention.rst b/docs/tutorials/recursive_attention.rst
new file mode 100644
index 00000000..702d306e
--- /dev/null
+++ b/docs/tutorials/recursive_attention.rst
@@ -0,0 +1,80 @@
+.. _recursive-attention:
+
+Attention States and Recursive form of Self-Attention
+=====================================================
+
+
+FlashInfer introduces the concept of **attention states**, which fully characterizes
+the attention between a query and a set of key/value pairs. We further defines a
+**merge** operator on the **attention states**. This merge operator facilitates the
+computation of complete attention by allowing the recursive merging of attention states.
+
+Suppose we define :math:`s_i = \mathbf{q}\mathbf{k}_i^T` as the pre-softmax attention
+score between the query :math:`\mathbf{q}` and the key :math:`\mathbf{k}_i`. The Self-Attention
+score on index :math:`i` can be generalized to index set :math:`I`:
+
+.. math::
+
+ s(I)=\log\left(\sum_{i\in I}\exp\left(s_i\right)\right)
+
+We can also generalize the value on index :math:`i` to index set :math:`I`:
+
+.. math::
+
+ \mathbf{v}(I)=\frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))}
+
+The *attention state* of the index set :math:`i` can be defined as a tuple :math:`(s(I), \mathbf{v}(I))`.
+
+Then we can define the **merge** operator :math:`\oplus` of two attention states as:
+
+.. math::
+
+ \begin{bmatrix}\mathbf{v}(I\cup J)\\s(I\cup J)\end{bmatrix}=\begin{bmatrix}\mathbf{v}(I)\\s(I)\end{bmatrix}\oplus\begin{bmatrix}\mathbf{v}(J)\\s(J)\end{bmatrix}=\begin{bmatrix} \frac{\mathbf{v}(I)\exp(s(I)) + \mathbf{v}(J)\exp(s(J))}{\exp(s(I)) + \exp(s(J))} \\ \log(\exp(s(I)) + \exp(s(J))) \end{bmatrix}
+
+The **attention state** on the entire sequence can be defined as:
+
+.. math::
+
+ \begin{bmatrix}\mathbf{v}(\{1,2,\dots, n\})\\s(\{1,2,\dots, n\})\end{bmatrix} = \bigoplus_{i=1}^{n} \begin{bmatrix}\mathbf{v}_i\\s_i\end{bmatrix}
+
+Then $\mathbf{v}(\{1,2,\dots, n\})$ is the final attention output.
+
+.. note::
+
+ The generalized score :math:`s` is also known as log-sum-exp (``lse`` for short).
+
+Applications
+------------
+
+Note that :math:`\oplus` operator is **commutative** and **associative**, which means we can
+safely offload the self-attention computation on a subset of KV to different devices
+and **merge** the results **in any order**.
+
+There are several interesting applications of this recursive form of self-attention in FlashInfer so far:
+
+Shared-Prefix Batch Decoding
+ Many LLM applications involves batch decoding with the shared long prompt, FlashInfer decomposes attention
+ on the entire KV-Cache to shared prefix attention and unique suffixes attention.
+ This decomposition enables the offloading of these components to different kernel implementations, resulting
+ in a remarkable 30x acceleration in scenarios with long context and large batch-size.
+ Such decomposition accelerates the operator by 30 times in long context setting.
+ Check `our blog post `_ on more details about this application,
+ and :ref:`api-cascade-attention` on how to use this feature in FlashInfer.
+
+KV Sequence Parallelism
+ For long context LLM inference/serving, the batch size and number of heads per GPU is limited by the GPU memory,
+ and the default parallelism strategy cannot use all SMs in GPUs, which results in suboptimal performance.
+ Inspired by `Split-K `_ trick
+ in GEMM optimizations. FlashInfer partitions the KV sequence dimension and dispatches the attention computations to
+ different thread-blocks and merge them in a second pass. This same idea was also proposed in Flash-Decoding, you can
+ check their great `blog post `_ for visualizations and more details.
+
+Related APIs
+------------
+
+FlashInfer exposes several APIs to facilitate the recursive attention computation:
+
+- :ref:`api-merge-states` defines the operators to merge attention states.
+- :ref:`apiprefill` and :ref:`apidecode` defines operators that returns attention states (APIs
+ with suffix ``_return_lse`` returns both attention output :math:`v` and score :math:`s`).
+
diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py
index 596ce38d..1d5b8890 100644
--- a/python/flashinfer/cascade.py
+++ b/python/flashinfer/cascade.py
@@ -15,10 +15,20 @@
"""
import math
from typing import Optional
-
import torch
-from . import _kernels
+try:
+ from . import _kernels
+except ImportError as e:
+ import os
+ import logging
+
+ if os.environ.get("BUILD_DOC", "0") == "1":
+ _kernels = None
+ logging.warning("Kernels are not loaded in documentation build mode.")
+ else:
+ raise e
+
from .decode import (
batch_decode_with_padded_kv_cache_return_lse,
BatchDecodeWithPagedKVCacheWrapper,
@@ -39,31 +49,33 @@
def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
):
- r"""Merge the attention output (V) and the logsumexp value (S) from the two KV-segments.
+ r"""Merge the attention output ``V`` and the logsumexp value ``S`` from the two
+ KV-segments.
+ Check :ref:`our tutorial ` on the mathematical details.
Parameters
----------
v_a : torch.Tensor
- The attention output from the KV segment A.
- Shape: [seq_len, num_heads, head_dim]
+ The attention output from the KV segment ``A``, shape:
+ ``[seq_len, num_heads, head_dim]``.
s_a : torch.Tensor
- The logsumexp value from the KV segment A. Expected to be a float32 tensor.
- Shape: [seq_len, num_heads]
+ The logsumexp value from the KV segment ``A``. expected to be a float32 tensor,
+ shape: ``[seq_len, num_heads]``.
v_b : torch.Tensor
- The attention output from the KV segment B.
- Shape: [seq_len, num_heads, head_dim]
+ The attention output from the KV segment ``B``,
+ shape: ``[seq_len, num_heads, head_dim]``.
s_b : torch.Tensor
- The logsumexp value from the KV segment B. Expected to be a float32 tensor.
- Shape: [seq_len, num_heads]
+ The logsumexp value from the KV segment ``B``, expected to be a float32 tensor,
+ shape: ``[seq_len, num_heads]``
Returns
-------
V : torch.Tensor
- The merged attention output (equivalent to attention with merged KV-segment [A: B]).
- Shape: [batch_size, num_heads, head_dim]
+ The merged attention output (equivalent to attention with merged KV-segment
+ ``[A: B]``), shape: ``[batch_size, num_heads, head_dim]``.
S : torch.Tensor
- The logsumexp value from the merged KV-segment [A: B].
- Shape: [batch_size, num_heads]
+ The logsumexp value from the merged KV-segment ``[A: B]``, shape:
+ ``[batch_size, num_heads]``.
"""
return _kernels.merge_state(v_a, s_a, v_b, s_b)
@@ -71,46 +83,48 @@ def merge_state(
def merge_state_in_place(
v: torch.Tensor, s: torch.Tensor, v_other: torch.Tensor, s_other: torch.Tensor
):
- r"""Merge the self-attention state (v, s) with another state (v_other, s_other) in-place.
+ r"""Merge the self-attention state ``(v, s)`` with another state
+ ``(v_other, s_other)`` in-place.
Parameters
----------
v : torch.Tensor
- The partial v to be updated in-place.
- Shape: (seq_len, num_heads, head_dim)
+ The partial attention output to be updated in-place, shape:
+ ``(seq_len, num_heads, head_dim)``.
s : torch.Tensor
- The partial logsumexpr value to be updated in-place, expected to be a float32 tensor.
- Shape: (seq_len, num_heads)
+ The partial logsumexpr value to be updated in-place, expected to be a float32
+ tensor, shape: ``(seq_len, num_heads)``.
v_other : torch.Tensor
- The other v to be merged.
- Shape: (seq_len, num_heads, head_dim)
+ The other attention output to be merged, shape:
+ ``(seq_len, num_heads, head_dim)``.
s_other : torch.Tensor
- The other logsumexp value to be merged, expected to be a float32 tensor.
- Shape: (seq_len, num_heads)
+ The other logsumexp value to be merged, expected to be a float32 tensor,
+ shape: ``(seq_len, num_heads)``.
"""
_kernels.merge_state_in_place(v, s, v_other, s_other)
def merge_states(v: torch.Tensor, s: torch.Tensor):
- r"""Merge the attention output (V) and the logsumexp value (S) from multiple KV-segments.
+ r"""Merge the attention output ``V`` and the logsumexp value ``S`` from multiple
+ KV-segments.
Parameters
----------
v : torch.Tensor
- The attention output from the KV segments.
- Shape: [seq_len, num_kv_segments, num_heads, head_dim]
+ The attention output from the KV segments, shape:
+ ``[seq_len, num_kv_segments, num_heads, head_dim]``.
s : torch.Tensor
- The logsumexp value from the KV segments.
- Shape: [seq_len, num_kv_segments, num_heads]
+ The logsumexp value from the KV segments, shape:
+ ``[seq_len, num_kv_segments, num_heads]``, expected
+ to be a float32 tensor.
Returns
-------
V : torch.Tensor
- The merged attention output.
- Shape: [seq_len, num_heads, head_dim]
+ The merged attention output, shape: ``[seq_len, num_heads, head_dim]``.
S : torch.Tensor
- The logsumexp value from the merged KV-segments.
- Shape: [seq_len, num_heads]
+ The logsumexp value from the merged KV-segments, shape:
+ ``[seq_len, num_heads]``.
"""
return _kernels.merge_states(v, s)
@@ -127,40 +141,49 @@ def batch_decode_with_shared_prefix_padded_kv_cache(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
- r"""Batch decode with shared prefix padded KV cache.
+ r"""Decode attention between queries and shared prefix kv-cache for batch of
+ requests.
Parameters
----------
q : torch.Tensor
- Shape: [batch_size, num_qo_heads, head_dim]
+ The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``.
k_shared : torch.Tensor
- Shape: [shared_prefix_len, num_kv_heads, head_dim] if NHD
- [num_kv_heads, shared_prefix_len, head_dim] if HND
+ The shared prefix key tensor, shape:
+ ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``,
+ or ``[num_kv_heads, shared_prefix_len, head_dim]`` if :attr:`kv_layout` is
+ ``HND``.
v_shared : torch.Tensor
- Shape: [shared_prefix_len, num_kv_heads, head_dim] if NHD
- [num_kv_heads, shared_prefix_len, head_dim] if HND
+ The shared prefix value tensor, shape:
+ ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``,
+ or ``[num_kv_heads, shared_prefix_len, head_dim]`` if :attr:`kv_layout` is
+ ``HND``.
k_unique : torch.Tensor
- Shape: [batch_size, unique_len, num_kv_heads, head_dim] if NHD
- [batch_size, num_kv_heads, unique_len, head_dim] if HND
+ The request-independent suffix key tensor, shape:
+ ``[batch_size, unique_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is
+ ``NHD``, or ``[batch_size, num_kv_heads, unique_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
v_unique : torch.Tensor
- Shape: [batch_size, unique_len, num_kv_heads, head_dim] if NHD
- [batch_size, num_kv_heads, unique_len, head_dim] if HND
+ The request-independent suffix value tensor, shape:
+ ``[batch_size, unique_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is
+ ``NHD``, or ``[batch_size, num_kv_heads, unique_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
kv_layout : str
- The layout of the input k/v tensors, could be either "NHD" or "HND".
+ The layout of the kv-cache, could be either "NHD" or "HND".
allow_fp16_qk_reduction : bool
- Whether to use f16 for qk reduction (could be significantly faster for GeForce cards, at
- the cost of slight precision loss).
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
sm_scale : Optional[float]
- The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim)
+ The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``
rope_scale : Optional[float]
- The scale used in RoPE interpolation, if not provided, will be set to 1.0.
+ The scale used in RoPE interpolation, if not provided, will be set to ``1.0``.
rope_theta : Optional[float]
- The theta used in RoPE, if not provided, will be set to 1e4.
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
Returns
-------
V : torch.Tensor
- Shape: [batch_size, num_heads, head_dim]
+ The attention output, shape: ``[batch_size, num_heads, head_dim]``
"""
check_kv_layout(kv_layout)
V_shared, S_shared = single_prefill_with_kv_cache_return_lse(
@@ -190,6 +213,19 @@ def batch_decode_with_shared_prefix_padded_kv_cache(
class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
+ r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch
+ of requests.
+
+ Check :ref:`our tutorial` for page table layout.
+
+ Note
+ ----
+ To accelerate computation, FlashInfer's shared prefix batch decode attention creates
+ some auxiliary data structures, these data structures can be reused across multiple
+ batch decode attention calls (e.g. different Transformer layers). This wrapper class
+ manages the lifecycle of these data structures.
+ """
+
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
self._batch_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
@@ -197,6 +233,14 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
self._kv_layout = kv_layout
def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
+ r"""Reset the workspace buffer.
+
+ Parameters
+ ----------
+ new_workspace_buffer : torch.Tensor
+ The new workspace buffer, the device of the new workspace buffer should
+ be the same as the device of the input tensors.
+ """
self._batch_decode_wrapper.reset_workspace_buffer(new_workspace_buffer)
def begin_forward(
@@ -210,6 +254,43 @@ def begin_forward(
page_size: int,
data_type: str = "float16",
):
+ r"""Create auxiliary data structures for shared-prefix batch decode for multiple
+ forward calls within the same decode step.
+
+ Parameters
+ ----------
+ indptr : torch.Tensor
+ The indptr of the paged kv cache, shape: ``[batch_size + 1]``
+ indices : torch.Tensor
+ The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]``
+ last_page_len : torch.Tensor
+ The number of entries in the last page of each request in the paged kv
+ cache, shape: ``[batch_size]``
+ num_qo_heads : int
+ The number of query/output heads
+ num_kv_heads : int
+ The number of key/value heads
+ head_dim : int
+ The dimension of the heads
+ page_size : int
+ The page size of the paged kv cache
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ data_type : Union[str, torch.dtype]
+ The data type of the paged kv cache
+
+ Note
+ ----
+ The :meth:`begin_forward` method should be called before any :meth:`forward` or
+ :meth:`forward_return_lse` calls,
+ auxiliary data structures will be created during this call and cached for
+ multiple forward calls.
+
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
+ is not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
+ """
self._batch_decode_wrapper.begin_forward(
unique_kv_indptr,
unique_kv_indices,
@@ -223,6 +304,7 @@ def begin_forward(
)
def end_forward(self):
+ r"""Clear auxiliary data structures created by :meth:`begin_forward`."""
self._batch_decode_wrapper.end_forward()
def forward(
@@ -235,6 +317,44 @@ def forward(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
+ r"""Compute batch decode attention between queries and shared-prefix paged
+ kv-cache.
+
+ Parameters
+ ----------
+ q : torch.Tensor
+ The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``.
+ k_shared : torch.Tensor
+ The shared prefix key tensor, shape:
+ ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is
+ ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ v_shared : torch.Tensor
+ The shared prefix value tensor, shape:
+ ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is
+ ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ unique_kv_data : torch.Tensor
+ A 5-D tensor of paged kv-cache data storing the request-independent suffix
+ key and value tensors, shape:
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
+ :attr:`kv_layout` is ``NHD``, or
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ allow_fp16_qk_reduction : bool
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to
+ ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ V : torch.Tensor
+ The attention output, shape: ``[batch_size, num_heads, head_dim]``
+ """
V_shared, S_shared = single_prefill_with_kv_cache_return_lse(
q,
k_shared,
@@ -258,13 +378,45 @@ def forward(
class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
+ r"""Wrapper class for prefill/append attention with shared-prefix paged kv-cache for
+ batch of requests.
+
+ Check :ref:`our tutorial` for paged kv-cache layout.
+
+ Note
+ ----
+ To accelerate computation, FlashInfer's shared-prefix batch prefill/append attention
+ operators creates some auxiliary data structures, these data structures can be
+ reused across multiple prefill/append attention calls (e.g. different Transformer
+ layers). This wrapper class manages the lifecycle of these data structures.
+ """
+
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
+ r"""Constructor of :class:`BatchDecodeWithSharedPrefixPagedKVCacheWrapper`.
+
+ Parameters
+ ----------
+ workspace_buffer : torch.Tensor
+ The user reserved workspace buffer used to store auxiliary data structures,
+ recommended size is 16MB, the device of the workspace buffer should be the
+ same as the device of the input tensors.
+ kv_layout : str
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ """
self._batch_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
self._kv_layout = kv_layout
def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
+ r"""Reset the workspace buffer.
+
+ Parameters
+ ----------
+ new_workspace_buffer : torch.Tensor
+ The new workspace buffer, the device of the new workspace buffer should
+ be the same as the device of the input tensors.
+ """
self._batch_prefill_wrapper.reset_workspace_buffer(new_workspace_buffer)
def begin_forward(
@@ -276,6 +428,36 @@ def begin_forward(
num_qo_heads: int,
num_kv_heads: int,
):
+ r"""Create auxiliary data structures for shared-prefix batch prefill/append
+ attention for multiple forward calls within the same prefill/append step.
+
+ Parameters
+ ----------
+ qo_indptr : torch.Tensor
+ The indptr of the query/output tensor, shape: ``[batch_size + 1]``.
+ paged_kv_indptr : torch.Tensor
+ The indptr of the paged kv-cache, shape: ``[batch_size + 1]``.
+ paged_kv_indices : torch.Tensor
+ The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``.
+ paged_kv_last_page_len : torch.Tensor
+ The number of entries in the last page of each request in the paged
+ kv-cache, shape: ``[batch_size]``.
+ num_qo_heads : int
+ The number of query/output heads.
+ num_kv_heads : int
+ The number of key/value heads.
+
+ Notes
+ -----
+ The :meth:`begin_forward` method should be called before any :meth:`forward`
+ or :meth:`forward_return_lse` calls, auxiliary data structures will be created
+ during this call and cached for multiple forward calls.
+
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
+ is not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
+ """
+
self._batch_prefill_wrapper.begin_forward(
qo_indptr,
paged_kv_indptr,
@@ -286,6 +468,7 @@ def begin_forward(
)
def end_forward(self):
+ r"""Clear the auxiliary data structures created by :meth:`begin_forward`."""
self._batch_prefill_wrapper.end_forward()
def forward(
@@ -299,6 +482,46 @@ def forward(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
+ r"""Compute batch prefill/append attention between query and shared-prefix paged
+ kv-cache.
+
+ Parameters
+ ----------
+ q : torch.Tensor
+ The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
+ k_shared : torch.Tensor
+ The shared prefix key tensor, shape:
+ ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is
+ ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ v_shared ; torch.Tensor
+ The shared prefix value tensor, shape:
+ ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is
+ ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ unique_kv_data : torch.Tensor
+ A 5-D tensor of paged kv-cache data storing the request-independent suffix
+ key and value tensors, shape:
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
+ :attr:`kv_layout` is ``NHD``, or
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ causal : bool
+ Whether to apply causal mask on the attention matrix.
+ allow_fp16_qk_reduction : bool
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to
+ ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ V : torch.Tensor
+ The attention output, shape: ``[qo_indptr[-1], num_heads, head_dim]``.
+ """
V_shared, S_shared = single_prefill_with_kv_cache_return_lse(
q,
k_shared,
diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py
index b2014a2c..d4f42cc7 100644
--- a/python/flashinfer/decode.py
+++ b/python/flashinfer/decode.py
@@ -15,10 +15,21 @@
"""
import math
from typing import Optional, Union
-
import torch
-from . import _kernels
+try:
+ from . import _kernels
+except ImportError as e:
+ import os
+ import logging
+
+ if os.environ.get("BUILD_DOC", "0") == "1":
+ _kernels = None
+ logging.warning("Kernels are not loaded in documentation build mode.")
+ else:
+ raise e
+
+
from .utils import (
RotaryMode,
TensorLayout,
@@ -43,35 +54,64 @@ def single_decode_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
- rotary_mode: str = "NONE",
kv_layout: str = "NHD",
+ rotary_mode: str = "NONE",
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
- r"""Single request decode with KV cache.
+ r"""Decode attention with KV Cache for single request, return attention output.
Parameters
----------
q : torch.Tensor
- Shape: [num_qo_heads, head_dim]
+ The query tensor, shape: ``[num_qo_heads, head_dim]``.
k : torch.Tensor
- Shape: [kv_len, num_kv_heads, head_dim] if NHD
- [num_kv_heads, kv_len, head_dim] if HND
+ The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is
+ ``HND``.
v : torch.Tensor
- Shape: [kv_len, num_kv_heads, head_dim] if NHD
- [num_kv_heads, kv_len, head_dim] if HND
- rotary_mode : str
- Whether to apply rotary embeddings inside attention kernels, could be
- "NONE" or "LLAMA".
+ The value tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if
+ :attr:`kv_layout` is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
kv_layout : str
- The layout of the input k/v tensors, could be either "NHD" or "HND".
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
sm_scale : Optional[float]
- The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim)
+ The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
- The scale used in RoPE interpolation, if not provided, will be set to 1.0.
+ The scale used in RoPE interpolation, if not provided, will be set to ``1.0``.
rope_theta : Optional[float]
- The theta used in RoPE, if not provided, will be set to 1e4.
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ torch.Tensor
+ The attention output, shape: ``[num_qo_heads, head_dim]``
+
+ Examples
+ --------
+
+ >>> import torch
+ >>> import flashinfer
+ >>> kv_len = 4096
+ >>> num_qo_heads = 32
+ >>> num_kv_heads = 32
+ >>> head_dim = 128
+ >>> q = torch.randn(num_qo_heads, head_dim).half().to("cuda:0")
+ >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
+ >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
+ >>> o = flashinfer.single_decode_with_kv_cache(q, k, v)
+ >>> o.shape
+ torch.Size([32, 128])
+
+ Notes
+ -----
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is
+ not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
"""
check_rotary_mode(rotary_mode)
check_kv_layout(kv_layout)
@@ -106,23 +146,45 @@ def batch_decode_with_padded_kv_cache(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
- r"""Batch decode with padded KV cache.
+ r"""Decode attention with padded KV cache for batch of requests, return attention
+ output.
Parameters
----------
q : torch.Tensor
- Shape: [batch_size, num_qo_heads, head_dim]
+ The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``.
k_padded : torch.Tensor
- Shape: [batch_size, padded_seq_len, num_kv_heads, head_dim] if NHD
- [batch_size, num_kv_heads, padded_seq_len, head_dim] if HND
+ The padded key tensor, shape:
+ ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
v_padded : torch.Tensor
- Shape: [batch_size, padded_seq_len, num_kv_heads, head_dim] if NHD
- [batch_size, num_kv_heads, padded_seq_len, head_dim] if HND
+ The padded value tensor, shape:
+ ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ kv_layout : str
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ sm_scale : Optional[float]
+ The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
Returns
-------
- V : torch.Tensor
- Shape: [batch_size, num_heads, head_dim]
+ torch.Tensor
+ The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``.
+
+ Notes
+ -----
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is
+ not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
"""
if sm_scale is None:
head_dim = q.shape[-1]
@@ -154,36 +216,48 @@ def batch_decode_with_padded_kv_cache_return_lse(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
- r"""
+ r"""Decode attention with padded KV cache for batch of requests, return attention
+ output and logsumexp of attention scores, return attention output and logsumexp of
+ attention scores.
+
Parameters
----------
q : torch.Tensor
- Shape: [batch_size, num_qo_heads, head_dim]
+ The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``.
k_padded : torch.Tensor
- Shape: [batch_size, padded_seq_len, num_kv_heads, head_dim] if NHD
- [batch_size, num_kv_heads, padded_seq_len, head_dim] if HND
+ The padded key tensor, shape:
+ ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
v_padded : torch.Tensor
- Shape: [batch_size, padded_seq_len, num_kv_heads, head_dim] if NHD
- [batch_size, num_kv_heads, padded_seq_len, head_dim] if HND
- kv_layout: str
- The layout of the input k_padded/v_padded tensors, could be either
- "NHD" or "HND"
- rotary_mode: str
- Whether to apply rotary embeddings inside attention kernels, could be
- "NONE" or "LLAMA".
- sm_scale: Optional[float]
- The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim)
- rope_scale: Optional[float]
- The scale used in RoPE interpolation, if not provided, will be set to 1.0.
- rope_theta: Optional[float]
- The theta used in RoPE, if not provided, will be set to 1e4.
+ The padded value tensor, shape:
+ ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ kv_layout : str
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ sm_scale : Optional[float]
+ The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
Returns
-------
V : torch.Tensor
- Shape: [batch_size, num_heads, head_dim]
+ The attention output, shape: [batch_size, num_qo_heads, head_dim]
S : torch.Tensor
- Shape: [batch_size, num_heads]
+ The logsumexp of attention scores, Shape: [batch_size, num_qo_heads]
+
+ Notes
+ -----
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is
+ not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
"""
if sm_scale is None:
head_dim = q.shape[-1]
@@ -206,15 +280,31 @@ def batch_decode_with_padded_kv_cache_return_lse(
class BatchDecodeWithPagedKVCacheWrapper:
- r"""Wrapper class for batch_decode_with_paged_kv_cache kernel.
+ r"""Wrapper class for decode attention with paged kv-cache (first proposed in
+ `vLLM `_) for batch of requests.
+
+ Check :ref:`our tutorial` for page table layout.
- To accelerate computation, FlashInfer's batch decode operators creates some
+ Note
+ ----
+ To accelerate computation, FlashInfer's batch decode attention creates some
auxiliary data structures, these data structures can be reused across multiple
- batch decode calls (e.g. different Transformer layers). This wrapper class manages
- the lifecycle of these data structures.
+ batch decode attention calls (e.g. different Transformer layers). This wrapper class
+ manages the lifecycle of these data structures.
"""
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
+ r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`.
+
+ Parameters
+ ----------
+ workspace_buffer : torch.Tensor
+ The user reserved workspace buffer used to store auxiliary data structures,
+ recommended size is 16MB, the device of the workspace buffer should be the
+ same as the device of the input tensors.
+ kv_layout : str
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ """
check_kv_layout(kv_layout)
self._kv_layout = kv_layout
self._workspace_buffer = workspace_buffer
@@ -226,6 +316,14 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
self._paged_kv_last_page_len = None
def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
+ r"""Reset the workspace buffer.
+
+ Parameters
+ ----------
+ new_workspace_buffer : torch.Tensor
+ The new workspace buffer, the device of the new workspace buffer should
+ be the same as the device of the input tensors.
+ """
self._workspace_buffer = new_workspace_buffer
def begin_forward(
@@ -240,9 +338,41 @@ def begin_forward(
rotary_mode: str = "NONE",
data_type: Union[str, torch.dtype] = "float16",
):
- r"""The begin_forward method should be called before any batch decode calls,
- auxiliary data structures will be created during this call and cached for
- multiple forward calls.
+ r"""Create auxiliary data structures for batch decode for multiple forward calls
+ within the same decode step.
+
+ Parameters
+ ----------
+ indptr : torch.Tensor
+ The indptr of the paged kv cache, shape: ``[batch_size + 1]``
+ indices : torch.Tensor
+ The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]``
+ last_page_len : torch.Tensor
+ The number of entries in the last page of each request in the paged kv
+ cache, shape: ``[batch_size]``
+ num_qo_heads : int
+ The number of query/output heads
+ num_kv_heads : int
+ The number of key/value heads
+ head_dim : int
+ The dimension of the heads
+ page_size : int
+ The page size of the paged kv cache
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ data_type : Union[str, torch.dtype]
+ The data type of the paged kv cache
+
+ Note
+ ----
+ The :meth:`begin_forward` method should be called before any :meth:`forward` or
+ :meth:`forward_return_lse` calls, auxiliary data structures will be created
+ during this call and cached for multiple forward calls.
+
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
+ is not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
"""
self._paged_kv_indptr = indptr
self._paged_kv_indices = indices
@@ -270,7 +400,7 @@ def begin_forward(
)
def end_forward(self):
- r"""The end_forward method can clear the cached data structures."""
+ r"""Clear auxiliary data structures created by :meth:`begin_forward`."""
self._paged_kv_indptr = None
self._paged_kv_indices = None
self._paged_kv_last_page_len = None
@@ -284,6 +414,32 @@ def forward(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
+ r"""Compute batch decode attention between query and paged kv cache.
+
+ Parameters
+ ----------
+ q : torch.Tensor
+ The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``
+ paged_kv_data : torch.Tensor
+ A 5-D tensor of the reserved paged kv-cache data, shape:
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
+ :attr:`kv_layout` is ``NHD``, or
+ ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to
+ ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ torch.Tensor
+ The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``.
+ """
check_rotary_mode(rotary_mode)
if rope_scale is None:
rope_scale = 1.0
@@ -310,6 +466,35 @@ def forward_return_lse(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
+ r"""Compute batch decode attention with paged kv cache, return attention output
+ and logsumexp of attention scores.
+
+ Parameters
+ ----------
+ q : torch.Tensor
+ The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``
+ paged_kv_data : torch.Tensor
+ A 5-D tensor of the reserved paged kv-cache data, shape:
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
+ :attr:`kv_layout` is ``NHD``, or
+ ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to
+ ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ V : torch.Tensor
+ The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``.
+ S : torch.Tensor
+ The logsumexp of attention scores, Shape: ``[batch_size, num_qo_heads]``.
+ """
check_rotary_mode(rotary_mode)
if rope_scale is None:
rope_scale = 1.0
diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py
index 9160907e..d7944a78 100644
--- a/python/flashinfer/page.py
+++ b/python/flashinfer/page.py
@@ -15,8 +15,17 @@
"""
import torch
-from . import _kernels
-from .utils import TensorLayout, check_kv_layout
+try:
+ from . import _kernels
+except ImportError as e:
+ import os
+ import logging
+
+ if os.environ.get("BUILD_DOC", "0") == "1":
+ _kernels = None
+ logging.warning("Kernels are not loaded in documentation build mode.")
+ else:
+ raise e
def append_paged_kv_cache(
@@ -29,6 +38,40 @@ def append_paged_kv_cache(
kv_last_page_len: torch.Tensor,
kv_layout: str = "NHD",
):
+ r"""Append a batch of key-value pairs to a paged key-value cache.
+
+ Parameters
+ ----------
+ append_key : torch.Tensor
+ The key tensor to append in ragged tensor format, shape:
+ ``[append_indptr[-1], num_kv_heads, head_dim]``.
+ append_value : torch.Tensor
+ The value tensor to append in ragged tensor format, shape:
+ ``[append_indptr[-1], num_kv_heads, head_dim]``.
+ append_indptr : torch.Tensor
+ The indptr tensor of the key-value pairs to append, shape: ``[batch_size + 1]``.
+ kv_data : torch.Tensor
+ The 5-D tensor of the paged key-value cache, shape:
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
+ :attr:`kv_layout` is ``NHD``, or
+ ``[max_num_pages, 2, num_kv_heads, page_size, num_kv_heads]`` if
+ :attr:`kv_layout` is ``NHD``.
+ kv_indices : torch.Tensor
+ The page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]``.
+ kv_indptr : torch.Tensor
+ The indptr of the paged kv-cache, shape: ``[batch_size + 1]``.
+ kv_last_page_len : torch.Tensor
+ The number of entries in the last page of each request in the paged kv cache,
+ shape: ``[batch_size]``.
+ kv_layout : str
+ The layout of the paged kv-cache, either ``NHD`` or ``HND``.
+
+ Notes
+ -----
+ The function assumes that the space for appended k/v have already been allocated,
+ which means :attr:`kv_indices`, :attr:`kv_indptr`, :attr:`kv_last_page_len` has
+ incorporated appended k/v.
+ """
check_kv_layout(kv_layout)
_kernels.append_paged_kv_cache(
append_key,
diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py
index 76d2e1ac..0aaac623 100644
--- a/python/flashinfer/prefill.py
+++ b/python/flashinfer/prefill.py
@@ -15,10 +15,20 @@
"""
import math
from typing import Optional
-
import torch
-from . import _kernels
+try:
+ from . import _kernels
+except ImportError as e:
+ import os
+ import logging
+
+ if os.environ.get("BUILD_DOC", "0") == "1":
+ _kernels = None
+ logging.warning("Kernels are not loaded in documentation build mode.")
+ else:
+ raise e
+
from .utils import (
RotaryMode,
TensorLayout,
@@ -54,39 +64,70 @@ def single_prefill_with_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
- rotary_mode: str = "NONE",
kv_layout: str = "NHD",
+ rotary_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
- r"""Single request prefill with KV cache kernel.
+ r"""Prefill/Append attention with KV cache for single request, return the attention
+ output.
Parameters
----------
q : torch.Tensor
- Shape: [qo_len, num_qo_heads, head_dim] if NHD
- [num_qo_heads, qo_len, head_dim] if HND
+ The query tensor, shape: ``[qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
- Shape: [kv_len, num_kv_heads, head_dim] if NHD
- [num_kv_heads, kv_len, head_dim] if HND
+ The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is
+ ``HND``.
v : torch.Tensor
- Shape: [kv_len, num_kv_heads, head_dim] if NHD
- [num_kv_heads, kv_len, head_dim] if HND
+ The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD``, ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is
+ ``HND``.
causal : bool
Whether to apply causal mask to the attention matrix.
- rotary_mode : str
- Whether to apply rotary embeddings inside attention kernels, could be
- "NONE" or "LLAMA".
kv_layout : str
- The layout of the input k/v tensors, could be either "NHD" or "HND".
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
allow_fp16_qk_reduction : bool
- Whether to use f16 for qk reduction (could be significantly faster for GeForce cards, at
- the cost of precision loss).
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
rope_scale : Optional[float]
The scale used in RoPE interpolation, if not provided, will be set to 1.0.
rope_theta : Optional[float]
The theta used in RoPE, if not provided, will be set to 1e4.
+
+ Returns
+ -------
+ torch.Tensor
+ The attention output, shape: ``[qo_len, num_qo_heads, head_dim]``.
+
+ Examples
+ --------
+
+ >>> import torch
+ >>> import flashinfer
+ >>> qo_len = 128
+ >>> kv_len = 4096
+ >>> num_qo_heads = 32
+ >>> num_kv_heads = 4
+ >>> head_dim = 128
+ >>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0")
+ >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
+ >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
+ >>> o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True,
+ allow_fp16_qk_reduction=True)
+ >>> o.shape
+ torch.Size([128, 32, 128])
+
+ Notes
+ -----
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is
+ not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
"""
check_rotary_mode(rotary_mode)
check_kv_layout(kv_layout)
@@ -115,49 +156,73 @@ def single_prefill_with_kv_cache_return_lse(
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
- rotary_mode: str = "NONE",
kv_layout: str = "NHD",
+ rotary_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
- r"""Single request prefill with KV cache kernel, return logsumexp value.
+ r"""Prefill/Append attention with KV cache for single request, return attention
+ output and logsumexp of attention scores.
Parameters
----------
q : torch.Tensor
- Shape: [qo_len, num_qo_heads, head_dim] if NHD
- [num_qo_heads, qo_len, head_dim] if HND
+ The query tensor, shape: ``[qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
- Shape: [kv_len, num_kv_heads, head_dim] if NHD
- [num_kv_heads, kv_len, head_dim] if HND
+ The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is
+ ``HND``.
v : torch.Tensor
- Shape: [kv_len, num_kv_heads, head_dim] if NHD
- [num_kv_heads, kv_len, head_dim] if HND
+ The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout`
+ is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is
+ ``HND``.
causal : bool
Whether to apply causal mask to the attention matrix.
- rotary_mode : str
- Whether to apply rotary embeddings inside attention kernels, could be
- "NONE" or "LLAMA".
kv_layout : str
- The layout of the input k/v tensors, could be either "NHD" or "HND".
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
allow_fp16_qk_reduction : bool
- Whether to use f16 for qk reduction (could be significantly faster for GeForce cards, at
- the cost of precision loss).
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
rope_scale : Optional[float]
- The scale used in RoPE interpolation, if not provided, will be set to 1.0.
+ The scale used in RoPE interpolation, if not provided, will be set to ``1.0``.
rope_theta : Optional[float]
- The theta used in RoPE, if not provided, will be set to 1e4.
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
Returns
-------
V : torch.Tensor
- The attention output.
- Shape: [qo_len, num_qo_heads, head_dim] if NHD
- [num_qo_heads, qo_len, head_dim] if HND
+ The attention output, shape: ``[qo_len, num_qo_heads, head_dim]``.
S : torch.Tensor
- The logsumexp value.
- Shape: [qo_len, num_qo_heads]
+ The logsumexp value, shape: ``[qo_len, num_qo_heads]``
+
+ Examples
+ --------
+
+ >>> import torch
+ >>> import flashinfer
+ >>> qo_len = 128
+ >>> kv_len = 4096
+ >>> num_qo_heads = 32
+ >>> num_kv_heads = 4
+ >>> head_dim = 128
+ >>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0")
+ >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
+ >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0")
+ >>> V, S = flashinfer.single_prefill_with_kv_cache_return_lse(q, k, v, causal=True)
+ >>> V.shape
+ torch.Size([128, 32, 128])
+ >>> S.shape
+ torch.Size([128, 32])
+
+ Notes
+ -----
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is
+ not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
"""
check_rotary_mode(rotary_mode)
check_kv_layout(kv_layout)
@@ -184,9 +249,31 @@ def single_prefill_with_kv_cache_return_lse(
class BatchPrefillWithPagedKVCacheWrapper:
- r"""Wrapper class of batch_prefill_with_paged_kv_cache kernel."""
+ r"""Wrapper class for prefill/append attention with paged kv-cache for batch of
+ requests.
+
+ Check :ref:`our tutorial` for page table layout.
+
+ Note
+ ----
+ To accelerate computation, FlashInfer's batch prefill/append attention operators
+ creates some auxiliary data structures, these data structures can be reused across
+ multiple prefill/append attention calls (e.g. different Transformer layers). This
+ wrapper class manages the lifecycle of these data structures.
+ """
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
+ r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`.
+
+ Parameters
+ ----------
+ workspace_buffer : torch.Tensor
+ The user reserved workspace buffer used to store auxiliary data structures,
+ recommended size is 16MB, the device of the workspace buffer should be the
+ same as the device of the input tensors.
+ kv_layout : str
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ """
check_kv_layout(kv_layout)
self._kv_layout = kv_layout
self._workspace_buffer = workspace_buffer
@@ -199,6 +286,14 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
self._paged_kv_last_page_len = None
def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
+ r"""Reset the workspace buffer.
+
+ Parameters
+ ----------
+ new_workspace_buffer : torch.Tensor
+ The new workspace buffer, the device of the new workspace buffer should
+ be the same as the device of the input tensors.
+ """
self._workspace_buffer = new_workspace_buffer
def begin_forward(
@@ -210,6 +305,35 @@ def begin_forward(
num_qo_heads: int,
num_kv_heads: int,
):
+ r"""Create auxiliary data structures for batch prefill/append attention for
+ multiple forward calls within the same prefill/append step.
+
+ Parameters
+ ----------
+ qo_indptr : torch.Tensor
+ The indptr of the query/output tensor, shape: ``[batch_size + 1]``.
+ paged_kv_indptr : torch.Tensor
+ The indptr of the paged kv-cache, shape: ``[batch_size + 1]``.
+ paged_kv_indices : torch.Tensor
+ The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``.
+ paged_kv_last_page_len : torch.Tensor
+ The number of entries in the last page of each request in the paged
+ kv-cache, shape: ``[batch_size]``.
+ num_qo_heads : int
+ The number of query/output heads.
+ num_kv_heads : int
+ The number of key/value heads.
+
+ Notes
+ -----
+ The :meth:`begin_forward` method should be called before any :meth:`forward` or
+ :meth:`forward_return_lse` calls, auxiliary data structures will be created
+ during this call and cached for multiple forward calls.
+
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
+ is not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
+ """
batch_size = len(qo_indptr) - 1
self._qo_indptr = qo_indptr
self._paged_kv_indptr = paged_kv_indptr
@@ -220,6 +344,7 @@ def begin_forward(
)
def end_forward(self):
+ r"""Clear the auxiliary data structures created by :meth:`begin_forward`."""
self._qo_indptr = None
self._paged_kv_indptr = None
self._paged_kv_indices = None
@@ -236,6 +361,37 @@ def forward(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
+ r"""Compute batch prefill/append attention between query and paged kv-cache.
+
+ Parameters
+ ----------
+ q : torch.Tensor
+ The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``
+ paged_kv_data : torch.Tensor
+ A 5-D tensor of the reserved paged kv-cache data, shape:
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]``
+ if :attr:`kv_layout` is ``NHD``, or
+ ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]``
+ if :attr:`kv_layout` is ``HND``.
+ causal : bool
+ Whether to apply causal mask to the attention matrix.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ allow_fp16_qk_reduction : bool
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to
+ ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ torch.Tensor
+ The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
+ """
check_rotary_mode(rotary_mode)
if rope_scale is None:
rope_scale = 1.0
@@ -267,6 +423,40 @@ def forward_return_lse(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
+ r"""Compute batch prefill/append attention paged kv-cache.
+
+ Parameters
+ ----------
+ q : torch.Tensor
+ The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``
+ paged_kv_data : torch.Tensor
+ A 5-D tensor of the reserved paged kv-cache data, shape:
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]``
+ if :attr:`kv_layout` is ``NHD``, or
+ ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
+ :attr:`kv_layout` is ``HND``.
+ causal : bool
+ Whether to apply causal mask to the attention matrix.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ allow_fp16_qk_reduction : bool
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to
+ ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ V : torch.Tensor
+ The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
+ S : torch.Tensor
+ The logsumexp of attention output, shape:
+ ``[qo_indptr[-1], num_qo_heads, head_dim]``.
+ """
check_rotary_mode(rotary_mode)
if rope_scale is None:
rope_scale = 1.0
@@ -290,9 +480,31 @@ def forward_return_lse(
class BatchPrefillWithRaggedKVCacheWrapper:
- r"""Wrapper class of batch_prefill_with_ragged_kv_cache kernel."""
+ r"""Wrapper class for prefill/append attention with ragged (tensor) kv-cache for
+ batch of requests.
+
+ Check :ref:`our tutorial` for ragged kv-cache layout.
+
+ Note
+ ----
+ To accelerate computation, FlashInfer's batch prefill/append attention operators
+ creates some auxiliary data structures, these data structures can be reused across
+ multiple prefill/append attention calls (e.g. different Transformer layers). This
+ wrapper class manages the lifecycle of these data structures.
+ """
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
+ r"""Constructor of :class:`BatchDecodeWithRaggedKVCacheWrapper`.
+
+ Parameters
+ ----------
+ workspace_buffer : torch.Tensor
+ The user reserved workspace buffer used to store auxiliary data structures,
+ recommended size is 16MB, the device of the workspace buffer should be the
+ same as the device of the input tensors.
+ kv_layout : str
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
+ """
check_kv_layout(kv_layout)
self._kv_layout = kv_layout
self._workspace_buffer = workspace_buffer
@@ -303,6 +515,14 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
self._kv_indptr = None
def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
+ r"""Reset the workspace buffer.
+
+ Parameters
+ ----------
+ new_workspace_buffer : torch.Tensor
+ The new workspace buffer, the device of the new workspace buffer should
+ be the same as the device of the input tensors.
+ """
self._workspace_buffer = new_workspace_buffer
def begin_forward(
@@ -312,6 +532,30 @@ def begin_forward(
num_qo_heads: int,
num_kv_heads: int,
):
+ r"""Create auxiliary data structures for batch prefill/append attention for
+ multiple forward calls within the same prefill/append step.
+
+ Parameters
+ ----------
+ qo_indptr : torch.Tensor
+ The indptr of the query/output tensor, shape: ``[batch_size + 1]``.
+ kv_indptr : torch.Tensor
+ The indptr of the key/value tensor, shape: ``[batch_size + 1]``.
+ num_qo_heads : int
+ The number of query/output heads.
+ num_kv_heads : int
+ The number of key/value heads.
+
+ Notes
+ -----
+ The :meth:`begin_forward` method should be called before any :meth:`forward` or
+ :meth:`forward_return_lse` calls, auxiliary data structures will be created
+ during this call and cached for multiple forward calls.
+
+ The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
+ is not equal to ``num_kv_heads``, the function will use
+ `grouped query attention `_.
+ """
batch_size = len(qo_indptr) - 1
self._qo_indptr = qo_indptr
self._kv_indptr = kv_indptr
@@ -320,6 +564,7 @@ def begin_forward(
)
def end_forward(self):
+ r"""Clear the auxiliary data structures created by :meth:`begin_forward`."""
self._qo_indptr = None
self._kv_indptr = None
self._wrapper.end_forward()
@@ -335,6 +580,36 @@ def forward(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
+ r"""Compute batch prefill/append attention between query and kv-cache stored in
+ ragged tensor.
+
+ Parameters
+ ----------
+ q : torch.Tensor
+ The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``
+ k : torch.Tensor
+ The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]``
+ v : torch.Tensor
+ The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]``
+ causal : bool
+ Whether to apply causal mask to the attention matrix.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ allow_fp16_qk_reduction : bool
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to
+ ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ torch.Tensor
+ The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
+ """
check_rotary_mode(rotary_mode)
if rope_scale is None:
rope_scale = 1.0
@@ -365,6 +640,38 @@ def forward_return_lse(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
):
+ r"""Compute batch prefill/append attention between query and kv-cache stored in
+ ragged tensor. Return attention output and logsumexp of attention scores.
+
+ Parameters
+ ----------
+ q : torch.Tensor
+ The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``
+ k : torch.Tensor
+ The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]``
+ v : torch.Tensor
+ The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]``
+ causal : bool
+ Whether to apply causal mask to the attention matrix.
+ rotary_mode : str
+ Whether to apply RoPE on-the-fly inside attention kernels, could be
+ ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
+ allow_fp16_qk_reduction : bool
+ Whether to use f16 for qk reduction (faster at the cost of slight precision
+ loss).
+ rope_scale : Optional[float]
+ The scale used in RoPE interpolation, if not provided, will be set to ``1.0``.
+ rope_theta : Optional[float]
+ The theta used in RoPE, if not provided, will be set to ``1e4``.
+
+ Returns
+ -------
+ V : torch.Tensor
+ The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
+ S : torch.Tensor
+ The logsumexp of attention output, shape:
+ ``[qo_indptr[-1], num_qo_heads, head_dim]``.
+ """
check_rotary_mode(rotary_mode)
if rope_scale is None:
rope_scale = 1.0