Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layernorm changes #681

Open
wants to merge 6 commits into
base: main_perf
Choose a base branch
from
Open

Layernorm changes #681

wants to merge 6 commits into from

Conversation

vgokhale
Copy link
Collaborator

  1. Add support to specify steps in benchmark or single value
  2. Add / remove autotune configs
  3. Added a non-blocked implementation for small shapes

@vgokhale vgokhale self-assigned this Dec 12, 2024
@vgokhale vgokhale requested review from scxiao and brunomazzottiamd and removed request for scxiao December 16, 2024 18:10
parser.add_argument('-N', "--N_start", default="1024", type=int)
parser.add_argument('-Ns', "--N_step", default="2048", type=int)
parser.add_argument('-Ne', "--N_end", default="65536", type=int)
parser.add_argument('-N', "--N_start", default="65536", type=int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Non-blocking] I think default argparse values do not necessarily need to be strings, you can use 65536 instead of "65536", 0 instead of "0" and so on...

x_vals_list.append(args.N_start)
x_names = ['N']
mn_args = {'M': args.M_start}
plot_name = str("layernorm-performance" + "_M" + str(args.M_start) + "_N" + str(args.N_start))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Non-blocking] I think f-string interpolation makes this piece of code more readable:

plot_name = f"layernorm-performance_M{args.M_start}_N{args.N_start}"

Probably the same suggestion can be applied in other assignments to plot_name as well (lines 193, 187).

sweep_m = args.M_step != 0
sweep_n = args.N_step != 0
x_vals_list = []
if (sweep_m):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Non-blocking] Do we need (...) in Python if statements? Is it a coding pattern for Booleans? Being naive I would just do if sweep_m:.

The same suggestions cab be applied to line 190.

#program id
row = tl.program_id(0)
tl.assume(row > 0)
Copy link
Collaborator

@brunomazzottiamd brunomazzottiamd Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Question] What tl.assume does? Is it a compile time or run time assertion?

[Question] The kernel launch grid is grid = lambda meta: (n_rows, ). What's the program ID range? Is it [0, n_rows) (open end) or [1, n_rows] (closed end)? If it is the first option, I believe that we should check for row >= 0.

triton.Config({'waves_per_eu': 2}, num_warps=2),
triton.Config({'waves_per_eu': 1}, num_warps=4),
triton.Config({'waves_per_eu': 2}, num_warps=4),
triton.Config({'waves_per_eu': 2}, num_warps=8),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for rmsnorm, waves_per_eu could be 4, are we definitely sure 4 should be in the lists ?

b_block = tl.load(b_ptr + col_offs, mask=mask, other=0.0)
y_block = (x_block - mean) * rstd
y_block = y_block * w_block + b_block
tl.store(y_ptr_start + col_offs, y_block, mask=mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need explicitly convert y_block back to y_ptr.dtype.type.element_ty ?

@@ -112,17 +89,45 @@ def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_r
tl.store(y_ptr_start + col_offsets, y_block, mask=mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need explicitly convert y_block back to y_ptr.dtype.type.element_ty ?

var = tl.sum(_x_block * _x_block, axis=0) / n_cols
rstd = tl.rsqrt(var + eps)

w_block = tl.load(w_ptr + col_offs, mask=mask, other=0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need covert explicitly for w_block and b_block to tl.float32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants