diff --git a/docs/api/python/gemm.rst b/docs/api/python/gemm.rst index 495987b0..41111c9d 100644 --- a/docs/api/python/gemm.rst +++ b/docs/api/python/gemm.rst @@ -1,11 +1,22 @@ -.. _apigroup_gemm: +.. _apigemm: -flashinfer.group_gemm -===================== +flashinfer.gemm +=============== -This module provides a set of functions to group GEMM operations. +.. currentmodule:: flashinfer.gemm -.. currentmodule:: flashinfer.group_gemm +This module provides a set of GEMM operations. + +FP8 Batch GEMM +-------------- + +.. autosummary:: + :toctree: ../../generated + + bmm_fp8 + +Grouped GEMM +------------ .. autoclass:: SegmentGEMMWrapper :members: diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 90b29f74..18f07846 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -452,9 +452,10 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: Check :ref:`our tutorial` for page table layout. - It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general - multi-level cascade inference, where the KV-Cache of each level is stored in a unified - page table. This API will be deprecated in the future. + Warning + ------- + This API will be deprecated in the future, please use + :class:`MultiLevelCascadeAttentionWrapper` instead. Example ------- diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 81d51f79..b66a5011 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -19,6 +19,7 @@ import torch from .utils import get_indptr +from typing import Optional # mypy: disable-error-code="attr-defined" try: @@ -204,7 +205,7 @@ def bmm_fp8( A_scale: torch.Tensor, B_scale: torch.Tensor, dtype: torch.dtype, - out: torch.Tensor = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""BMM FP8 @@ -225,13 +226,36 @@ def bmm_fp8( dtype: torch.dtype out dtype, bf16 or fp16. - out: torch.Tensor - Out tensor, shape (b, m, n), bf16 or fp16. + out: Optional[torch.Tensor] + Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``. Returns ------- out: torch.Tensor Out tensor, shape (b, m, n), bf16 or fp16. + + Examples + -------- + >>> import torch + >>> import torch.nn.functional as F + >>> import flashinfer + >>> def to_float8(x, dtype=torch.float8_e4m3fn): + ... finfo = torch.finfo(dtype) + ... abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12) + ... scale = finfo.max / abs_max + ... x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + ... return x_scl_sat.to(dtype), scale.float().reciprocal() + >>> + >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) + >>> input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn) + >>> # column major weight + >>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + >>> weight_fp8, weight_inv_s = to_float8(weight, dtype=torch.float8_e4m3fn) + >>> out = flashinfer.bmm_fp8(input_fp8, weight_fp8, input_inv_s, weight_inv_s, torch.bfloat16) + >>> out.shape + torch.Size([16, 48, 80]) + >>> out.dtype + torch.bfloat16 """ if out is None: out = torch.empty( diff --git a/python/tests/test_group_gemm.py b/python/tests/test_group_gemm.py index 505598cf..f0f5cc07 100644 --- a/python/tests/test_group_gemm.py +++ b/python/tests/test_group_gemm.py @@ -38,7 +38,7 @@ def test_segment_gemm( pytest.skip("batch_size * num_rows_per_batch too large for test.") torch.manual_seed(42) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) - segment_gemm = flashinfer.group_gemm.SegmentGEMMWrapper(workspace_buffer) + segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer) x = ( (torch.randn(batch_size * num_rows_per_batch, d_in) / 10) .to(0)