diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 923de0835..e96f1bd2b 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -8,7 +8,7 @@ from torch.nn import functional as F from torch.utils.cpp_extension import load from megatron import mpu -from mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region +from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region class WKV(torch.autograd.Function): """ @@ -207,7 +207,7 @@ def __init__(self, neox_args, layer_number, init_method): neox_args=neox_args, input_size=neox_args.dim_att, output_size=neox_args.hidden_size, - gather_output=False, + gather_output=True, init_method=init_method, bias=False, ) @@ -216,19 +216,19 @@ def __init__(self, neox_args, layer_number, init_method): neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.dim_att, - gather_output=False, + gather_output=True, init_method=init_method, bias=False, ) - self.gate = mpu.RowParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.dim_att, - input_is_parallel=True, - init_method=init_method, - parallel_output=False, - bias=False - ) + #self.gate = mpu.RowParallelLinear( + # neox_args=neox_args, + # input_size=neox_args.hidden_size, + # output_size=neox_args.dim_att, + # input_is_parallel=True, + # init_method=init_method, + # parallel_output=False, + # bias=False + # ) self.ln_x = nn.GroupNorm( neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) ) @@ -252,7 +252,8 @@ def jit_func(self, x): r, _ = self.receptance(xr) k, _ = self.key(xk) v, _ = self.value(xv) - g, _ = F.silu(self.gate(xg)) + gated, _ = self.gate(xg) + g = F.silu(gated) ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 w = self.time_decay + ww @@ -271,12 +272,15 @@ def jit_func_2(self, x, g): def forward(self, x): B, T, C = x.size() + C_tp = C//mpu.get_model_parallel_world_size() H = self.neox_args.num_attention_heads + H_tp = H//mpu.get_model_parallel_world_size() r, k, v, g, w = self.jit_func(x) print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") - x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) - x = reduce_from_model_parallel_region(x) + x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=self.time_faaaa) + x = gather_from_model_parallel_region(x) + print(f"size of x after kernel: {x.size()}") return self.jit_func_2(x, g) @@ -304,11 +308,11 @@ def __init__(self, neox_args, layer_number, init_method): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - #self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + #self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False) self.key = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, - output_size=neox_args.dim_ffn, + output_size=neox_args.ffn_dim, gather_output=False, init_method=init_method, bias=False, @@ -324,10 +328,10 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False ) - #self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + #self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False) self.value = mpu.RowParallelLinear( neox_args=neox_args, - input_size=neox_args.dim_ffn, + input_size=neox_args.ffn_dim, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=init_method, @@ -350,7 +354,7 @@ class RWKVResidualLayer(nn.Module): """ RWKV layer definition """ - + def __init__(self, neox_args, init_method, layer_number): super().__init__() self.neox_args = neox_args @@ -446,4 +450,6 @@ def forward(self, args): assert len(args) == 2 hidden_states, mask = args neox_args = self.neox_args + if self.layer_number == 0: + hidden_states = hidden_states.transpose(0,1) return super().forward(hidden_states), mask