-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Layernorm changes #681
base: main_perf
Are you sure you want to change the base?
Layernorm changes #681
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,50 +7,27 @@ | |
import triton.language as tl | ||
|
||
|
||
def is_cuda(): | ||
return triton.runtime.driver.active.get_current_target().backend == "cuda" | ||
|
||
|
||
def is_hip(): | ||
return triton.runtime.driver.active.get_current_target().backend == "hip" | ||
|
||
|
||
def get_cuda_autotune_config(): | ||
return [ | ||
triton.Config({}, num_warps=4, num_stages=1), | ||
triton.Config({}, num_warps=8, num_stages=1), | ||
triton.Config({}, num_warps=16, num_stages=1), | ||
] | ||
|
||
|
||
def get_hip_autotune_config(): | ||
def get_autotune_config(): | ||
return [ | ||
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=2), | ||
triton.Config({'waves_per_eu': 1}, num_warps=4), | ||
triton.Config({'waves_per_eu': 2}, num_warps=4), | ||
triton.Config({'waves_per_eu': 2}, num_warps=8), | ||
] | ||
|
||
|
||
def get_autotune_config(): | ||
if is_cuda(): | ||
return get_cuda_autotune_config() | ||
else: | ||
return get_hip_autotune_config() | ||
|
||
|
||
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
@triton.jit | ||
def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps, | ||
BLOCK_SIZE: tl.constexpr): | ||
def layernorm_kernel_blocked_impl(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps, | ||
BLOCK_SIZE: tl.constexpr): | ||
|
||
tl.assume(x_row_stride > 0) | ||
tl.assume(y_row_stride > 0) | ||
#program id | ||
row = tl.program_id(0) | ||
tl.assume(row > 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Question] What [Question] The kernel launch grid is |
||
x_ptr_start = x_ptr + (row * x_row_stride) | ||
y_ptr_start = y_ptr + (row * y_row_stride) | ||
|
||
|
@@ -60,7 +37,7 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r | |
mean = 0 | ||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
loop_num_l = loop_num | ||
for b in range(0, loop_num_l): | ||
for b in tl.range(0, loop_num_l, num_stages=3): | ||
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
x_block = tl.load(x_ptr_start + col_offsets).to(tl.float32) #Unmasked loads | ||
_mean += x_block | ||
|
@@ -75,7 +52,7 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r | |
#variance | ||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
loop_num_l = loop_num | ||
for b in range(0, loop_num_l): | ||
for b in tl.range(0, loop_num_l, num_stages=3): | ||
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
x_block = tl.load(x_ptr_start + col_offsets).to(tl.float32) #Unmasked loads | ||
x_block = x_block - mean | ||
|
@@ -92,7 +69,7 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r | |
|
||
#Normalize and store | ||
loop_num_l = loop_num | ||
for b in range(0, loop_num_l): | ||
for b in tl.range(0, loop_num_l, num_stages=3): | ||
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
w_block = tl.load(w_ptr + col_offsets) | ||
b_block = tl.load(b_ptr + col_offsets) | ||
|
@@ -112,17 +89,45 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r | |
tl.store(y_ptr_start + col_offsets, y_block, mask=mask) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need explicitly convert y_block back to y_ptr.dtype.type.element_ty ? |
||
|
||
|
||
def layernorm(x, w, b, eps=1e-5): | ||
n_rows, n_cols = x.shape | ||
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
@triton.jit | ||
def layernorm_kernel_impl(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps, | ||
BLOCK_SIZE: tl.constexpr): | ||
|
||
MAX_FUSED_SIZE = 65536 // x.element_size() | ||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) | ||
y = torch.empty_like(x) | ||
tl.assume(x_row_stride > 0) | ||
tl.assume(y_row_stride > 0) | ||
#program id | ||
row = tl.program_id(0) | ||
tl.assume(row > 0) | ||
x_ptr_start = x_ptr + (row * x_row_stride) | ||
y_ptr_start = y_ptr + (row * y_row_stride) | ||
col_offs = tl.arange(0, BLOCK_SIZE) | ||
|
||
#calculate mean | ||
x_ptrs = x_ptr_start + col_offs | ||
mask = col_offs < n_cols | ||
x_block = tl.load(x_ptrs, cache_modifier=".cg", mask=mask, other=0.0).to(tl.float32) #Unmasked loads | ||
mean = tl.sum(x_block, axis=0) / n_cols | ||
_x_block = tl.where(mask, x_block - mean, 0.0) | ||
var = tl.sum(_x_block * _x_block, axis=0) / n_cols | ||
rstd = tl.rsqrt(var + eps) | ||
|
||
w_block = tl.load(w_ptr + col_offs, mask=mask, other=0.0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need covert explicitly for w_block and b_block to tl.float32 |
||
b_block = tl.load(b_ptr + col_offs, mask=mask, other=0.0) | ||
y_block = (x_block - mean) * rstd | ||
y_block = y_block * w_block + b_block | ||
tl.store(y_ptr_start + col_offs, y_block, mask=mask) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need explicitly convert y_block back to y_ptr.dtype.type.element_ty ? |
||
|
||
num_programs = n_rows | ||
|
||
grid = lambda meta: (num_programs, ) | ||
layernorm_kernel[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, n_cols, eps, BLOCK_SIZE) | ||
def layernorm(x, y, w, b, eps=1e-5): | ||
n_rows, n_cols = x.shape | ||
|
||
grid = lambda meta: (n_rows, ) | ||
if n_cols <= 8192: | ||
layernorm_kernel_impl[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, n_cols, eps, | ||
BLOCK_SIZE=triton.next_power_of_2(n_cols)) | ||
else: | ||
layernorm_kernel_blocked_impl[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, n_cols, eps, BLOCK_SIZE=2048) | ||
|
||
return y | ||
|
||
|
@@ -138,10 +143,11 @@ def run_layernorm(M, N): | |
print(f"Running Layernorm on shape ({M},{N})") | ||
torch.manual_seed(0) | ||
x = torch.randn(M, N, device='cuda') | ||
y = torch.empty_like(x) | ||
w_shape = (N, ) | ||
w = torch.rand(w_shape, device='cuda') | ||
b = torch.rand(w_shape, device='cuda') | ||
y_triton = layernorm(x, w, b) | ||
y_triton = layernorm(x, y, w, b) | ||
|
||
return y_triton | ||
|
||
|
@@ -152,10 +158,11 @@ def run_layernorm(M, N): | |
def test_layernorm(M, N, eps=1e-5): | ||
torch.manual_seed(0) | ||
x = torch.randn(M, N, device='cuda') | ||
y = torch.empty_like(x) | ||
w_shape = (N, ) | ||
w = torch.rand(w_shape, device='cuda') | ||
b = torch.rand(w_shape, device='cuda') | ||
y_triton = layernorm(x, w, b, eps) | ||
y_triton = layernorm(x, y, w, b, eps) | ||
y_torch = torch.nn.functional.layer_norm(x, w_shape, w, b, eps) | ||
|
||
assert torch.allclose(y_triton, y_torch, rtol=1e-05, atol=1e-06) | ||
|
@@ -167,7 +174,10 @@ def test_layernorm(M, N, eps=1e-5): | |
|
||
def run_benchmark(args): | ||
config = [] | ||
if (args.M_benchmark): | ||
sweep_m = args.M_step != 0 | ||
sweep_n = args.N_step != 0 | ||
x_vals_list = [] | ||
if (sweep_m): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Non-blocking] Do we need The same suggestions cab be applied to line 190. |
||
val = args.M_start | ||
x_vals_list = [] | ||
while val <= args.M_end: | ||
|
@@ -177,12 +187,17 @@ def run_benchmark(args): | |
plot_name = str("layernorm-performance_" + args.dtype + "_N" + str(args.N_start) + "_M" + str(args.M_start) + | ||
"-" + str(args.M_end) + "-" + str(args.M_step)) | ||
x_names = ['M'] | ||
else: | ||
x_vals_list = [i for i in range(args.N_start, args.N_end, args.N_step)] | ||
elif (sweep_n): | ||
x_vals_list = [i for i in range(args.N_start, args.N_end + 1, args.N_step)] | ||
mn_args = {'M': args.M_start} | ||
plot_name = str("layernorm-performance_" + args.dtype + "_M" + str(args.M_start) + "_N" + str(args.N_start) + | ||
"-" + str(args.N_end) + "-" + str(args.N_step)) | ||
x_names = ['N'] | ||
else: | ||
x_vals_list.append(args.N_start) | ||
x_names = ['N'] | ||
mn_args = {'M': args.M_start} | ||
plot_name = str("layernorm-performance" + "_M" + str(args.M_start) + "_N" + str(args.N_start)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Non-blocking] I think f-string interpolation makes this piece of code more readable: plot_name = f"layernorm-performance_M{args.M_start}_N{args.N_start}" Probably the same suggestion can be applied in other assignments to |
||
dtype = arg_to_torch_dtype[args.dtype] | ||
|
||
print(plot_name) | ||
|
@@ -205,6 +220,7 @@ def run_benchmark(args): | |
@triton.testing.perf_report(config) | ||
def benchmark(M, N, provider): | ||
x = torch.randn(M, N, device='cuda', dtype=dtype) | ||
y = torch.empty_like(x) | ||
w_shape = (N, ) | ||
w = torch.rand(w_shape, device='cuda', dtype=dtype) | ||
b = torch.rand(w_shape, device='cuda', dtype=dtype) | ||
|
@@ -213,7 +229,7 @@ def benchmark(M, N, provider): | |
if provider == 'torch': | ||
ms = triton.testing.do_bench(lambda: torch_layernorm(x, w, b)) | ||
if provider == 'triton': | ||
ms = triton.testing.do_bench(lambda: layernorm(x, w, b)) | ||
ms = triton.testing.do_bench(lambda: layernorm(x, y, w, b)) | ||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) | ||
return gbps(ms) | ||
|
||
|
@@ -227,13 +243,12 @@ def parse_args(): | |
) | ||
|
||
parser.add_argument('-M', "--M_start", default="1", type=int) | ||
parser.add_argument('-Ms', "--M_step", default="2", type=int) | ||
parser.add_argument('-Me', "--M_end", default="512", type=int) | ||
parser.add_argument('-Mb', "--M_benchmark", default=False, type=bool) | ||
parser.add_argument('-Ms', "--M_step", default="0", type=int) | ||
parser.add_argument('-Me', "--M_end", default="0", type=int) | ||
|
||
parser.add_argument('-N', "--N_start", default="1024", type=int) | ||
parser.add_argument('-Ns', "--N_step", default="2048", type=int) | ||
parser.add_argument('-Ne', "--N_end", default="65536", type=int) | ||
parser.add_argument('-N', "--N_start", default="65536", type=int) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Non-blocking] I think default |
||
parser.add_argument('-Ns', "--N_step", default="0", type=int) | ||
parser.add_argument('-Ne', "--N_end", default="0", type=int) | ||
|
||
parser.add_argument('-d', "--dtype", default="fp16") | ||
parser.add_argument('-nb', "--no_benchmark", default=False, type=bool) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for rmsnorm, waves_per_eu could be 4, are we definitely sure 4 should be in the lists ?