From 5f89ed84d0c74804a6c448be3c96e36a1283b0b8 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 5 Nov 2024 17:20:14 -0500 Subject: [PATCH] merge --- megatron/model/rwkv/v6/rwkv.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index fa0eaa53f..77521b9aa 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -284,7 +284,7 @@ def forward(self, x): r, k, v, g, w = self.jit_func(x) print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") - x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=self.time_faaaa) + x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H) ) print(f"size of x after kernel: {x.size()}") x = gather_from_model_parallel_region(x) print(f"size of x after allgather: {x.size()}") @@ -297,7 +297,6 @@ class ParallelRWKV_ChannelMix(nn.Module): Channel Mix layer. The ffn in RWKV """ - def __init__(self, neox_args, layer_number, init_method): def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args @@ -377,8 +376,9 @@ def __init__(self, neox_args, init_method, layer_number): self.bf16 = neox_args.precision == "bfloat16" assert ( neox_args.intermediate_size == None or neox_args.expansion_factor == None - ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" - if not hasattr(neox_args, "dim_att"): + ), "Must pass either the absolute intermediate size or the relative expansion factor for rwkv" + if not neox_args.dim_att: + print("replacing dim_att") neox_args.dim_att = neox_args.hidden_size if neox_args.intermediate_size: neox_args.ffn_dim = neox_args.intermediate_size