Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMSNorm #834

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions tests/test_triton_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import logging

import pytest
import torch
from torch.cuda.amp.autocast_mode import autocast

import xformers

try:
from xformers.triton import FusedRMSNorm
from xformers.triton.utils import gpu_capabilities_older_than_70

_triton_available = xformers._is_triton_available()
except ImportError:
logging.warning("Triton is not available, some optimizations will not be tested.")
_triton_available = False

# Testing odd shapes on purpose
SHAPES = [
(384, 128),
(8, 384, 128),
(8, 784, 512),
(4, 2048, 384),
(4, 3136, 1024),
(2, 1024, 2048),
(2, 2048, 4096),
(2, 4096, 4096),
(1, 2048, 12288),
]


class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
self.eps = eps

def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

return (self.weight * hidden_states).to(input_dtype)


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
not _triton_available or gpu_capabilities_older_than_70(),
reason="Triton requires a SM70+ GPU",
)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("amp", [True, False])
def test_rmsnorm_parity(shape, amp):
"""Check that PyTorch and Triton softmax give the same result"""

# Get the same inputs
torch.random.manual_seed(0)
X = torch.normal(0, 1, size=shape, device="cuda", requires_grad=True)

torch.random.manual_seed(0)
X_ = torch.normal(0, 1, size=shape, device="cuda", requires_grad=True)

eps = 1e-4

# Initialize the two layers, weights are 1 and 0 by default, no randomness
torch_rmsnorm = RMSNorm(X.shape[-1], eps=eps).to("cuda")
triton_rmsnorm = FusedRMSNorm(X.shape[-1], eps=eps).to("cuda")

with autocast(enabled=amp):
assert torch.allclose(X, X_) # sanity checking, else all hell breaks loose

# Check the forward pass
y_torch = torch_rmsnorm(X)
y_triton = triton_rmsnorm(X_)
assert torch.allclose(
y_torch.norm(), y_triton.norm(), atol=1e-3
), f"{torch.norm(y_torch)} vs. {torch.norm(y_triton)}"

# Check that BW also gives the same result
loss_torch = torch.norm(y_torch)
loss_torch.backward()

loss_triton = torch.norm(y_triton)
loss_triton.backward()

print(torch.norm(y_torch), torch.norm(y_triton))

print(y_torch[0, :])
print(y_triton[0, :])

# There are 2 items to check:
# - gradient on the inputs
assert torch.allclose(
X.grad, X_.grad
), f"Inputs grad mismatch: {torch.norm(X.grad)} vs. {torch.norm(X_.grad)}"

# - gradient on the rmsnorm weight
assert torch.allclose(
torch_rmsnorm.weight.grad, triton_rmsnorm.weight.grad, atol=1e-3
), (
f"Weight grad mismatch: {torch.norm(torch_rmsnorm.weight.grad)} vs."
+ f" {torch.norm(triton_rmsnorm.weight.grad)}"
)


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16])
def test_no_contiguous(dtype):
"""Check that we don't choke on non-contigous tensors"""
shape = (8, 384, 128)

# Get the same inputs
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)

X = torch.normal(0, 1, size=shape, device="cuda", requires_grad=True, dtype=dtype)
X = X.transpose(2, 1).contiguous().transpose(2, 1)

assert not X.is_contiguous()

triton_rmsnorm = FusedRMSNorm(X.shape[-1]).to(device="cuda", dtype=dtype)
_ = triton_rmsnorm(X)

2 changes: 2 additions & 0 deletions xformers/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .dropout import FusedDropoutBias, dropout # noqa
from .fused_linear_layer import FusedLinear # noqa
from .layer_norm import FusedLayerNorm, layer_norm # noqa
from .rms_norm import FusedRMSNorm # noqa
from .softmax import log_softmax, softmax # noqa

__all__ = [
Expand All @@ -21,6 +22,7 @@
"FusedDropoutBias",
"FusedLinear",
"FusedLayerNorm",
"FusedRMSNorm",
"layer_norm",
]
except ImportError:
Expand Down
148 changes: 148 additions & 0 deletions xformers/triton/k_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


# CREDITS: This comes almost as-is from the Triton layer norm tutorial
# https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py


import triton
import triton.language as tl


# fmt: off
@triton.jit
def rms_norm_fw(X, Y, W, V, stride, N, eps, BLOCK_SIZE_N: tl.constexpr):
# fmt: on
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N

# Move to this row
x_ptrs = X + row * stride + cols
x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
x_var = tl.sum(x * x, axis=0) / N
rstd = 1.0 / tl.sqrt(x_var + eps)

y = x * rstd
tl.store(V + row, rstd)

mask = cols < N
w = tl.load(W + cols, mask=mask, other=1.0)
y = y * w

y_ptrs = Y + row * stride + cols
tl.store(y_ptrs, y, mask=mask)


# Backward pass (DX + partial DW)
# fmt: off
@triton.jit
def rms_norm_bwd_dx_fused(
DX, DY, DW,
X, W, V,
Lock, stride, N,
# META-parameters
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
# fmt: on

# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N

# offset data pointers to start at the row of interest
x_ptrs = X + row * stride + cols
dy_ptrs = DY + row * stride + cols

# load data to SRAM
x = tl.load(x_ptrs, mask=mask, other=0)
dy = tl.load(dy_ptrs, mask=mask, other=0)
rstd = tl.load(V + row)

# compute dx
xhat = x * rstd

w = tl.load(W + cols, mask=mask, other=0)
wdy = w * dy


xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
dx = (wdy - xhat * mean1) * rstd

# write-back dx
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N # re-materialize the mask to save registers
dx_ptrs = DX + row * stride + cols
tl.store(dx_ptrs, dx, mask=mask)

# accumulate partial sums for dw
partial_dw = (dy * xhat).to(w.dtype)

# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
# DW into one of GROUP_SIZE_M independent buffers
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M

# - wait for a lock on the accumulated dw/db
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)

# - we got the lock, accumulate this kernel's results with
# the stored values.
dw_ptrs = DW + lock_id * N + cols

if count == 0:
# first store doesn't accumulate
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(dw_ptrs, mask=mask, other=0.)

tl.store(dw_ptrs, partial_dw, mask=mask)

# release lock
tl.atomic_xchg(Lock, 0)


# Backward pass (total DW)
# fmt: off
@triton.jit
def rms_norm_bwd_dw(
DW, FINAL_DW,
M, N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr
):
# fmt: on
pid = tl.program_id(0)

cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_cols = cols < N

dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
offs = rows[:, None] * N + cols[None, :]
mask_rm = rows < M

dw += tl.load(DW + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)

sum_dw = tl.sum(dw, axis=0)

cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_cols = cols < N

tl.store(FINAL_DW + cols, sum_dw, mask=mask_cols)

Loading