Skip to content

Commit

Permalink
[Kernel][Triton][AMD] Use block size heuristic for avg 2.8x speedup f…
Browse files Browse the repository at this point in the history
…or int8 models (#11698)

Signed-off-by: Randall Smith <[email protected]>
  • Loading branch information
rasmith authored Jan 8, 2025
1 parent 56fe4c2 commit 526de82
Showing 1 changed file with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor,
bias: Optional[torch.Tensor] = None,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32) -> torch.Tensor:
block_size_k: int = 32,
use_heuristic=True) -> torch.Tensor:
M, K = input.shape
N = weight.shape[1]

Expand All @@ -152,6 +153,20 @@ def triton_scaled_mm(input: torch.Tensor,

has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1

if use_heuristic:
is_small_N = N < 8192
next_power_of_2_M = max(32, triton.next_power_of_2(M))
if next_power_of_2_M <= 32:
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
elif next_power_of_2_M <= 64:
tile_shape = (64, 64, 256)
elif next_power_of_2_M <= 128:
tile_shape = (64, 128, 128)
else:
tile_shape = (128, 128, 128)

block_size_m, block_size_n, block_size_k = tile_shape

block_size_sa = 1 if has_scalar(scale_a) else block_size_m
block_size_sb = 1 if has_scalar(scale_b) else block_size_n

Expand Down

0 comments on commit 526de82

Please sign in to comment.