diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 9e643874a..e866ffb40 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -258,6 +258,7 @@ def init_specs(self): LayerSpec( RWKVResidualLayerPipe, neox_args=self.neox_args, + init_method=self.init_method, layer_number=i, ) ) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 5d4e0d144..eaeec4ad7 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -7,7 +7,7 @@ import torch.nn as nn from torch.nn import functional as F from torch.utils.cpp_extension import load - +from megatron import mpu class WKV(torch.autograd.Function): """ @@ -179,7 +179,7 @@ def __init__(self, neox_args, layer_number): self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) - self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) # column self.ln_x = nn.GroupNorm( neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) ) @@ -228,15 +228,19 @@ def forward(self, x): return self.jit_func_2(x, g) -class RWKV_ChannelMix(nn.Module): +class ParallelRWKV_ChannelMix(nn.Module): """ Channel Mix layer. The ffn in RWKV """ - def __init__(self, neox_args, layer_number): + def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args self.layer_number = layer_number + + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix @@ -247,21 +251,46 @@ def __init__(self, neox_args, layer_number): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) - self.receptance = nn.Linear( - neox_args.hidden_size, neox_args.hidden_size, bias=False - ) - self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) - + #self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.key = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_ffn, + gather_output=False, + init_method=init_method, + bias=False, + ) + #self.receptance = nn.Linear( + # neox_args.hidden_size, neox_args.hidden_size, bias=False + #) + self.receptance = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.hidden_size, + gather_output=True, + init_method=init_method, + bias=False + ) + #self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + self.value = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.dim_ffn, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=init_method, + parallel_output=False, + bias=False + ) def forward(self, x): xx = self.time_shift(x) - x xk = x + xx * self.time_maa_k xr = x + xx * self.time_maa_r - k = self.key(xk) + k, _ = self.key(xk) k = torch.relu(k) ** 2 - kv = self.value(k) - return torch.sigmoid(self.receptance(xr)) * kv + kv, _ = self.value(k) + receptance, _ = self.receptance(xr) + return torch.sigmoid(receptance) * kv class RWKVResidualLayer(nn.Module): @@ -269,7 +298,7 @@ class RWKVResidualLayer(nn.Module): RWKV layer definition """ - def __init__(self, neox_args, layer_number): + def __init__(self, neox_args, init_method, layer_number): super().__init__() self.neox_args = neox_args self.layer_number = layer_number @@ -288,6 +317,7 @@ def __init__(self, neox_args, layer_number): self.num_attention_heads = neox_args.num_attention_heads assert neox_args.dim_att % self.num_attention_heads == 0 + self.init_method = init_method if neox_args.attention_dropout > 0: self.drop0 = nn.Dropout(p=neox_args.attention_dropout) @@ -296,7 +326,7 @@ def __init__(self, neox_args, layer_number): self.att = RWKV_TimeMix(neox_args, layer_number) - self.ffn = RWKV_ChannelMix(neox_args, layer_number) + self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method) if neox_args.attention_dropout > 0: self.drop0 = nn.Dropout(p=neox_args.attention_dropout) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index ff4f4bc21..3dda6489d 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1066,17 +1066,11 @@ def calculate_derived(self): if isinstance(self.zero_stage, int): assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba" assert ( - self.hidden_dropout == 0.0, + self.hidden_dropout != 0.0, ), "Mamba does not yet have dropout implemented" if "rwkv" in self.attention_config: - assert ( - not self.is_pipe_parallel and self.model_parallel_size == 1 - ), "RWKV not currently compatible with parallelism" if isinstance(self.zero_stage, int): assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV" - assert ( - self.hidden_dropout == 0.0, - ), "RWKV does not yet have dropout implemented" # Sparsity config if self.sparsity_config is None: