Skip to content

Commit

Permalink
Support out-of-place RMSNorm to fix gemma2 (#376)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

Fix #370

Gemma2 has convergence issue for in-place rmsnorm. 


![image](https://github.com/user-attachments/assets/f1c8c871-0c59-4d86-929a-152808c54bbd)

Looking at the diagram, the residual sits between double rmsnorm. At the
yellow highlight region, you can see dY is actually needed after it is
modified in-place. Therefore, we should do out-of-place.

This does not happen for other models because they don't have double
rmsnorm.

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
ByronHsu authored Nov 12, 2024
1 parent 5ef09d5 commit 563e5e5
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
33 changes: 27 additions & 6 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def _rms_norm_forward_kernel(
def _rms_norm_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
Expand Down Expand Up @@ -146,6 +148,8 @@ def _rms_norm_backward_kernel(
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

dY_ptr += row_start * dY_row_stride
dX_ptr += row_start * dX_row_stride

X_ptr += row_start * X_row_stride
RSTD_ptr += row_start

Expand Down Expand Up @@ -184,9 +188,10 @@ def _rms_norm_backward_kernel(
# here X_row is already in fp32 (see previous if block)
dW_row += dY_row * (X_row * rstd_row)

tl.store(dY_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)

dY_ptr += dY_row_stride
dX_ptr += dX_row_stride
X_ptr += X_row_stride
RSTD_ptr += RSTD_row_stride

Expand Down Expand Up @@ -251,7 +256,9 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode


def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
def rms_norm_backward(
dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
Expand All @@ -265,10 +272,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
# Here we use dY to store the value of dX to save memory

if in_place is True:
dX = dY
else:
dX = torch.zeros_like(dY)

_rms_norm_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
Expand All @@ -286,8 +300,9 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dX = dY.view(*shape)
dX = dX.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)

return dX, dW


Expand All @@ -307,11 +322,15 @@ class LigerRMSNormFunction(torch.autograd.Function):
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
`in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
"""

@staticmethod
@ensure_contiguous
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
"""
X: (B, T, H) or (BxT, H)
W: (H,)
Expand All @@ -321,6 +340,7 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
)
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.in_place = in_place
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, W, RSTD)
Expand All @@ -342,5 +362,6 @@ def backward(ctx, dY):
ctx.casting_mode,
ctx.BLOCK_SIZE,
ctx.num_warps,
ctx.in_place,
)
return dX, dW, None, None, None
return dX, dW, None, None, None, None
2 changes: 1 addition & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def apply_liger_kernel_to_gemma2(
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model

LigerRMSNormForGemma2 = partial(
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
)
_patch_rms_norm_module_for_gemma2 = partial(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma"
Expand Down
14 changes: 11 additions & 3 deletions src/liger_kernel/transformers/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

class LigerRMSNorm(nn.Module):
def __init__(
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
self,
hidden_size,
eps=1e-6,
offset=0.0,
casting_mode="llama",
init_fn="ones",
in_place=True,
):
super().__init__()
assert init_fn in [
Expand All @@ -16,10 +22,11 @@ def __init__(
self.weight = nn.Parameter(
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
)
self.variance_epsilon, self.offset, self.casting_mode = (
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
eps,
offset,
casting_mode,
in_place,
)

def forward(self, hidden_states):
Expand All @@ -29,7 +36,8 @@ def forward(self, hidden_states):
self.variance_epsilon,
self.offset,
self.casting_mode,
self.in_place,
)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}"
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
15 changes: 13 additions & 2 deletions test/transformers/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,16 @@ def forward(self, x):
(BaseRMSNorm, 0.0, "none"),
],
)
def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode):
@pytest.mark.parametrize(
"in_place",
[
True,
False,
],
)
def test_correctness(
bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, in_place
):
_tensor = torch.randn(bs, sl, hd, device="cuda", dtype=dtype)

h1 = _tensor.clone().requires_grad_(True)
Expand All @@ -116,7 +125,9 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m

# triton
triton_rms = (
LigerRMSNorm(hidden_size=hd, offset=offset, casting_mode=casting_mode)
LigerRMSNorm(
hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=in_place
)
.to("cuda")
.to(dtype)
)
Expand Down

0 comments on commit 563e5e5

Please sign in to comment.