diff --git a/configs/rwkv/1.5B.yml b/configs/rwkv/1.5B.yml index 0d97a7861..473bde88e 100644 --- a/configs/rwkv/1.5B.yml +++ b/configs/rwkv/1.5B.yml @@ -1,7 +1,7 @@ { # Parallelism is not yet supported for rwkv "pipe_parallel_size": 1, - "model_parallel_size": 2, + "model_parallel_size": 1, "num_layers": 24, "hidden_size": 2048, @@ -12,7 +12,7 @@ "output_layer_parallelism": "column", "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-5, - "train_micro_batch_size_per_gpu": 1, + "train_micro_batch_size_per_gpu": 4, "attention_config": [[["rwkv"], 24]], @@ -86,8 +86,8 @@ }, # misc. training settings - "train_iters": 1, - "lr_decay_iters": 1, + "train_iters": 320000, + "lr_decay_iters": 320000, "distributed_backend": "nccl", "lr_decay_style": "constant", "warmup": 0.01, diff --git a/configs/rwkv/430M.yml b/configs/rwkv/430M.yml new file mode 100644 index 000000000..1b3a62dfd --- /dev/null +++ b/configs/rwkv/430M.yml @@ -0,0 +1,103 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + "num_layers": 24, + "hidden_size": 1024, + "num_attention_heads": 16, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models + "seq_length": 4096, + "max_position_embeddings": 4096, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 4, + + "attention_config": [[["rwkv"], 24]], + + "activation": "silu", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + "seed": 1234, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index ddf025a1d..1b6aa9b54 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -74,7 +74,6 @@ def cross_entropy(output, labels, _fp16=False): else: losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels) loss_mask = loss_mask.view(-1) - print(f"model output shape: {output.size()}, loss shape: {losses.size()}") loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() return loss diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 47e06bd0b..3018063af 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -173,9 +173,6 @@ def __init__(self, neox_args, layer_number, init_method): ) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - #self.receptance = nn.Linear( - # neox_args.hidden_size, neox_args.dim_att, bias=False - #) self.receptance = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -184,7 +181,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.key = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -193,7 +189,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -202,7 +197,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) self.output = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.dim_att, @@ -211,7 +205,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) # column self.gate = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -220,15 +213,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.gate = mpu.RowParallelLinear( - # neox_args=neox_args, - # input_size=neox_args.hidden_size, - # output_size=neox_args.dim_att, - # input_is_parallel=True, - # init_method=init_method, - # parallel_output=False, - # bias=False - # ) self.ln_x = nn.GroupNorm( neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) ) @@ -237,7 +221,6 @@ def jit_func(self, x): B, T, C = x.size() xx = self.time_shift(x) - x - print(x[0,:,1],xx[0,:,1]) xxx = x + xx * self.time_maa_x xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) @@ -256,7 +239,6 @@ def jit_func(self, x): gated, _ = self.gate(xg) g = F.silu(gated) - print(f"size of ww matmuls: {self.time_decay_w1.size()}, {self.time_decay_w2.size()}") ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 w = self.time_decay + ww w = scatter_to_model_parallel_region(w) @@ -268,9 +250,8 @@ def jit_func_2(self, x, g): x = x.view(B * T, C) x = self.ln_x(x).view(B, T, C) - print(f"shape of x: {x.size()}, shape of g: {g.size()}") x, _ = self.output(x * g) - print(f"new shape of x: {x.size()}") + return x def forward(self, x): @@ -279,15 +260,11 @@ def forward(self, x): H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size() H_tp = H//mpu.get_model_parallel_world_size() - #self.time_faaaa = self.time_faaaa[:self.neox_args.num_attention_heads//2,:] - #self.time_faaaa = scatter_to_model_parallel_region(self.time_faaaa) r, k, v, g, w = self.jit_func(x) - print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H) ) - print(f"size of x after kernel: {x.size()}") + x = gather_from_model_parallel_region(x) - print(f"size of x after allgather: {x.size()}") return self.jit_func_2(x, g) @@ -315,7 +292,6 @@ def __init__(self, neox_args, layer_number, init_method): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - #self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False) self.key = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -324,9 +300,7 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.receptance = nn.Linear( - # neox_args.hidden_size, neox_args.hidden_size, bias=False - #) + self.receptance = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -335,7 +309,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False ) - #self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False) self.value = mpu.RowParallelLinear( neox_args=neox_args, input_size=neox_args.ffn_dim, @@ -345,6 +318,7 @@ def __init__(self, neox_args, layer_number, init_method): parallel_output=False, bias=False ) + def forward(self, x): xx = self.time_shift(x) - x xk = x + xx * self.time_maa_k @@ -355,7 +329,7 @@ def forward(self, x): kv, _ = self.value(k) receptance, _ = self.receptance(xr) retVal = torch.sigmoid(receptance) * kv - print(f"channel mix output size: {retVal.size()}") + return retVal @@ -374,7 +348,6 @@ def __init__(self, neox_args, init_method, layer_number): neox_args.intermediate_size == None or neox_args.expansion_factor == None ), "Must pass either the absolute intermediate size or the relative expansion factor for rwkv" if not neox_args.dim_att: - print("replacing dim_att") neox_args.dim_att = neox_args.hidden_size if neox_args.intermediate_size: neox_args.ffn_dim = neox_args.intermediate_size @@ -450,7 +423,6 @@ def forward(self, x): return x - class RWKVResidualLayerPipe(RWKVResidualLayer): """ RWKV Pipeline Layer @@ -465,5 +437,4 @@ def forward(self, args): hidden_states = super().forward(hidden_states) if self.layer_number == self.neox_args.num_layers-1: hidden_states = hidden_states.transpose(0,1) - print(f"output of model from residual layer pipe: {hidden_states.size()}") return hidden_states, mask