Skip to content

Commit

Permalink
time-mixing
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef committed Oct 11, 2024
1 parent 43d641d commit de02f37
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn import functional as F
from torch.utils.cpp_extension import load
from megatron import mpu
from mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region
from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region

class WKV(torch.autograd.Function):
"""
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self, neox_args, layer_number, init_method):
neox_args=neox_args,
input_size=neox_args.dim_att,
output_size=neox_args.hidden_size,
gather_output=False,
gather_output=True,
init_method=init_method,
bias=False,
)
Expand All @@ -216,19 +216,19 @@ def __init__(self, neox_args, layer_number, init_method):
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_att,
gather_output=False,
gather_output=True,
init_method=init_method,
bias=False,
)
self.gate = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_att,
input_is_parallel=True,
init_method=init_method,
parallel_output=False,
bias=False
)
#self.gate = mpu.RowParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=neox_args.dim_att,
# input_is_parallel=True,
# init_method=init_method,
# parallel_output=False,
# bias=False
# )
self.ln_x = nn.GroupNorm(
neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2)
)
Expand All @@ -252,7 +252,8 @@ def jit_func(self, x):
r, _ = self.receptance(xr)
k, _ = self.key(xk)
v, _ = self.value(xv)
g, _ = F.silu(self.gate(xg))
gated, _ = self.gate(xg)
g = F.silu(gated)

ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
w = self.time_decay + ww
Expand All @@ -271,12 +272,15 @@ def jit_func_2(self, x, g):

def forward(self, x):
B, T, C = x.size()
C_tp = C//mpu.get_model_parallel_world_size()
H = self.neox_args.num_attention_heads
H_tp = H//mpu.get_model_parallel_world_size()

r, k, v, g, w = self.jit_func(x)
print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n")
x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa)
x = reduce_from_model_parallel_region(x)
x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=self.time_faaaa)
x = gather_from_model_parallel_region(x)
print(f"size of x after kernel: {x.size()}")

return self.jit_func_2(x, g)

Expand Down Expand Up @@ -304,11 +308,11 @@ def __init__(self, neox_args, layer_number, init_method):
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

#self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
#self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False)
self.key = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_ffn,
output_size=neox_args.ffn_dim,
gather_output=False,
init_method=init_method,
bias=False,
Expand All @@ -324,10 +328,10 @@ def __init__(self, neox_args, layer_number, init_method):
init_method=init_method,
bias=False
)
#self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)
#self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False)
self.value = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.dim_ffn,
input_size=neox_args.ffn_dim,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=init_method,
Expand All @@ -350,7 +354,7 @@ class RWKVResidualLayer(nn.Module):
"""
RWKV layer definition
"""

def __init__(self, neox_args, init_method, layer_number):
super().__init__()
self.neox_args = neox_args
Expand Down Expand Up @@ -446,4 +450,6 @@ def forward(self, args):
assert len(args) == 2
hidden_states, mask = args
neox_args = self.neox_args
if self.layer_number == 0:
hidden_states = hidden_states.transpose(0,1)
return super().forward(hidden_states), mask

0 comments on commit de02f37

Please sign in to comment.