diff --git a/.gitignore b/.gitignore index 1266124b..e42c8015 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,16 @@ python/csrc/generated/ python/flashinfer/_build_meta.py +# Generated documentation files +docs/generated + +# DS_Store files +.DS_store + +# Microbenchmark files microbenchmark/ + +# vscode .vscode/ # Byte-compiled / optimized / DLL files diff --git a/docs/_static/FlashInfer-black-background.png b/docs/_static/FlashInfer-black-background.png new file mode 100644 index 00000000..79eccd90 Binary files /dev/null and b/docs/_static/FlashInfer-black-background.png differ diff --git a/docs/_static/FlashInfer-white-background.png b/docs/_static/FlashInfer-white-background.png new file mode 100644 index 00000000..d0936895 Binary files /dev/null and b/docs/_static/FlashInfer-white-background.png differ diff --git a/docs/api/python/cascade.rst b/docs/api/python/cascade.rst new file mode 100644 index 00000000..dd095a6c --- /dev/null +++ b/docs/api/python/cascade.rst @@ -0,0 +1,38 @@ +.. _apicascade: + +flashinfer.cascade +================== + +.. currentmodule:: flashinfer.cascade + +.. _api-merge-states: + +Merge Attention States +---------------------- + +.. autosummary:: + :toctree: ../../generated + + merge_state + merge_state_in_place + merge_states + +.. _api-cascade-attention: + +Cascade Attention +----------------- + +.. autosummary:: + :toctree: ../../generated + + batch_decode_with_shared_prefix_padded_kv_cache + + +Cascade Attention Wrapper Classes +--------------------------------- + +.. autoclass:: BatchDecodeWithSharedPrefixPagedKVCacheWrapper + :members: + +.. autoclass:: BatchPrefillWithSharedPrefixPagedKVCacheWrapper + :members: diff --git a/docs/api/python/decode.rst b/docs/api/python/decode.rst new file mode 100644 index 00000000..ca972664 --- /dev/null +++ b/docs/api/python/decode.rst @@ -0,0 +1,26 @@ +.. _apidecode: + +flashinfer.decode +================= + +.. currentmodule:: flashinfer.decode + +Single Request Decoding +----------------------- + +.. autosummary:: + :toctree: ../../generated + + single_decode_with_kv_cache + +Batch Decoding +-------------- + +.. autosummary:: + :toctree: ../../generated + + batch_decode_with_padded_kv_cache + batch_decode_with_padded_kv_cache_return_lse + +.. autoclass:: BatchDecodeWithPagedKVCacheWrapper + :members: diff --git a/docs/api/python/page.rst b/docs/api/python/page.rst new file mode 100644 index 00000000..66e64f68 --- /dev/null +++ b/docs/api/python/page.rst @@ -0,0 +1,16 @@ +.. _apipage: + +flashinfer.page +=============== + +Kernels to manipulte paged kv-cache. + +.. currentmodule:: flashinfer.page + +Append new K/V tensors to Paged KV-Cache +---------------------------------------- + +.. autosummary:: + :toctree: ../../generated + + append_paged_kv_cache diff --git a/docs/api/python/prefill.rst b/docs/api/python/prefill.rst new file mode 100644 index 00000000..9f50f195 --- /dev/null +++ b/docs/api/python/prefill.rst @@ -0,0 +1,27 @@ +.. _apiprefill: + +flashinfer.prefill +================== + +Attention kernels for prefill & append attention in both single request and batch serving setting. + +.. currentmodule:: flashinfer.prefill + +Single Request Prefill/Append Attention +--------------------------------------- + +.. autosummary:: + :toctree: ../../generated + + single_prefill_with_kv_cache + single_prefill_with_kv_cache_return_lse + +Batch Prefill/Append Attention +------------------------------ + +.. autoclass:: BatchPrefillWithPagedKVCacheWrapper + :members: + +.. autoclass:: BatchPrefillWithRaggedKVCacheWrapper + :members: + \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 534fc4e3..55345c89 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,7 +1,7 @@ import os import sys -import tlcpack_sphinx_addon +# import tlcpack_sphinx_addon # Configuration file for the Sphinx documentation builder. # # For the full list of built-in configuration values, see the documentation: @@ -10,6 +10,10 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information +sys.path.insert(0, os.path.abspath("../python")) +os.environ["BUILD_DOC"] = "1" +autodoc_mock_imports = ["torch"] + project = 'FlashInfer' author = "FlashInfer Contributors" footer_copyright = '2023-2024, {}'.format(author) @@ -22,11 +26,10 @@ extensions = [ "sphinx_tabs.tabs", - "sphinx_toolbox.collapse", - "sphinxcontrib.httpdomain", "sphinx.ext.autodoc", "sphinx.ext.napoleon", - "sphinx_reredirects", + "sphinx.ext.autosummary", + "sphinx.ext.mathjax", ] source_suffix = [".rst"] @@ -44,11 +47,7 @@ # -- Options for HTML output ---------------------------------------------- -# The theme is set by the make target -import sphinx_rtd_theme - -html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] +html_theme = "furo" #"sphinx_rtd_theme" templates_path = [] @@ -60,27 +59,8 @@ "logo_only": True, } -header_links = [ - ("Home", "https://flashinfer.ai"), - ("Github", "https://github.com/flashinfer-ai/flashinfer"), - ("Discussions", "https://github.com/orgs/flashinfer-ai/discussions"), -] - -html_context = { - "footer_copyright": footer_copyright, - "footer_note": footer_note, - "header_links": header_links, - "display_github": True, - "github_user": "flashinfer-ai", - "github_repo": "flashinfer", - "github_version": "main/docs/", - "theme_vcs_pageview_mode": "edit", - # "header_logo": "/path/to/logo", - # "header_logo_link": "", - # "version_selecter": "", +html_static_path = ["_static"] +html_theme_options = { + "light_logo": "FlashInfer-white-background.png", + "dark_logo": "FlashInfer-black-background.png", } - -# add additional overrides -templates_path += [tlcpack_sphinx_addon.get_templates_path()] -html_static_path += [tlcpack_sphinx_addon.get_static_path()] - diff --git a/docs/index.rst b/docs/index.rst index 3a23a81a..d171a797 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,15 +6,29 @@ Welcome to FlashInfer's documentation! ====================================== +`Blog `_ | `Discussion Forum `_ | `GitHub `_ + +FlashInfer is a library for Language Languages Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, PageAttention and LoRA. FlashInfer focus on LLM serving and inference, and delivers state-the-art performance across diverse scenarios. + .. toctree:: :maxdepth: 2 - :caption: Contents: + :caption: Get Started + installation +.. toctree:: + :maxdepth: 2 + :caption: Tutorials -Indices and tables -================== + tutorials/recursive_attention + tutorials/kv_layout + +.. toctree:: + :maxdepth: 2 + :caption: PyTorch API Reference -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` + api/python/decode + api/python/prefill + api/python/cascade + api/python/page + \ No newline at end of file diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 00000000..3ad79832 --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,47 @@ +.. _installation: + +Installation +============ + +Python Package +-------------- +FlashInfer is available as a Python package, built on top of `PyTorch `_ to +easily integrate with your python applications. + +Prerequisites +^^^^^^^^^^^^^ + +- OS: Linux only +- Python: 3.10, 3.11 +- PyTorch CUDA 11.8/12.1 + - Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version. +- Supported GPU architectures: sm_80, sm_86, sm_89, sm_90 (sm_75 support is working in progress). + +Quick Start +^^^^^^^^^^^ + +.. tabs:: + .. tab:: PyTorch CUDA 11.8 + + .. code-block:: bash + + pip install flashinfer -i https://flashinfer.ai/whl/cu118/ + + .. tab:: PyTorch CUDA 12.1 + + .. code-block:: bash + + pip install flashinfer -i https://flashinfer.ai/whl/cu121/ + + +C++ API +------- + +FlashInfer is a header-only library with only CUDA/C++ standard library dependency +that can be directly integrated into your C++ project without installation. + +You can check our `unittest and benchmarks `_ on how to use our C++ APIs at the moment. + +.. note:: + The ``nvbench`` and ``googletest`` dependency in ``3rdparty`` directory are only + used to compile unittests and benchmarks, and are not required for the library itself. diff --git a/docs/requirements.txt b/docs/requirements.txt index 0604736c..7717b04c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,9 +1,7 @@ sphinx-tabs == 3.4.1 -sphinx-rtd-theme -sphinx == 5.2.3 +sphinx == 7.2.6 sphinx-toolbox == 3.4.0 -tlcpack-sphinx-addon==0.2.2 -sphinxcontrib_httpdomain==1.8.1 -sphinxcontrib-napoleon==0.7 -sphinx-reredirects==0.1.2 - +sphinxcontrib_httpdomain == 1.8.1 +sphinxcontrib-napoleon == 0.7 +sphinx-reredirects == 0.1.2 +furo == 2024.01.29 diff --git a/docs/tutorials/kv_layout.rst b/docs/tutorials/kv_layout.rst new file mode 100644 index 00000000..8ee06282 --- /dev/null +++ b/docs/tutorials/kv_layout.rst @@ -0,0 +1,98 @@ +.. _kv-layout: + +KV-Cache Layout in FlashInfer +============================= + +Layout: NHD/HND +--------------- + +FlashInfer provides two layouts for last 3 dimensions in KV-Cache: ``NHD`` and ``HND``: + +- ``NHD``: the last 3 dimensions are organized as ``(seq_len, num_heads, head_dim)``. +- ``HND``: the last 3 dimensions are organized as ``(num_heads, head_dim, seq_len)``. + +The ``NHD`` layout is more natural because it's consistent with the output of +:math:`xW_k` and :math:`xW_v` without transpose. The ``HND`` layout is more friendly +for GPU implementation when KV-Cache uses low-precision data type (e.g. fp8). +In practice we don't observe significant performance difference between these two layouts +on fp16 kV-Cache and we prioritize ``NHD`` layout for better readability. FlashInfer implements +Attention kernels on both layouts and we provide an option to select between them (``NHD`` +by default). + +.. _ragged-layout: + +Ragged Tensor +------------- + +In batched inference/serving, the input sequence length may vary across different samples. +When there is no need to change the sequence length (e.g. in prefilling stage), we can use ``RaggedTensor`` to store +the key/value tensors in KV-Cache: + +.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/ragged.png + :width: 400 + :align: center + :alt: Data structure of Ragged KV-Cache. + +The keys (or values) of all requests are packed into a single ``data`` tensor without padding, +we use a ``indptr`` array (``num_requests+1`` elements, the first element is always zero) +to store the information of variable sequence lengths of each request +(``indptr[i+1]-indptr[i]`` is the sequence length of request ``i``), the ``data`` tensor has +shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``. + +We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``. + +.. _page-layout: + +FlashInfer APIs +~~~~~~~~~~~~~~~ + +FlashInfer provides :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper` to compute +the prefill attention between queries stored in ragged tensor and keys/values stored in ragged +KV-Cache. + +Page Table +---------- + +When KV-Cache is dynamic (e.g. in append or decode stage), packing all keys/values is not +efficient because the sequence length per request changes over time. `vLLM `_ +proposes to organize KV-Cache as a Page Table. In FlashInfer, we treat the page-table as +a block sparse matrix (each used page can be viewed as an non-zero block in block sparse matrix) +and uses the `CSR format `_ +to index the pages in KV-Cache. + +.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/page_layout.png + :width: 800 + :align: center + :alt: Data structure of Paged KV-Cache. + +For each request, we keep an record of its ``page_indices``, ``last_page_length``. +The overall ``kv_indptr`` array (with length ``num_requests+1``) can be computed as: +``[0, len(page_indices[0]), len(page_indices[0])+len(page_indices[1]), ...]``. +The overall ``kv_page_indices`` array (with length ``kv_indptr[-1]``) is the concatenation of all requests' ``page_indices``. +The overall ``last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``. + +The ``kv_data`` tensor is a 5-D tensor with shape (in ``NHD`` layout): + +.. code:: + + (max_num_pages, 2, page_size, num_heads, head_dim) + +where ``max_num_pages`` is the maximum number of pages used by all requests, ``page_size`` is the number of tokens +we fit into each page. ``2`` is the number of slots in each page (first one for keys, the second one for values). + +FlashInfer APIs +~~~~~~~~~~~~~~~ + +:meth:`flashinfer.page.append_paged_kv_cache` can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache +(the pages for these appended keys/values must be allocated prior to calling this API). + +:class:`BatchDecodeWithPagedKVCacheWrapper` and :class:`BatchPrefillWithPagedKVCacheWrapper` implements the decode attention +and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache. + +FAQ +^^^ + +How do FlashInfer manages KV-Cache? + FlashInfer itself is not responsible for managing the page-table (pop and allocate new pages, etc.) and we leave the strategy + to the user: different serving engine might have different strategies to manage the page-table. FlashInfer is only responsible + for computing the attention between queries and keys/values stored in KV-Cache. diff --git a/docs/tutorials/recursive_attention.rst b/docs/tutorials/recursive_attention.rst new file mode 100644 index 00000000..702d306e --- /dev/null +++ b/docs/tutorials/recursive_attention.rst @@ -0,0 +1,80 @@ +.. _recursive-attention: + +Attention States and Recursive form of Self-Attention +===================================================== + + +FlashInfer introduces the concept of **attention states**, which fully characterizes +the attention between a query and a set of key/value pairs. We further defines a +**merge** operator on the **attention states**. This merge operator facilitates the +computation of complete attention by allowing the recursive merging of attention states. + +Suppose we define :math:`s_i = \mathbf{q}\mathbf{k}_i^T` as the pre-softmax attention +score between the query :math:`\mathbf{q}` and the key :math:`\mathbf{k}_i`. The Self-Attention +score on index :math:`i` can be generalized to index set :math:`I`: + +.. math:: + + s(I)=\log\left(\sum_{i\in I}\exp\left(s_i\right)\right) + +We can also generalize the value on index :math:`i` to index set :math:`I`: + +.. math:: + + \mathbf{v}(I)=\frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))} + +The *attention state* of the index set :math:`i` can be defined as a tuple :math:`(s(I), \mathbf{v}(I))`. + +Then we can define the **merge** operator :math:`\oplus` of two attention states as: + +.. math:: + + \begin{bmatrix}\mathbf{v}(I\cup J)\\s(I\cup J)\end{bmatrix}=\begin{bmatrix}\mathbf{v}(I)\\s(I)\end{bmatrix}\oplus\begin{bmatrix}\mathbf{v}(J)\\s(J)\end{bmatrix}=\begin{bmatrix} \frac{\mathbf{v}(I)\exp(s(I)) + \mathbf{v}(J)\exp(s(J))}{\exp(s(I)) + \exp(s(J))} \\ \log(\exp(s(I)) + \exp(s(J))) \end{bmatrix} + +The **attention state** on the entire sequence can be defined as: + +.. math:: + + \begin{bmatrix}\mathbf{v}(\{1,2,\dots, n\})\\s(\{1,2,\dots, n\})\end{bmatrix} = \bigoplus_{i=1}^{n} \begin{bmatrix}\mathbf{v}_i\\s_i\end{bmatrix} + +Then $\mathbf{v}(\{1,2,\dots, n\})$ is the final attention output. + +.. note:: + + The generalized score :math:`s` is also known as log-sum-exp (``lse`` for short). + +Applications +------------ + +Note that :math:`\oplus` operator is **commutative** and **associative**, which means we can +safely offload the self-attention computation on a subset of KV to different devices +and **merge** the results **in any order**. + +There are several interesting applications of this recursive form of self-attention in FlashInfer so far: + +Shared-Prefix Batch Decoding + Many LLM applications involves batch decoding with the shared long prompt, FlashInfer decomposes attention + on the entire KV-Cache to shared prefix attention and unique suffixes attention. + This decomposition enables the offloading of these components to different kernel implementations, resulting + in a remarkable 30x acceleration in scenarios with long context and large batch-size. + Such decomposition accelerates the operator by 30 times in long context setting. + Check `our blog post `_ on more details about this application, + and :ref:`api-cascade-attention` on how to use this feature in FlashInfer. + +KV Sequence Parallelism + For long context LLM inference/serving, the batch size and number of heads per GPU is limited by the GPU memory, + and the default parallelism strategy cannot use all SMs in GPUs, which results in suboptimal performance. + Inspired by `Split-K `_ trick + in GEMM optimizations. FlashInfer partitions the KV sequence dimension and dispatches the attention computations to + different thread-blocks and merge them in a second pass. This same idea was also proposed in Flash-Decoding, you can + check their great `blog post `_ for visualizations and more details. + +Related APIs +------------ + +FlashInfer exposes several APIs to facilitate the recursive attention computation: + +- :ref:`api-merge-states` defines the operators to merge attention states. +- :ref:`apiprefill` and :ref:`apidecode` defines operators that returns attention states (APIs + with suffix ``_return_lse`` returns both attention output :math:`v` and score :math:`s`). + diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 596ce38d..1d5b8890 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -15,10 +15,20 @@ """ import math from typing import Optional - import torch -from . import _kernels +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + from .decode import ( batch_decode_with_padded_kv_cache_return_lse, BatchDecodeWithPagedKVCacheWrapper, @@ -39,31 +49,33 @@ def merge_state( v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor ): - r"""Merge the attention output (V) and the logsumexp value (S) from the two KV-segments. + r"""Merge the attention output ``V`` and the logsumexp value ``S`` from the two + KV-segments. + Check :ref:`our tutorial ` on the mathematical details. Parameters ---------- v_a : torch.Tensor - The attention output from the KV segment A. - Shape: [seq_len, num_heads, head_dim] + The attention output from the KV segment ``A``, shape: + ``[seq_len, num_heads, head_dim]``. s_a : torch.Tensor - The logsumexp value from the KV segment A. Expected to be a float32 tensor. - Shape: [seq_len, num_heads] + The logsumexp value from the KV segment ``A``. expected to be a float32 tensor, + shape: ``[seq_len, num_heads]``. v_b : torch.Tensor - The attention output from the KV segment B. - Shape: [seq_len, num_heads, head_dim] + The attention output from the KV segment ``B``, + shape: ``[seq_len, num_heads, head_dim]``. s_b : torch.Tensor - The logsumexp value from the KV segment B. Expected to be a float32 tensor. - Shape: [seq_len, num_heads] + The logsumexp value from the KV segment ``B``, expected to be a float32 tensor, + shape: ``[seq_len, num_heads]`` Returns ------- V : torch.Tensor - The merged attention output (equivalent to attention with merged KV-segment [A: B]). - Shape: [batch_size, num_heads, head_dim] + The merged attention output (equivalent to attention with merged KV-segment + ``[A: B]``), shape: ``[batch_size, num_heads, head_dim]``. S : torch.Tensor - The logsumexp value from the merged KV-segment [A: B]. - Shape: [batch_size, num_heads] + The logsumexp value from the merged KV-segment ``[A: B]``, shape: + ``[batch_size, num_heads]``. """ return _kernels.merge_state(v_a, s_a, v_b, s_b) @@ -71,46 +83,48 @@ def merge_state( def merge_state_in_place( v: torch.Tensor, s: torch.Tensor, v_other: torch.Tensor, s_other: torch.Tensor ): - r"""Merge the self-attention state (v, s) with another state (v_other, s_other) in-place. + r"""Merge the self-attention state ``(v, s)`` with another state + ``(v_other, s_other)`` in-place. Parameters ---------- v : torch.Tensor - The partial v to be updated in-place. - Shape: (seq_len, num_heads, head_dim) + The partial attention output to be updated in-place, shape: + ``(seq_len, num_heads, head_dim)``. s : torch.Tensor - The partial logsumexpr value to be updated in-place, expected to be a float32 tensor. - Shape: (seq_len, num_heads) + The partial logsumexpr value to be updated in-place, expected to be a float32 + tensor, shape: ``(seq_len, num_heads)``. v_other : torch.Tensor - The other v to be merged. - Shape: (seq_len, num_heads, head_dim) + The other attention output to be merged, shape: + ``(seq_len, num_heads, head_dim)``. s_other : torch.Tensor - The other logsumexp value to be merged, expected to be a float32 tensor. - Shape: (seq_len, num_heads) + The other logsumexp value to be merged, expected to be a float32 tensor, + shape: ``(seq_len, num_heads)``. """ _kernels.merge_state_in_place(v, s, v_other, s_other) def merge_states(v: torch.Tensor, s: torch.Tensor): - r"""Merge the attention output (V) and the logsumexp value (S) from multiple KV-segments. + r"""Merge the attention output ``V`` and the logsumexp value ``S`` from multiple + KV-segments. Parameters ---------- v : torch.Tensor - The attention output from the KV segments. - Shape: [seq_len, num_kv_segments, num_heads, head_dim] + The attention output from the KV segments, shape: + ``[seq_len, num_kv_segments, num_heads, head_dim]``. s : torch.Tensor - The logsumexp value from the KV segments. - Shape: [seq_len, num_kv_segments, num_heads] + The logsumexp value from the KV segments, shape: + ``[seq_len, num_kv_segments, num_heads]``, expected + to be a float32 tensor. Returns ------- V : torch.Tensor - The merged attention output. - Shape: [seq_len, num_heads, head_dim] + The merged attention output, shape: ``[seq_len, num_heads, head_dim]``. S : torch.Tensor - The logsumexp value from the merged KV-segments. - Shape: [seq_len, num_heads] + The logsumexp value from the merged KV-segments, shape: + ``[seq_len, num_heads]``. """ return _kernels.merge_states(v, s) @@ -127,40 +141,49 @@ def batch_decode_with_shared_prefix_padded_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): - r"""Batch decode with shared prefix padded KV cache. + r"""Decode attention between queries and shared prefix kv-cache for batch of + requests. Parameters ---------- q : torch.Tensor - Shape: [batch_size, num_qo_heads, head_dim] + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``. k_shared : torch.Tensor - Shape: [shared_prefix_len, num_kv_heads, head_dim] if NHD - [num_kv_heads, shared_prefix_len, head_dim] if HND + The shared prefix key tensor, shape: + ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_kv_heads, shared_prefix_len, head_dim]`` if :attr:`kv_layout` is + ``HND``. v_shared : torch.Tensor - Shape: [shared_prefix_len, num_kv_heads, head_dim] if NHD - [num_kv_heads, shared_prefix_len, head_dim] if HND + The shared prefix value tensor, shape: + ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + or ``[num_kv_heads, shared_prefix_len, head_dim]`` if :attr:`kv_layout` is + ``HND``. k_unique : torch.Tensor - Shape: [batch_size, unique_len, num_kv_heads, head_dim] if NHD - [batch_size, num_kv_heads, unique_len, head_dim] if HND + The request-independent suffix key tensor, shape: + ``[batch_size, unique_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is + ``NHD``, or ``[batch_size, num_kv_heads, unique_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. v_unique : torch.Tensor - Shape: [batch_size, unique_len, num_kv_heads, head_dim] if NHD - [batch_size, num_kv_heads, unique_len, head_dim] if HND + The request-independent suffix value tensor, shape: + ``[batch_size, unique_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is + ``NHD``, or ``[batch_size, num_kv_heads, unique_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. kv_layout : str - The layout of the input k/v tensors, could be either "NHD" or "HND". + The layout of the kv-cache, could be either "NHD" or "HND". allow_fp16_qk_reduction : bool - Whether to use f16 for qk reduction (could be significantly faster for GeForce cards, at - the cost of slight precision loss). + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). sm_scale : Optional[float] - The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim) + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)`` rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to 1.0. + The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to 1e4. + The theta used in RoPE, if not provided, will be set to ``1e4``. Returns ------- V : torch.Tensor - Shape: [batch_size, num_heads, head_dim] + The attention output, shape: ``[batch_size, num_heads, head_dim]`` """ check_kv_layout(kv_layout) V_shared, S_shared = single_prefill_with_kv_cache_return_lse( @@ -190,6 +213,19 @@ def batch_decode_with_shared_prefix_padded_kv_cache( class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: + r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch + of requests. + + Check :ref:`our tutorial` for page table layout. + + Note + ---- + To accelerate computation, FlashInfer's shared prefix batch decode attention creates + some auxiliary data structures, these data structures can be reused across multiple + batch decode attention calls (e.g. different Transformer layers). This wrapper class + manages the lifecycle of these data structures. + """ + def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._batch_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout @@ -197,6 +233,14 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._kv_layout = kv_layout def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): + r"""Reset the workspace buffer. + + Parameters + ---------- + new_workspace_buffer : torch.Tensor + The new workspace buffer, the device of the new workspace buffer should + be the same as the device of the input tensors. + """ self._batch_decode_wrapper.reset_workspace_buffer(new_workspace_buffer) def begin_forward( @@ -210,6 +254,43 @@ def begin_forward( page_size: int, data_type: str = "float16", ): + r"""Create auxiliary data structures for shared-prefix batch decode for multiple + forward calls within the same decode step. + + Parameters + ---------- + indptr : torch.Tensor + The indptr of the paged kv cache, shape: ``[batch_size + 1]`` + indices : torch.Tensor + The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` + last_page_len : torch.Tensor + The number of entries in the last page of each request in the paged kv + cache, shape: ``[batch_size]`` + num_qo_heads : int + The number of query/output heads + num_kv_heads : int + The number of key/value heads + head_dim : int + The dimension of the heads + page_size : int + The page size of the paged kv cache + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + data_type : Union[str, torch.dtype] + The data type of the paged kv cache + + Note + ---- + The :meth:`begin_forward` method should be called before any :meth:`forward` or + :meth:`forward_return_lse` calls, + auxiliary data structures will be created during this call and cached for + multiple forward calls. + + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` + is not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. + """ self._batch_decode_wrapper.begin_forward( unique_kv_indptr, unique_kv_indices, @@ -223,6 +304,7 @@ def begin_forward( ) def end_forward(self): + r"""Clear auxiliary data structures created by :meth:`begin_forward`.""" self._batch_decode_wrapper.end_forward() def forward( @@ -235,6 +317,44 @@ def forward( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): + r"""Compute batch decode attention between queries and shared-prefix paged + kv-cache. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``. + k_shared : torch.Tensor + The shared prefix key tensor, shape: + ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is + ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. + v_shared : torch.Tensor + The shared prefix value tensor, shape: + ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is + ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. + unique_kv_data : torch.Tensor + A 5-D tensor of paged kv-cache data storing the request-independent suffix + key and value tensors, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``HND``. + allow_fp16_qk_reduction : bool + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + V : torch.Tensor + The attention output, shape: ``[batch_size, num_heads, head_dim]`` + """ V_shared, S_shared = single_prefill_with_kv_cache_return_lse( q, k_shared, @@ -258,13 +378,45 @@ def forward( class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: + r"""Wrapper class for prefill/append attention with shared-prefix paged kv-cache for + batch of requests. + + Check :ref:`our tutorial` for paged kv-cache layout. + + Note + ---- + To accelerate computation, FlashInfer's shared-prefix batch prefill/append attention + operators creates some auxiliary data structures, these data structures can be + reused across multiple prefill/append attention calls (e.g. different Transformer + layers). This wrapper class manages the lifecycle of these data structures. + """ + def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): + r"""Constructor of :class:`BatchDecodeWithSharedPrefixPagedKVCacheWrapper`. + + Parameters + ---------- + workspace_buffer : torch.Tensor + The user reserved workspace buffer used to store auxiliary data structures, + recommended size is 16MB, the device of the workspace buffer should be the + same as the device of the input tensors. + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + """ self._batch_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) self._kv_layout = kv_layout def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): + r"""Reset the workspace buffer. + + Parameters + ---------- + new_workspace_buffer : torch.Tensor + The new workspace buffer, the device of the new workspace buffer should + be the same as the device of the input tensors. + """ self._batch_prefill_wrapper.reset_workspace_buffer(new_workspace_buffer) def begin_forward( @@ -276,6 +428,36 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, ): + r"""Create auxiliary data structures for shared-prefix batch prefill/append + attention for multiple forward calls within the same prefill/append step. + + Parameters + ---------- + qo_indptr : torch.Tensor + The indptr of the query/output tensor, shape: ``[batch_size + 1]``. + paged_kv_indptr : torch.Tensor + The indptr of the paged kv-cache, shape: ``[batch_size + 1]``. + paged_kv_indices : torch.Tensor + The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``. + paged_kv_last_page_len : torch.Tensor + The number of entries in the last page of each request in the paged + kv-cache, shape: ``[batch_size]``. + num_qo_heads : int + The number of query/output heads. + num_kv_heads : int + The number of key/value heads. + + Notes + ----- + The :meth:`begin_forward` method should be called before any :meth:`forward` + or :meth:`forward_return_lse` calls, auxiliary data structures will be created + during this call and cached for multiple forward calls. + + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` + is not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. + """ + self._batch_prefill_wrapper.begin_forward( qo_indptr, paged_kv_indptr, @@ -286,6 +468,7 @@ def begin_forward( ) def end_forward(self): + r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" self._batch_prefill_wrapper.end_forward() def forward( @@ -299,6 +482,46 @@ def forward( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): + r"""Compute batch prefill/append attention between query and shared-prefix paged + kv-cache. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + k_shared : torch.Tensor + The shared prefix key tensor, shape: + ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is + ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. + v_shared ; torch.Tensor + The shared prefix value tensor, shape: + ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is + ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. + unique_kv_data : torch.Tensor + A 5-D tensor of paged kv-cache data storing the request-independent suffix + key and value tensors, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``HND``. + causal : bool + Whether to apply causal mask on the attention matrix. + allow_fp16_qk_reduction : bool + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + V : torch.Tensor + The attention output, shape: ``[qo_indptr[-1], num_heads, head_dim]``. + """ V_shared, S_shared = single_prefill_with_kv_cache_return_lse( q, k_shared, diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index b2014a2c..d4f42cc7 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -15,10 +15,21 @@ """ import math from typing import Optional, Union - import torch -from . import _kernels +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + + from .utils import ( RotaryMode, TensorLayout, @@ -43,35 +54,64 @@ def single_decode_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - rotary_mode: str = "NONE", kv_layout: str = "NHD", + rotary_mode: str = "NONE", sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): - r"""Single request decode with KV cache. + r"""Decode attention with KV Cache for single request, return attention output. Parameters ---------- q : torch.Tensor - Shape: [num_qo_heads, head_dim] + The query tensor, shape: ``[num_qo_heads, head_dim]``. k : torch.Tensor - Shape: [kv_len, num_kv_heads, head_dim] if NHD - [num_kv_heads, kv_len, head_dim] if HND + The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is + ``HND``. v : torch.Tensor - Shape: [kv_len, num_kv_heads, head_dim] if NHD - [num_kv_heads, kv_len, head_dim] if HND - rotary_mode : str - Whether to apply rotary embeddings inside attention kernels, could be - "NONE" or "LLAMA". + The value tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. kv_layout : str - The layout of the input k/v tensors, could be either "NHD" or "HND". + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). sm_scale : Optional[float] - The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim) + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to 1.0. + The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to 1e4. + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + torch.Tensor + The attention output, shape: ``[num_qo_heads, head_dim]`` + + Examples + -------- + + >>> import torch + >>> import flashinfer + >>> kv_len = 4096 + >>> num_qo_heads = 32 + >>> num_kv_heads = 32 + >>> head_dim = 128 + >>> q = torch.randn(num_qo_heads, head_dim).half().to("cuda:0") + >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") + >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") + >>> o = flashinfer.single_decode_with_kv_cache(q, k, v) + >>> o.shape + torch.Size([32, 128]) + + Notes + ----- + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is + not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. """ check_rotary_mode(rotary_mode) check_kv_layout(kv_layout) @@ -106,23 +146,45 @@ def batch_decode_with_padded_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): - r"""Batch decode with padded KV cache. + r"""Decode attention with padded KV cache for batch of requests, return attention + output. Parameters ---------- q : torch.Tensor - Shape: [batch_size, num_qo_heads, head_dim] + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``. k_padded : torch.Tensor - Shape: [batch_size, padded_seq_len, num_kv_heads, head_dim] if NHD - [batch_size, num_kv_heads, padded_seq_len, head_dim] if HND + The padded key tensor, shape: + ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. v_padded : torch.Tensor - Shape: [batch_size, padded_seq_len, num_kv_heads, head_dim] if NHD - [batch_size, num_kv_heads, padded_seq_len, head_dim] if HND + The padded value tensor, shape: + ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + sm_scale : Optional[float] + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. Returns ------- - V : torch.Tensor - Shape: [batch_size, num_heads, head_dim] + torch.Tensor + The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. + + Notes + ----- + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is + not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. """ if sm_scale is None: head_dim = q.shape[-1] @@ -154,36 +216,48 @@ def batch_decode_with_padded_kv_cache_return_lse( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): - r""" + r"""Decode attention with padded KV cache for batch of requests, return attention + output and logsumexp of attention scores, return attention output and logsumexp of + attention scores. + Parameters ---------- q : torch.Tensor - Shape: [batch_size, num_qo_heads, head_dim] + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``. k_padded : torch.Tensor - Shape: [batch_size, padded_seq_len, num_kv_heads, head_dim] if NHD - [batch_size, num_kv_heads, padded_seq_len, head_dim] if HND + The padded key tensor, shape: + ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. v_padded : torch.Tensor - Shape: [batch_size, padded_seq_len, num_kv_heads, head_dim] if NHD - [batch_size, num_kv_heads, padded_seq_len, head_dim] if HND - kv_layout: str - The layout of the input k_padded/v_padded tensors, could be either - "NHD" or "HND" - rotary_mode: str - Whether to apply rotary embeddings inside attention kernels, could be - "NONE" or "LLAMA". - sm_scale: Optional[float] - The scale of softmax, if not provided, will be set to 1 / sqrt(head_dim) - rope_scale: Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to 1.0. - rope_theta: Optional[float] - The theta used in RoPE, if not provided, will be set to 1e4. + The padded value tensor, shape: + ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if + :attr:`kv_layout` is ``HND``. + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + sm_scale : Optional[float] + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. Returns ------- V : torch.Tensor - Shape: [batch_size, num_heads, head_dim] + The attention output, shape: [batch_size, num_qo_heads, head_dim] S : torch.Tensor - Shape: [batch_size, num_heads] + The logsumexp of attention scores, Shape: [batch_size, num_qo_heads] + + Notes + ----- + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is + not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. """ if sm_scale is None: head_dim = q.shape[-1] @@ -206,15 +280,31 @@ def batch_decode_with_padded_kv_cache_return_lse( class BatchDecodeWithPagedKVCacheWrapper: - r"""Wrapper class for batch_decode_with_paged_kv_cache kernel. + r"""Wrapper class for decode attention with paged kv-cache (first proposed in + `vLLM `_) for batch of requests. + + Check :ref:`our tutorial` for page table layout. - To accelerate computation, FlashInfer's batch decode operators creates some + Note + ---- + To accelerate computation, FlashInfer's batch decode attention creates some auxiliary data structures, these data structures can be reused across multiple - batch decode calls (e.g. different Transformer layers). This wrapper class manages - the lifecycle of these data structures. + batch decode attention calls (e.g. different Transformer layers). This wrapper class + manages the lifecycle of these data structures. """ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): + r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. + + Parameters + ---------- + workspace_buffer : torch.Tensor + The user reserved workspace buffer used to store auxiliary data structures, + recommended size is 16MB, the device of the workspace buffer should be the + same as the device of the input tensors. + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + """ check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer @@ -226,6 +316,14 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._paged_kv_last_page_len = None def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): + r"""Reset the workspace buffer. + + Parameters + ---------- + new_workspace_buffer : torch.Tensor + The new workspace buffer, the device of the new workspace buffer should + be the same as the device of the input tensors. + """ self._workspace_buffer = new_workspace_buffer def begin_forward( @@ -240,9 +338,41 @@ def begin_forward( rotary_mode: str = "NONE", data_type: Union[str, torch.dtype] = "float16", ): - r"""The begin_forward method should be called before any batch decode calls, - auxiliary data structures will be created during this call and cached for - multiple forward calls. + r"""Create auxiliary data structures for batch decode for multiple forward calls + within the same decode step. + + Parameters + ---------- + indptr : torch.Tensor + The indptr of the paged kv cache, shape: ``[batch_size + 1]`` + indices : torch.Tensor + The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` + last_page_len : torch.Tensor + The number of entries in the last page of each request in the paged kv + cache, shape: ``[batch_size]`` + num_qo_heads : int + The number of query/output heads + num_kv_heads : int + The number of key/value heads + head_dim : int + The dimension of the heads + page_size : int + The page size of the paged kv cache + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + data_type : Union[str, torch.dtype] + The data type of the paged kv cache + + Note + ---- + The :meth:`begin_forward` method should be called before any :meth:`forward` or + :meth:`forward_return_lse` calls, auxiliary data structures will be created + during this call and cached for multiple forward calls. + + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` + is not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. """ self._paged_kv_indptr = indptr self._paged_kv_indices = indices @@ -270,7 +400,7 @@ def begin_forward( ) def end_forward(self): - r"""The end_forward method can clear the cached data structures.""" + r"""Clear auxiliary data structures created by :meth:`begin_forward`.""" self._paged_kv_indptr = None self._paged_kv_indices = None self._paged_kv_last_page_len = None @@ -284,6 +414,32 @@ def forward( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): + r"""Compute batch decode attention between query and paged kv cache. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` + paged_kv_data : torch.Tensor + A 5-D tensor of the reserved paged kv-cache data, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``HND``. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + torch.Tensor + The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. + """ check_rotary_mode(rotary_mode) if rope_scale is None: rope_scale = 1.0 @@ -310,6 +466,35 @@ def forward_return_lse( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): + r"""Compute batch decode attention with paged kv cache, return attention output + and logsumexp of attention scores. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` + paged_kv_data : torch.Tensor + A 5-D tensor of the reserved paged kv-cache data, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``HND``. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + V : torch.Tensor + The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. + S : torch.Tensor + The logsumexp of attention scores, Shape: ``[batch_size, num_qo_heads]``. + """ check_rotary_mode(rotary_mode) if rope_scale is None: rope_scale = 1.0 diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 9160907e..d7944a78 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -15,8 +15,17 @@ """ import torch -from . import _kernels -from .utils import TensorLayout, check_kv_layout +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e def append_paged_kv_cache( @@ -29,6 +38,40 @@ def append_paged_kv_cache( kv_last_page_len: torch.Tensor, kv_layout: str = "NHD", ): + r"""Append a batch of key-value pairs to a paged key-value cache. + + Parameters + ---------- + append_key : torch.Tensor + The key tensor to append in ragged tensor format, shape: + ``[append_indptr[-1], num_kv_heads, head_dim]``. + append_value : torch.Tensor + The value tensor to append in ragged tensor format, shape: + ``[append_indptr[-1], num_kv_heads, head_dim]``. + append_indptr : torch.Tensor + The indptr tensor of the key-value pairs to append, shape: ``[batch_size + 1]``. + kv_data : torch.Tensor + The 5-D tensor of the paged key-value cache, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, num_kv_heads, page_size, num_kv_heads]`` if + :attr:`kv_layout` is ``NHD``. + kv_indices : torch.Tensor + The page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]``. + kv_indptr : torch.Tensor + The indptr of the paged kv-cache, shape: ``[batch_size + 1]``. + kv_last_page_len : torch.Tensor + The number of entries in the last page of each request in the paged kv cache, + shape: ``[batch_size]``. + kv_layout : str + The layout of the paged kv-cache, either ``NHD`` or ``HND``. + + Notes + ----- + The function assumes that the space for appended k/v have already been allocated, + which means :attr:`kv_indices`, :attr:`kv_indptr`, :attr:`kv_last_page_len` has + incorporated appended k/v. + """ check_kv_layout(kv_layout) _kernels.append_paged_kv_cache( append_key, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 76d2e1ac..0aaac623 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -15,10 +15,20 @@ """ import math from typing import Optional - import torch -from . import _kernels +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + from .utils import ( RotaryMode, TensorLayout, @@ -54,39 +64,70 @@ def single_prefill_with_kv_cache( k: torch.Tensor, v: torch.Tensor, causal: bool = False, - rotary_mode: str = "NONE", kv_layout: str = "NHD", + rotary_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): - r"""Single request prefill with KV cache kernel. + r"""Prefill/Append attention with KV cache for single request, return the attention + output. Parameters ---------- q : torch.Tensor - Shape: [qo_len, num_qo_heads, head_dim] if NHD - [num_qo_heads, qo_len, head_dim] if HND + The query tensor, shape: ``[qo_len, num_qo_heads, head_dim]``. k : torch.Tensor - Shape: [kv_len, num_kv_heads, head_dim] if NHD - [num_kv_heads, kv_len, head_dim] if HND + The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is + ``HND``. v : torch.Tensor - Shape: [kv_len, num_kv_heads, head_dim] if NHD - [num_kv_heads, kv_len, head_dim] if HND + The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD``, ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is + ``HND``. causal : bool Whether to apply causal mask to the attention matrix. - rotary_mode : str - Whether to apply rotary embeddings inside attention kernels, could be - "NONE" or "LLAMA". kv_layout : str - The layout of the input k/v tensors, could be either "NHD" or "HND". + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). allow_fp16_qk_reduction : bool - Whether to use f16 for qk reduction (could be significantly faster for GeForce cards, at - the cost of precision loss). + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to 1.0. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to 1e4. + + Returns + ------- + torch.Tensor + The attention output, shape: ``[qo_len, num_qo_heads, head_dim]``. + + Examples + -------- + + >>> import torch + >>> import flashinfer + >>> qo_len = 128 + >>> kv_len = 4096 + >>> num_qo_heads = 32 + >>> num_kv_heads = 4 + >>> head_dim = 128 + >>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0") + >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") + >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") + >>> o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, + allow_fp16_qk_reduction=True) + >>> o.shape + torch.Size([128, 32, 128]) + + Notes + ----- + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is + not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. """ check_rotary_mode(rotary_mode) check_kv_layout(kv_layout) @@ -115,49 +156,73 @@ def single_prefill_with_kv_cache_return_lse( k: torch.Tensor, v: torch.Tensor, causal: bool = False, - rotary_mode: str = "NONE", kv_layout: str = "NHD", + rotary_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): - r"""Single request prefill with KV cache kernel, return logsumexp value. + r"""Prefill/Append attention with KV cache for single request, return attention + output and logsumexp of attention scores. Parameters ---------- q : torch.Tensor - Shape: [qo_len, num_qo_heads, head_dim] if NHD - [num_qo_heads, qo_len, head_dim] if HND + The query tensor, shape: ``[qo_len, num_qo_heads, head_dim]``. k : torch.Tensor - Shape: [kv_len, num_kv_heads, head_dim] if NHD - [num_kv_heads, kv_len, head_dim] if HND + The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is + ``HND``. v : torch.Tensor - Shape: [kv_len, num_kv_heads, head_dim] if NHD - [num_kv_heads, kv_len, head_dim] if HND + The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` + is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is + ``HND``. causal : bool Whether to apply causal mask to the attention matrix. - rotary_mode : str - Whether to apply rotary embeddings inside attention kernels, could be - "NONE" or "LLAMA". kv_layout : str - The layout of the input k/v tensors, could be either "NHD" or "HND". + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). allow_fp16_qk_reduction : bool - Whether to use f16 for qk reduction (could be significantly faster for GeForce cards, at - the cost of precision loss). + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to 1.0. + The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to 1e4. + The theta used in RoPE, if not provided, will be set to ``1e4``. Returns ------- V : torch.Tensor - The attention output. - Shape: [qo_len, num_qo_heads, head_dim] if NHD - [num_qo_heads, qo_len, head_dim] if HND + The attention output, shape: ``[qo_len, num_qo_heads, head_dim]``. S : torch.Tensor - The logsumexp value. - Shape: [qo_len, num_qo_heads] + The logsumexp value, shape: ``[qo_len, num_qo_heads]`` + + Examples + -------- + + >>> import torch + >>> import flashinfer + >>> qo_len = 128 + >>> kv_len = 4096 + >>> num_qo_heads = 32 + >>> num_kv_heads = 4 + >>> head_dim = 128 + >>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0") + >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") + >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") + >>> V, S = flashinfer.single_prefill_with_kv_cache_return_lse(q, k, v, causal=True) + >>> V.shape + torch.Size([128, 32, 128]) + >>> S.shape + torch.Size([128, 32]) + + Notes + ----- + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is + not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. """ check_rotary_mode(rotary_mode) check_kv_layout(kv_layout) @@ -184,9 +249,31 @@ def single_prefill_with_kv_cache_return_lse( class BatchPrefillWithPagedKVCacheWrapper: - r"""Wrapper class of batch_prefill_with_paged_kv_cache kernel.""" + r"""Wrapper class for prefill/append attention with paged kv-cache for batch of + requests. + + Check :ref:`our tutorial` for page table layout. + + Note + ---- + To accelerate computation, FlashInfer's batch prefill/append attention operators + creates some auxiliary data structures, these data structures can be reused across + multiple prefill/append attention calls (e.g. different Transformer layers). This + wrapper class manages the lifecycle of these data structures. + """ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): + r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. + + Parameters + ---------- + workspace_buffer : torch.Tensor + The user reserved workspace buffer used to store auxiliary data structures, + recommended size is 16MB, the device of the workspace buffer should be the + same as the device of the input tensors. + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + """ check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer @@ -199,6 +286,14 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._paged_kv_last_page_len = None def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): + r"""Reset the workspace buffer. + + Parameters + ---------- + new_workspace_buffer : torch.Tensor + The new workspace buffer, the device of the new workspace buffer should + be the same as the device of the input tensors. + """ self._workspace_buffer = new_workspace_buffer def begin_forward( @@ -210,6 +305,35 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, ): + r"""Create auxiliary data structures for batch prefill/append attention for + multiple forward calls within the same prefill/append step. + + Parameters + ---------- + qo_indptr : torch.Tensor + The indptr of the query/output tensor, shape: ``[batch_size + 1]``. + paged_kv_indptr : torch.Tensor + The indptr of the paged kv-cache, shape: ``[batch_size + 1]``. + paged_kv_indices : torch.Tensor + The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``. + paged_kv_last_page_len : torch.Tensor + The number of entries in the last page of each request in the paged + kv-cache, shape: ``[batch_size]``. + num_qo_heads : int + The number of query/output heads. + num_kv_heads : int + The number of key/value heads. + + Notes + ----- + The :meth:`begin_forward` method should be called before any :meth:`forward` or + :meth:`forward_return_lse` calls, auxiliary data structures will be created + during this call and cached for multiple forward calls. + + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` + is not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. + """ batch_size = len(qo_indptr) - 1 self._qo_indptr = qo_indptr self._paged_kv_indptr = paged_kv_indptr @@ -220,6 +344,7 @@ def begin_forward( ) def end_forward(self): + r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" self._qo_indptr = None self._paged_kv_indptr = None self._paged_kv_indices = None @@ -236,6 +361,37 @@ def forward( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): + r"""Compute batch prefill/append attention between query and paged kv-cache. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` + paged_kv_data : torch.Tensor + A 5-D tensor of the reserved paged kv-cache data, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` + if :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` + if :attr:`kv_layout` is ``HND``. + causal : bool + Whether to apply causal mask to the attention matrix. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + allow_fp16_qk_reduction : bool + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + torch.Tensor + The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + """ check_rotary_mode(rotary_mode) if rope_scale is None: rope_scale = 1.0 @@ -267,6 +423,40 @@ def forward_return_lse( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): + r"""Compute batch prefill/append attention paged kv-cache. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` + paged_kv_data : torch.Tensor + A 5-D tensor of the reserved paged kv-cache data, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` + if :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``HND``. + causal : bool + Whether to apply causal mask to the attention matrix. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + allow_fp16_qk_reduction : bool + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + V : torch.Tensor + The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + S : torch.Tensor + The logsumexp of attention output, shape: + ``[qo_indptr[-1], num_qo_heads, head_dim]``. + """ check_rotary_mode(rotary_mode) if rope_scale is None: rope_scale = 1.0 @@ -290,9 +480,31 @@ def forward_return_lse( class BatchPrefillWithRaggedKVCacheWrapper: - r"""Wrapper class of batch_prefill_with_ragged_kv_cache kernel.""" + r"""Wrapper class for prefill/append attention with ragged (tensor) kv-cache for + batch of requests. + + Check :ref:`our tutorial` for ragged kv-cache layout. + + Note + ---- + To accelerate computation, FlashInfer's batch prefill/append attention operators + creates some auxiliary data structures, these data structures can be reused across + multiple prefill/append attention calls (e.g. different Transformer layers). This + wrapper class manages the lifecycle of these data structures. + """ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): + r"""Constructor of :class:`BatchDecodeWithRaggedKVCacheWrapper`. + + Parameters + ---------- + workspace_buffer : torch.Tensor + The user reserved workspace buffer used to store auxiliary data structures, + recommended size is 16MB, the device of the workspace buffer should be the + same as the device of the input tensors. + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + """ check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer @@ -303,6 +515,14 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._kv_indptr = None def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): + r"""Reset the workspace buffer. + + Parameters + ---------- + new_workspace_buffer : torch.Tensor + The new workspace buffer, the device of the new workspace buffer should + be the same as the device of the input tensors. + """ self._workspace_buffer = new_workspace_buffer def begin_forward( @@ -312,6 +532,30 @@ def begin_forward( num_qo_heads: int, num_kv_heads: int, ): + r"""Create auxiliary data structures for batch prefill/append attention for + multiple forward calls within the same prefill/append step. + + Parameters + ---------- + qo_indptr : torch.Tensor + The indptr of the query/output tensor, shape: ``[batch_size + 1]``. + kv_indptr : torch.Tensor + The indptr of the key/value tensor, shape: ``[batch_size + 1]``. + num_qo_heads : int + The number of query/output heads. + num_kv_heads : int + The number of key/value heads. + + Notes + ----- + The :meth:`begin_forward` method should be called before any :meth:`forward` or + :meth:`forward_return_lse` calls, auxiliary data structures will be created + during this call and cached for multiple forward calls. + + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` + is not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. + """ batch_size = len(qo_indptr) - 1 self._qo_indptr = qo_indptr self._kv_indptr = kv_indptr @@ -320,6 +564,7 @@ def begin_forward( ) def end_forward(self): + r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" self._qo_indptr = None self._kv_indptr = None self._wrapper.end_forward() @@ -335,6 +580,36 @@ def forward( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): + r"""Compute batch prefill/append attention between query and kv-cache stored in + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` + k : torch.Tensor + The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` + v : torch.Tensor + The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` + causal : bool + Whether to apply causal mask to the attention matrix. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + allow_fp16_qk_reduction : bool + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + torch.Tensor + The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + """ check_rotary_mode(rotary_mode) if rope_scale is None: rope_scale = 1.0 @@ -365,6 +640,38 @@ def forward_return_lse( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): + r"""Compute batch prefill/append attention between query and kv-cache stored in + ragged tensor. Return attention output and logsumexp of attention scores. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` + k : torch.Tensor + The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` + v : torch.Tensor + The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` + causal : bool + Whether to apply causal mask to the attention matrix. + rotary_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + allow_fp16_qk_reduction : bool + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + V : torch.Tensor + The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + S : torch.Tensor + The logsumexp of attention output, shape: + ``[qo_indptr[-1], num_qo_heads, head_dim]``. + """ check_rotary_mode(rotary_mode) if rope_scale is None: rope_scale = 1.0