Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layernorm changes #681

Open
wants to merge 6 commits into
base: main_perf
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 73 additions & 58 deletions python/perf-kernels/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,27 @@
import triton.language as tl


def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"


def get_cuda_autotune_config():
return [
triton.Config({}, num_warps=4, num_stages=1),
triton.Config({}, num_warps=8, num_stages=1),
triton.Config({}, num_warps=16, num_stages=1),
]


def get_hip_autotune_config():
def get_autotune_config():
return [
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=1),
triton.Config({'waves_per_eu': 2}, num_warps=1),
triton.Config({'waves_per_eu': 2}, num_warps=2),
triton.Config({'waves_per_eu': 1}, num_warps=4),
triton.Config({'waves_per_eu': 2}, num_warps=4),
triton.Config({'waves_per_eu': 2}, num_warps=8),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for rmsnorm, waves_per_eu could be 4, are we definitely sure 4 should be in the lists ?

]


def get_autotune_config():
if is_cuda():
return get_cuda_autotune_config()
else:
return get_hip_autotune_config()


@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps,
BLOCK_SIZE: tl.constexpr):
def layernorm_kernel_blocked_impl(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps,
BLOCK_SIZE: tl.constexpr):

tl.assume(x_row_stride > 0)
tl.assume(y_row_stride > 0)
#program id
row = tl.program_id(0)
tl.assume(row > 0)
Copy link
Collaborator

@brunomazzottiamd brunomazzottiamd Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Question] What tl.assume does? Is it a compile time or run time assertion?

[Question] The kernel launch grid is grid = lambda meta: (n_rows, ). What's the program ID range? Is it [0, n_rows) (open end) or [1, n_rows] (closed end)? If it is the first option, I believe that we should check for row >= 0.

x_ptr_start = x_ptr + (row * x_row_stride)
y_ptr_start = y_ptr + (row * y_row_stride)

Expand All @@ -60,7 +37,7 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
loop_num_l = loop_num
for b in range(0, loop_num_l):
for b in tl.range(0, loop_num_l, num_stages=3):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x_block = tl.load(x_ptr_start + col_offsets).to(tl.float32) #Unmasked loads
_mean += x_block
Expand All @@ -75,7 +52,7 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r
#variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
loop_num_l = loop_num
for b in range(0, loop_num_l):
for b in tl.range(0, loop_num_l, num_stages=3):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x_block = tl.load(x_ptr_start + col_offsets).to(tl.float32) #Unmasked loads
x_block = x_block - mean
Expand All @@ -92,7 +69,7 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r

