Skip to content

Commit

Permalink
improve rms norm code quality (#43)
Browse files Browse the repository at this point in the history
## Summary
improve rms norm code quality (#43)

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

I am preparing for cuda mode talk, and i found rms norm is not very
readable.

1. var names are not precise
2. math formula is incorrect **in the comment**
3. comment not informative

This PR improves the above.

Also rerun benchmark for rms to ensure the perf does not degrade. (HF's
memory decreases a bit probably due to torch version)

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence


```
make checkstyle test test-convergence
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Skipped 1 files
All done! ✨ 🍰 ✨
48 files left unchanged.
pytest --disable-warnings test/ --ignore=test/convergence
============================================================ test session starts ============================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 111 items                                                                                                                         

test/transformers/test_cross_entropy.py ..........................................................                                    [ 52%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                           [ 57%]
test/transformers/test_geglu.py ........                                                                                              [ 64%]
test/transformers/test_rms_norm.py ................                                                                                   [ 79%]
test/transformers/test_rope.py ............                                                                                           [ 90%]
test/transformers/test_swiglu.py ........                                                                                             [ 97%]
test/transformers/test_transformers_monkey_patch.py .                                                                                 [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                            [100%]

====================================================== 111 passed in 61.55s (0:01:01) =======================================================
HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence
============================================================ test session starts ============================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 8 items                                                                                                                           

test/convergence/test_mini_models.py ......                                                                                           [ 75%]
test/convergence/test_mini_models_no_logits.py ..                                                                                     [100%]

======================================================= 8 passed in 98.51s (0:01:38) ========================================================
```
  • Loading branch information
ByronHsu authored Aug 19, 2024
1 parent ce7735d commit 01cae6e
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 34 deletions.
12 changes: 6 additions & 6 deletions benchmark/rms_norm_memory/rmsnorm-full-memory-benchmark.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
N,Liger,Hugging Face
1024.0,32.0,87.6
2048.0,64.0,175.2
4096.0,128.1,350.5
8192.0,236.1,700.9
16384.0,368.2,1401.8
32768.0,624.4,2803.6
1024.0,32.0,79.6
2048.0,64.0,159.2
4096.0,128.1,318.5
8192.0,236.1,636.9
16384.0,368.2,1273.8
32768.0,624.4,2547.6
Binary file modified benchmark/rms_norm_memory/rmsnorm-full-memory-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions benchmark/rms_norm_speed/rmsnorm-bwd-speed-benchmark.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
N,Liger,Hugging Face
1024.0,0.2,0.2
2048.0,0.2,0.4
2048.0,0.1,0.4
4096.0,0.2,0.7
8192.0,0.2,1.4
16384.0,0.3,2.6
32768.0,1.1,5.0
32768.0,1.0,5.0
Binary file modified benchmark/rms_norm_speed/rmsnorm-bwd-speed-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions benchmark/rms_norm_speed/rmsnorm-full-speed-benchmark.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
N,Liger,Hugging Face
1024.0,0.6,0.4
2048.0,0.6,0.5
4096.0,0.6,1.0
8192.0,0.6,1.9
16384.0,0.6,3.6
32768.0,1.3,6.9
1024.0,0.3,0.3
2048.0,0.3,0.5
4096.0,0.4,1.0
8192.0,0.3,1.9
16384.0,0.4,3.6
32768.0,1.2,6.9
Binary file modified benchmark/rms_norm_speed/rmsnorm-full-speed-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmark/rms_norm_speed/rmsnorm-fwd-speed-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 38 additions & 20 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ def _rms_norm_forward(
BLOCK_SIZE: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
Reference:
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
3. https://arxiv.org/pdf/1910.07467
"""

row_idx = tl.program_id(0)
Expand All @@ -36,16 +39,17 @@ def _rms_norm_forward(
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)

row_var = tl.sum(X_row * X_row, axis=0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
inv_rms = tl.math.rsqrt(mean_square + eps)

# trick: row_var is tiny compared to X_row because it just has one per row we can save 4 ops (*, sum, /, rqrt) if we cache it
tl.store(r_ptr, inv_var)
# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(r_ptr, inv_rms)

normed = X_row * inv_var
Y_row = X_row * inv_rms * W_row

output = normed * W_row
tl.store(Y_ptr + col_offsets, output, mask=mask)
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)


@triton.jit
Expand All @@ -65,9 +69,10 @@ def _rms_norm_backward(
BLOCK_SIZE: tl.constexpr,
):
"""
dx = (1 / var(x)) * (dy * w - (1/N) * (dy * w) dot x) * x
dw = sum(dy * (x / var(x)))
dx = (1 / RMS) * [dy * w - (1 / N) * (1 / RMS^2) * ((dy * w) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
dw = sum(dy * (x / RMS)). summation over BxT dimension
"""

row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Expand All @@ -81,33 +86,41 @@ def _rms_norm_backward(
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)

# Get saved row variance
inv_var = tl.load(r_ptr)

normed = X_row * inv_var
# Get cached rms
inv_rms_row = tl.load(r_ptr)

dY_W = dY_row * W_row
dY_normed = dY_row * normed

rowsum_dY_normed = tl.sum(dY_W * normed, axis=0)
output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
tl.store(dY_ptr + col_offsets, output, mask=mask)
dX_row = (inv_rms_row) * (
dY_row * W_row
- (1 / n_cols)
* inv_rms_row
* inv_rms_row
* tl.sum(dY_row * W_row * X_row, axis=0)
* X_row
)
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)

# calculate the gradient of W
tl.store(dW_ptr + col_offsets, dY_normed, mask=mask)
dW_row = dY_row * X_row * inv_rms_row
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)


class LigerRMSNormFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, eps):
"""
X: (B, T, H) or (BxT, H)
W: (H,)
"""

shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# r is to cache (1/rms) for each row
r = torch.empty(n_rows, dtype=X.dtype, device=X.device)

# Check constraints.
Expand Down Expand Up @@ -139,13 +152,18 @@ def forward(ctx, X, W, eps):
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
"""
Y: (B, T, H) or (BxT, H)
"""

shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
X, W, r = ctx.saved_tensors
n_rows, n_cols = dY.shape
dW = torch.zeros_like(X)

# Here we use dY to store the value of dX to save memory
_rms_norm_backward[(n_rows,)](
dY,
dY.stride(0),
Expand Down

0 comments on commit 01cae6e

Please sign in to comment.