Skip to content

Commit

Permalink
inital tp commits
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef committed Jun 4, 2024
1 parent d037756 commit 4c7cb11
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
1 change: 1 addition & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def init_specs(self):
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
layer_number=i,
)
)
Expand Down
60 changes: 45 additions & 15 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.cpp_extension import load

from megatron import mpu

class WKV(torch.autograd.Function):
"""
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(self, neox_args, layer_number):

self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False)
self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False)
self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False)
self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) # column
self.ln_x = nn.GroupNorm(
neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2)
)
Expand Down Expand Up @@ -228,15 +228,19 @@ def forward(self, x):
return self.jit_func_2(x, g)


class RWKV_ChannelMix(nn.Module):
class ParallelRWKV_ChannelMix(nn.Module):
"""
Channel Mix layer. The ffn in RWKV
"""

def __init__(self, neox_args, layer_number):
def __init__(self, neox_args, layer_number, init_method):
super().__init__()
self.neox_args = neox_args
self.layer_number = layer_number

world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad(): # fancy init of time_mix
Expand All @@ -247,29 +251,54 @@ def __init__(self, neox_args, layer_number):
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
self.receptance = nn.Linear(
neox_args.hidden_size, neox_args.hidden_size, bias=False
)
self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)

#self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
self.key = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_ffn,
gather_output=False,
init_method=init_method,
bias=False,
)
#self.receptance = nn.Linear(
# neox_args.hidden_size, neox_args.hidden_size, bias=False
#)
self.receptance = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
gather_output=True,
init_method=init_method,
bias=False
)
#self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)
self.value = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.dim_ffn,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=init_method,
parallel_output=False,
bias=False
)
def forward(self, x):
xx = self.time_shift(x) - x
xk = x + xx * self.time_maa_k
xr = x + xx * self.time_maa_r

k = self.key(xk)
k, _ = self.key(xk)
k = torch.relu(k) ** 2
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv
kv, _ = self.value(k)
receptance, _ = self.receptance(xr)
return torch.sigmoid(receptance) * kv


class RWKVResidualLayer(nn.Module):
"""
RWKV layer definition
"""

def __init__(self, neox_args, layer_number):
def __init__(self, neox_args, init_method, layer_number):
super().__init__()
self.neox_args = neox_args
self.layer_number = layer_number
Expand All @@ -288,6 +317,7 @@ def __init__(self, neox_args, layer_number):
self.num_attention_heads = neox_args.num_attention_heads
assert neox_args.dim_att % self.num_attention_heads == 0

self.init_method = init_method
if neox_args.attention_dropout > 0:
self.drop0 = nn.Dropout(p=neox_args.attention_dropout)

Expand All @@ -296,7 +326,7 @@ def __init__(self, neox_args, layer_number):

self.att = RWKV_TimeMix(neox_args, layer_number)

self.ffn = RWKV_ChannelMix(neox_args, layer_number)
self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method)

if neox_args.attention_dropout > 0:
self.drop0 = nn.Dropout(p=neox_args.attention_dropout)
Expand Down
8 changes: 1 addition & 7 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,17 +1066,11 @@ def calculate_derived(self):
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba"
assert (
self.hidden_dropout == 0.0,
self.hidden_dropout != 0.0,
), "Mamba does not yet have dropout implemented"
if "rwkv" in self.attention_config:
assert (
not self.is_pipe_parallel and self.model_parallel_size == 1
), "RWKV not currently compatible with parallelism"
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV"
assert (
self.hidden_dropout == 0.0,
), "RWKV does not yet have dropout implemented"

# Sparsity config
if self.sparsity_config is None:
Expand Down

0 comments on commit 4c7cb11

Please sign in to comment.