Skip to content

Commit

Permalink
add Triton FLA
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef committed Nov 10, 2024
1 parent 5a259c0 commit c2d6c85
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 24 deletions.
72 changes: 48 additions & 24 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
from torch.utils.cpp_extension import load
from megatron import mpu
from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region, scatter_to_model_parallel_region
try:
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6, native_recurrent_rwkv6
import einops
except ModuleNotFoundError:
print(
"Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \
or directly from https://github.com/TorchRWKV/flash-linear-attention/tree/stable, or use CUDA kernels."
)
pass

class WKV(torch.autograd.Function):
"""
Expand Down Expand Up @@ -96,6 +105,18 @@ def backward(ctx, gy):
def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u):
return WKV.apply(B, T, C, H, r, k, v, w, u)

@torch.compiler.disable(recursive=True)
# torch.compiler introduces errors in numerical precision (torch 2.4)
def RUN_FLA_CHUNK(B, T, C, H, r, k, v, w, u, h=None, scale=1.0, chunk_size=32):
r = r.view(B,T,H,-1).transpose(1,2)
k = k.view(B,T,H,-1).transpose(1,2)
v = v.view(B,T,H,-1).transpose(1,2)
# u can be 3d or 2d (B, H, -1) or just (H, -1) to save VRAM
w = -torch.exp(w.view(B,T,H,-1).transpose(1,2))
# change to scale=-1.0 when using fp16, this will apply scale to r and k.
o, final_state = chunk_rwkv6(r, k, v, w, u=u, scale=scale, initial_state=h,
output_final_state=False, chunk_size=chunk_size) #initial_state=None and output_final_state=False for rwkv6
return o.transpose(1,2).reshape(B,T,C), final_state

# RWKV6 time mix
class RWKV_TimeMix(nn.Module):
Expand Down Expand Up @@ -260,9 +281,11 @@ def forward(self, x):
H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size()

r, k, v, g, w = self.jit_func(x)

x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H) )

if self.neox_args.rwkv_fla:
x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H))
else:
x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H))

x = gather_from_model_parallel_region(x)

return self.jit_func_2(x, g)
Expand Down Expand Up @@ -382,27 +405,28 @@ def __init__(self, neox_args, init_method, layer_number):
self.drop1 = nn.Dropout(p=neox_args.hidden_dropout)

if layer_number == 0:
global wkv_cuda
"""
Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not.
"""
wkv_cuda = load(
name="wkv6",
sources=[
"megatron/model/rwkv/v6/cuda/wkv6_op.cpp",
f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={self.neox_args.head_size}",
f"-D_T_={self.neox_args.seq_length}",
],
)
if not self.neox_args.rwkv_fla:
global wkv_cuda
"""
Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not.
"""
wkv_cuda = load(
name="wkv6",
sources=[
"megatron/model/rwkv/v6/cuda/wkv6_op.cpp",
f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={self.neox_args.head_size}",
f"-D_T_={self.neox_args.seq_length}",
],
)

def forward(self, x):
neox_args = self.neox_args
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ class NeoXArgsModel(NeoXArgsTemplate):
}
"""

rwkv_fla: bool = False
"""
Whether to use the Flash Linear Attention implementation of the RWKV kernel, or the CUDA kernel version.
"""

num_unique_layers: int = None
"""
Number of unique transformer layers. num-layers should be divisible by this value. Currently only has an effect when pipe_parallel_size=0.
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-rwkv.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
rwkv-fla>=0.1.202410200535

0 comments on commit c2d6c85

Please sign in to comment.