-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rmsnorm kernel #633
Add rmsnorm kernel #633
Conversation
2794514
to
b356942
Compare
triton.Config({}, num_warps=16, num_stages=1), | ||
] | ||
|
||
def get_hip_autotune_config(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing wrong here. Just wondering if these configs are comprehensive enough to cover the typical use cases you have encountered in actual models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, I just came up with these on what I thought made sense. Other than that, no systematic way to come up with this. I was thinking may be I can add an argument to take in a custom config file. That way one doesn't need to touch the code in this file if some other configs need to be benchmarked.
Actually, I noticed a small bug for CUDA. There are repeated entries.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
python/perf-kernels/rmsnorm.py
Outdated
row = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0) | ||
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0) | ||
row_norm = row * row #square each value | ||
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For RMSNorm, is it always safe to accumulate FP16 values with a FP16 accumulator? Do we need a FP32 accumulator here? I'm supposing that tl.sum
of FP16's is also a FP16.
For instance, if we square 7 three-digit numbers and then add them together we're already overflowing FP16.
I'm not that familiar with normalization layers, maybe I'm just too conservative and I'm thinking too much in the general case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to think about this case a bit more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try to give you some food for thought...
Use a wider data type as accumulator
Our GEMM kernel does this:
# INT32 accumulator for INT8 data and FP32 accumulator for everything else.
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
You can do something similar. I think it's safe to accumulate using FP32.
Use a scale factor
You can scale down the numbers before squaring them and scale up the mean of squares afterwards. Here is a NumPy prototype of the idea:
import numpy as np
def mean_sqr(x):
return np.sum(x * x) / len(x)
# Scale factor can be a constant or computed on the fly.
def mean_sqr_with_scale_factor(x, scale_factor):
x = (1 / scale_factor) * x
mean_sqr = np.sum(x * x) / len(x)
return (scale_factor * scale_factor) * mean_sqr
def compute_scale_factor(x):
max_x = np.max(np.abs(x))
return np.exp(np.floor(np.log(max_x))) if max_x > 0 else 1
np.random.seed(42)
x = np.random.uniform(size=4096, low=-500.0, high=500.0).astype(np.float32)
print(mean_sqr(x))
print(mean_sqr_with_scale_factor(x, compute_scale_factor(x)))
Larger scale factors reduce the risk of overflow but increase the chance of losing precision. Smaller scale factors retain precision but may not prevent overflow effectively if the numbers are too large.
Inspect PyTorch implementation
Try to find out what PyTorch is doing. If PyTorch doesn't care about sum of squares overflow then we can follow it and doesn't care as well.
Seeing this PR reminded that I added a kernel without a benchmark and without a test! Shame on me! I think adding benchmark and correctness test must be mandatory from now on. Reviewers should reject PRs lacking these features. |
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
b356942
to
c949904
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add tests for this kernel?
c949904
to
f625a47
Compare
@rahulbatra85 There are failures in the RMS prop code |
yeah, looking. |
@micmelesse ah ok, so rms norm layer was added in Pytorch starting with version 2.4. Which docker image does the CI use? |
f625a47
to
fa9cc06
Compare
@rahulbatra85 It's |
fa9cc06
to
f80aed7
Compare
@micmelesse It's passing with new docker image. |
Adds forward kernel for RMSNorm