#Normalize and store
loop_num_l = loop_num
for b in range(0, loop_num_l):
for b in tl.range(0, loop_num_l, num_stages=3):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
w_block = tl.load(w_ptr + col_offsets)
b_block = tl.load(b_ptr + col_offsets)
Expand All @@ -112,17 +89,45 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r
tl.store(y_ptr_start + col_offsets, y_block, mask=mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need explicitly convert y_block back to y_ptr.dtype.type.element_ty ?



def layernorm(x, w, b, eps=1e-5):
n_rows, n_cols = x.shape
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def layernorm_kernel_impl(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps,
BLOCK_SIZE: tl.constexpr):

MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
y = torch.empty_like(x)
tl.assume(x_row_stride > 0)
tl.assume(y_row_stride > 0)
#program id
row = tl.program_id(0)
tl.assume(row > 0)
x_ptr_start = x_ptr + (row * x_row_stride)
y_ptr_start = y_ptr + (row * y_row_stride)
col_offs = tl.arange(0, BLOCK_SIZE)

#calculate mean
x_ptrs = x_ptr_start + col_offs
mask = col_offs < n_cols
x_block = tl.load(x_ptrs, cache_modifier=".cg", mask=mask, other=0.0).to(tl.float32) #Unmasked loads
mean = tl.sum(x_block, axis=0) / n_cols
_x_block = tl.where(mask, x_block - mean, 0.0)
var = tl.sum(_x_block * _x_block, axis=0) / n_cols
rstd = tl.rsqrt(var + eps)

w_block = tl.load(w_ptr + col_offs, mask=mask, other=0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need covert explicitly for w_block and b_block to tl.float32

b_block = tl.load(b_ptr + col_offs, mask=mask, other=0.0)
y_block = (x_block - mean) * rstd
y_block = y_block * w_block + b_block
tl.store(y_ptr_start + col_offs, y_block, mask=mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need explicitly convert y_block back to y_ptr.dtype.type.element_ty ?


num_programs = n_rows

grid = lambda meta: (num_programs, )
layernorm_kernel[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, n_cols, eps, BLOCK_SIZE)
def layernorm(x, y, w, b, eps=1e-5):
n_rows, n_cols = x.shape

grid = lambda meta: (n_rows, )
if n_cols <= 8192:
layernorm_kernel_impl[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, n_cols, eps,
BLOCK_SIZE=triton.next_power_of_2(n_cols))
else:
layernorm_kernel_blocked_impl[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, n_cols, eps, BLOCK_SIZE=2048)

return y

Expand All @@ -138,10 +143,11 @@ def run_layernorm(M, N):
print(f"Running Layernorm on shape ({M},{N})")
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
y = torch.empty_like(x)
w_shape = (N, )
w = torch.rand(w_shape, device='cuda')
b = torch.rand(w_shape, device='cuda')
y_triton = layernorm(x, w, b)
y_triton = layernorm(x, y, w, b)

return y_triton

Expand All @@ -152,10 +158,11 @@ def run_layernorm(M, N):
def test_layernorm(M, N, eps=1e-5):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
y = torch.empty_like(x)
w_shape = (N, )
w = torch.rand(w_shape, device='cuda')
b = torch.rand(w_shape, device='cuda')
y_triton = layernorm(x, w, b, eps)
y_triton = layernorm(x, y, w, b, eps)
y_torch = torch.nn.functional.layer_norm(x, w_shape, w, b, eps)

assert torch.allclose(y_triton, y_torch, rtol=1e-05, atol=1e-06)
Expand All @@ -167,7 +174,10 @@ def test_layernorm(M, N, eps=1e-5):

def run_benchmark(args):
config = []
if (args.M_benchmark):
sweep_m = args.M_step != 0
sweep_n = args.N_step != 0
x_vals_list = []
if (sweep_m):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Non-blocking] Do we need (...) in Python if statements? Is it a coding pattern for Booleans? Being naive I would just do if sweep_m:.

The same suggestions cab be applied to line 190.

val = args.M_start
x_vals_list = []
while val <= args.M_end:
Expand All @@ -177,12 +187,17 @@ def run_benchmark(args):
plot_name = str("layernorm-performance_" + args.dtype + "_N" + str(args.N_start) + "_M" + str(args.M_start) +
"-" + str(args.M_end) + "-" + str(args.M_step))
x_names = ['M']
else:
x_vals_list = [i for i in range(args.N_start, args.N_end, args.N_step)]
elif (sweep_n):
x_vals_list = [i for i in range(args.N_start, args.N_end + 1, args.N_step)]
mn_args = {'M': args.M_start}
plot_name = str("layernorm-performance_" + args.dtype + "_M" + str(args.M_start) + "_N" + str(args.N_start) +
"-" + str(args.N_end) + "-" + str(args.N_step))
x_names = ['N']
else:
x_vals_list.append(args.N_start)
x_names = ['N']
mn_args = {'M': args.M_start}
plot_name = str("layernorm-performance" + "_M" + str(args.M_start) + "_N" + str(args.N_start))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Non-blocking] I think f-string interpolation makes this piece of code more readable:

plot_name = f"layernorm-performance_M{args.M_start}_N{args.N_start}"

Probably the same suggestion can be applied in other assignments to plot_name as well (lines 193, 187).

dtype = arg_to_torch_dtype[args.dtype]

print(plot_name)
Expand All @@ -205,6 +220,7 @@ def run_benchmark(args):
@triton.testing.perf_report(config)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=dtype)
y = torch.empty_like(x)
w_shape = (N, )
w = torch.rand(w_shape, device='cuda', dtype=dtype)
b = torch.rand(w_shape, device='cuda', dtype=dtype)
Expand All @@ -213,7 +229,7 @@ def benchmark(M, N, provider):
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch_layernorm(x, w, b))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: layernorm(x, w, b))
ms = triton.testing.do_bench(lambda: layernorm(x, y, w, b))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)

Expand All @@ -227,13 +243,12 @@ def parse_args():
)

parser.add_argument('-M', "--M_start", default="1", type=int)
parser.add_argument('-Ms', "--M_step", default="2", type=int)
parser.add_argument('-Me', "--M_end", default="512", type=int)
parser.add_argument('-Mb', "--M_benchmark", default=False, type=bool)
parser.add_argument('-Ms', "--M_step", default="0", type=int)
parser.add_argument('-Me', "--M_end", default="0", type=int)

parser.add_argument('-N', "--N_start", default="1024", type=int)
parser.add_argument('-Ns', "--N_step", default="2048", type=int)
parser.add_argument('-Ne', "--N_end", default="65536", type=int)
parser.add_argument('-N', "--N_start", default="65536", type=int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Non-blocking] I think default argparse values do not necessarily need to be strings, you can use 65536 instead of "65536", 0 instead of "0" and so on...

parser.add_argument('-Ns', "--N_step", default="0", type=int)
parser.add_argument('-Ne', "--N_end", default="0", type=int)

parser.add_argument('-d', "--dtype", default="fp16")
parser.add_argument('-nb', "--no_benchmark", default=False, type=bool)
Expand Down
Loading