-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Layernorm changes #681
base: main_perf
Are you sure you want to change the base?
Layernorm changes #681
Conversation
vgokhale
commented
Dec 12, 2024
- Add support to specify steps in benchmark or single value
- Add / remove autotune configs
- Added a non-blocked implementation for small shapes
parser.add_argument('-N', "--N_start", default="1024", type=int) | ||
parser.add_argument('-Ns', "--N_step", default="2048", type=int) | ||
parser.add_argument('-Ne', "--N_end", default="65536", type=int) | ||
parser.add_argument('-N', "--N_start", default="65536", type=int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Non-blocking] I think default argparse
values do not necessarily need to be strings, you can use 65536
instead of "65536"
, 0
instead of "0"
and so on...
x_vals_list.append(args.N_start) | ||
x_names = ['N'] | ||
mn_args = {'M': args.M_start} | ||
plot_name = str("layernorm-performance" + "_M" + str(args.M_start) + "_N" + str(args.N_start)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Non-blocking] I think f-string interpolation makes this piece of code more readable:
plot_name = f"layernorm-performance_M{args.M_start}_N{args.N_start}"
Probably the same suggestion can be applied in other assignments to plot_name
as well (lines 193, 187).
sweep_m = args.M_step != 0 | ||
sweep_n = args.N_step != 0 | ||
x_vals_list = [] | ||
if (sweep_m): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Non-blocking] Do we need (...)
in Python if
statements? Is it a coding pattern for Booleans? Being naive I would just do if sweep_m:
.
The same suggestions cab be applied to line 190.
#program id | ||
row = tl.program_id(0) | ||
tl.assume(row > 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Question] What tl.assume
does? Is it a compile time or run time assertion?
[Question] The kernel launch grid is grid = lambda meta: (n_rows, )
. What's the program ID range? Is it [0, n_rows) (open end) or [1, n_rows] (closed end)? If it is the first option, I believe that we should check for row >= 0
.
triton.Config({'waves_per_eu': 2}, num_warps=2), | ||
triton.Config({'waves_per_eu': 1}, num_warps=4), | ||
triton.Config({'waves_per_eu': 2}, num_warps=4), | ||
triton.Config({'waves_per_eu': 2}, num_warps=8), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for rmsnorm, waves_per_eu could be 4, are we definitely sure 4 should be in the lists ?
b_block = tl.load(b_ptr + col_offs, mask=mask, other=0.0) | ||
y_block = (x_block - mean) * rstd | ||
y_block = y_block * w_block + b_block | ||
tl.store(y_ptr_start + col_offs, y_block, mask=mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need explicitly convert y_block back to y_ptr.dtype.type.element_ty ?
@@ -112,17 +89,45 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r | |||
tl.store(y_ptr_start + col_offsets, y_block, mask=mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need explicitly convert y_block back to y_ptr.dtype.type.element_ty ?
var = tl.sum(_x_block * _x_block, axis=0) / n_cols | ||
rstd = tl.rsqrt(var + eps) | ||
|
||
w_block = tl.load(w_ptr + col_offs, mask=mask, other=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need covert explicitly for w_block and b_block to tl.float32