Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef committed Nov 5, 2024
1 parent 97c7915 commit 5f89ed8
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def forward(self, x):
r, k, v, g, w = self.jit_func(x)
print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n")

x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=self.time_faaaa)
x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H) )
print(f"size of x after kernel: {x.size()}")
x = gather_from_model_parallel_region(x)
print(f"size of x after allgather: {x.size()}")
Expand All @@ -297,7 +297,6 @@ class ParallelRWKV_ChannelMix(nn.Module):
Channel Mix layer. The ffn in RWKV
"""

def __init__(self, neox_args, layer_number, init_method):
def __init__(self, neox_args, layer_number, init_method):
super().__init__()
self.neox_args = neox_args
Expand Down Expand Up @@ -377,8 +376,9 @@ def __init__(self, neox_args, init_method, layer_number):
self.bf16 = neox_args.precision == "bfloat16"
assert (
neox_args.intermediate_size == None or neox_args.expansion_factor == None
), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections"
if not hasattr(neox_args, "dim_att"):
), "Must pass either the absolute intermediate size or the relative expansion factor for rwkv"
if not neox_args.dim_att:
print("replacing dim_att")
neox_args.dim_att = neox_args.hidden_size
if neox_args.intermediate_size:
neox_args.ffn_dim = neox_args.intermediate_size
Expand Down

0 comments on commit 5f89ed8

Please sign in to comment.