diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index bdbf1853..53b7688e 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -34,3 +34,9 @@ BatchPrefillWithSharedPrefixPagedKVCacheWrapper, ) from .page import append_paged_kv_cache + +try: + from ._build_meta import __version__ as __version__ +except ImportError: + with open("version.txt") as f: + __version__ = f.read().strip() diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index d422bf0d..90e4d25e 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -117,7 +117,7 @@ def merge_state_in_place( s_other : torch.Tensor The other logsumexp value to be merged, expected to be a float32 tensor, shape: ``(seq_len, num_heads)``. - + Example ------- >>> import torch @@ -135,7 +135,7 @@ def merge_state_in_place( def merge_states(v: torch.Tensor, s: torch.Tensor): - r"""Merge multiple attention states (v, s). + r"""Merge multiple attention states (v, s). Parameters ---------- @@ -154,7 +154,7 @@ def merge_states(v: torch.Tensor, s: torch.Tensor): S : torch.Tensor The logsumexp value from the merged KV-segments, shape: ``[seq_len, num_heads]``. - + Example ------- >>> import torch @@ -229,7 +229,7 @@ def batch_decode_with_shared_prefix_padded_kv_cache( ------- V : torch.Tensor The attention output, shape: ``[batch_size, num_heads, head_dim]`` - + Example ------- >>> import torch @@ -312,7 +312,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: ... ) >>> batch_size = 7 >>> shared_prefix_len = 8192 - >>> unique_kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") + >>> unique_kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> unique_kv_page_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) @@ -355,7 +355,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: ... # compute batch decode attention, reuse auxiliary data structures for all layers ... o = wrapper.forward(q, k_shared, v_shared, unique_kv_data) ... outputs.append(o) - ... + ... >>> # clear auxiliary data structures >>> wrapper.end_forward() >>> outputs[0].shape @@ -547,7 +547,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) - >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") + >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> paged_kv_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) @@ -590,7 +590,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: ... q, k_shared, v_shared, kv_data, causal=True ... ) ... outputs.append(o) - ... + ... s[0].shape>>> # clear auxiliary data structures >>> prefill_wrapper.end_forward() >>> outputs[0].shape diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index eb27526f..dc67024b 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -179,7 +179,7 @@ def batch_decode_with_padded_kv_cache( ------- torch.Tensor The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. - + Examples -------- >>> import torch @@ -270,7 +270,7 @@ def batch_decode_with_padded_kv_cache_return_lse( The attention output, shape: [batch_size, num_qo_heads, head_dim] S : torch.Tensor The logsumexp of attention scores, Shape: [batch_size, num_qo_heads] - + Examples -------- >>> import torch @@ -342,7 +342,7 @@ class BatchDecodeWithPagedKVCacheWrapper: ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 - >>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") + >>> kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> kv_page_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) @@ -374,7 +374,7 @@ class BatchDecodeWithPagedKVCacheWrapper: ... # compute batch decode attention, reuse auxiliary data structures for all layers ... o = decode_wrapper.forward(q, kv_data) ... outputs.append(o) - ... + ... >>> # clear auxiliary data structures >>> decode_wrapper.end_forward() >>> outputs[0].shape @@ -589,7 +589,7 @@ def forward_return_lse( The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. S : torch.Tensor The logsumexp of attention scores, Shape: ``[batch_size, num_qo_heads]``. - + Notes ----- Please refer to the :ref:`tutorial ` for a detailed diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 9b6e92f0..d68db529 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -67,7 +67,7 @@ def append_paged_kv_cache( shape: ``[batch_size]``. kv_layout : str The layout of the paged kv-cache, either ``NHD`` or ``HND``. - + Example ------- >>> import torch @@ -96,7 +96,7 @@ def append_paged_kv_cache( >>> # 25 = (2 - 1) * 16 + 9 >>> # 22 = (2 - 1) * 16 + 6 >>> kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0") - >>> + >>> >>> flashinfer.append_paged_kv_cache( ... k_append, ... v_append, @@ -111,7 +111,7 @@ def append_paged_kv_cache( ----- Please refer to the :ref:`tutorial ` for a detailed explanation of the log-sum-exp function and attention states. - + The function assumes that the space for appended k/v have already been allocated, which means :attr:`kv_indices`, :attr:`kv_indptr`, :attr:`kv_last_page_len` has incorporated appended k/v. diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 9d4eaa6a..0e617a45 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -277,7 +277,7 @@ class BatchPrefillWithPagedKVCacheWrapper: >>> qo_indptr = torch.tensor( ... [0, 33, 44, 55, 66, 77, 88, nnz_qo], dtype=torch.int32, device="cuda:0" ... ) - >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") + >>> paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0") >>> paged_kv_indptr = torch.tensor( ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0" ... ) @@ -308,7 +308,7 @@ class BatchPrefillWithPagedKVCacheWrapper: ... q, kv_data, causal=True ... ) ... outputs.append(o) - ... + ... >>> # clear auxiliary data structures >>> prefill_wrapper.end_forward() >>> outputs[0].shape @@ -582,7 +582,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: ... q, k, v, causal=True ... ) ... outputs.append(o) - ... + ... >>> # clear auxiliary data structures >>> prefill_wrapper.end_forward() >>> outputs[0].shape