Skip to content

Commit

Permalink
[Doc] Improve README and documentation. (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Feb 2, 2024
1 parent a389ed4 commit b30dad3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ Using our PyTorch API is the easiest way to get started:
We provide prebuilt wheels for Linux and you can try out FlashInfer with the following command:

```bash
pip install flashinfer -i https://flashinfer.ai/whl/cu121/ # for CUDA 12.1, use cu118 for CUDA 11.8
# For CUDA 12.1
pip install flashinfer -i https://flashinfer.ai/whl/cu121/
# For CUDA 11.8
# pip install flashinfer -i https://flashinfer.ai/whl/cu118/
```

or you can build from source:
Expand Down
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Welcome to FlashInfer's documentation!

`Blog <https://flashinfer.ai/>`_ | `Discussion Forum <https://github.com/orgs/flashinfer-ai/discussions>`_ | `GitHub <https://github.com/flashinfer-ai/flashinfer/>`_

FlashInfer is a library for Language Languages Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, PageAttention and LoRA. FlashInfer focus on LLM serving and inference, and delivers state-the-art performance across diverse scenarios.
FlashInfer is a library for Language Languages Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, PageAttention and LoRA. FlashInfer focus on LLM serving and inference, and delivers state-of-the-art performance across diverse scenarios.

.. toctree::
:maxdepth: 2
Expand All @@ -31,4 +31,4 @@ FlashInfer is a library for Language Languages Models that provides high-perform
api/python/prefill
api/python/cascade
api/python/page


21 changes: 12 additions & 9 deletions docs/tutorials/recursive_attention.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _recursive-attention:

Attention States and Recursive form of Self-Attention
=====================================================
Attention States and Recursive Attention
========================================


FlashInfer introduces the concept of **attention states**, which fully characterizes
Expand All @@ -21,23 +21,26 @@ We can also generalize the value on index :math:`i` to index set :math:`I`:

.. math::
\mathbf{v}(I)=\frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))}
\mathbf{v}(I) = \sum_{i\in I}\textrm{softmax}(s_i) \mathbf{v}_i = \frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))}
The *attention state* of the index set :math:`i` can be defined as a tuple :math:`(s(I), \mathbf{v}(I))`.

Then we can define the **merge** operator :math:`\oplus` of two attention states as:
The :math:`softmax` function is restricted to the index set :math:`I`. Note that :math:`\mathbf{v}(\{1,2,\cdots, n\})` is the self-attention output of the entire sequence.
The *attention state* of the index set :math:`i` can be defined as a tuple :math:`(s(I), \mathbf{v}(I))`, then we can define a binary **merge** operator :math:`\oplus` of two attention states as ((in practice we will minus $s$ with maximum value to guarantee numerical stability and here we omit them for simplicity):

.. math::
\begin{bmatrix}\mathbf{v}(I\cup J)\\s(I\cup J)\end{bmatrix}=\begin{bmatrix}\mathbf{v}(I)\\s(I)\end{bmatrix}\oplus\begin{bmatrix}\mathbf{v}(J)\\s(J)\end{bmatrix}=\begin{bmatrix} \frac{\mathbf{v}(I)\exp(s(I)) + \mathbf{v}(J)\exp(s(J))}{\exp(s(I)) + \exp(s(J))} \\ \log(\exp(s(I)) + \exp(s(J))) \end{bmatrix}
The **attention state** on the entire sequence can be defined as:
the **merge** operator can be generalized to any number of attention state inputs:

.. math::
\begin{bmatrix}\mathbf{v}(\bigcup_{i=1}^{n}I_i) \\ s(\bigcup_{i=1}^{n}I_i) \end{bmatrix} = \bigoplus_{i=1}^{n}\begin{bmatrix}\mathbf{v}(I_i) \\ s(I_i)\end{bmatrix} = \begin{bmatrix} \sum_{i=1}^{n} \textrm{softmax}(s(I_i))\mathbf{v}(I_i) \\ \log(\sum_{i=1}^{n} \exp (s(I_i))) \end{bmatrix}
\begin{bmatrix}\mathbf{v}(\{1,2,\dots, n\})\\s(\{1,2,\dots, n\})\end{bmatrix} = \bigoplus_{i=1}^{n} \begin{bmatrix}\mathbf{v}_i\\s_i\end{bmatrix}
The above n-ary merge operator is consistent with the binary merge operator, and we can prove the operator is *communicative* and *associative*. There are different ways to get the attention state of the entire sequence by merging the attention states of index subsets, and the final outcome is mathematically equivalent:

Then :math:`\mathbf{v}(\{1,2,\dots, n\})` is the final attention output.
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/recursive-attention.png
:width: 600
:align: center
:alt: Recurisve Attention

.. note::

Expand Down

0 comments on commit b30dad3

Please sign in to comment.