Skip to content

Commit

Permalink
add option to use layernorm instead of rmsnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 4, 2024
1 parent 4dc07b0 commit e1dc178
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soft-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.4',
version = '0.1.5',
license='MIT',
description = 'Soft MoE - Pytorch',
author = 'Phil Wang',
Expand Down
17 changes: 14 additions & 3 deletions soft_moe_pytorch/soft_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def gumbel_noise(t):

# norm

class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))

def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
Expand Down Expand Up @@ -274,16 +283,18 @@ def __init__(
dropout = 0.,
geglu = False,
is_distributed = None,
offload_unused_experts_to_cpu = True
offload_unused_experts_to_cpu = True,
use_layernorm = False
):
super().__init__()
assert exists(seq_len) ^ exists(num_slots), 'either seq_len, or num_slots must be passed into SoftMoE'

num_slots = default(num_slots, seq_len // num_experts)

self.norm = RMSNorm(dim)
norm_klass = LayerNorm if use_layernorm else RMSNorm
self.norm = norm_klass(dim)

self.slot_norm = RMSNorm(dim)
self.slot_norm = norm_klass(dim)
self.slot_embeds = nn.Parameter(torch.randn(num_experts, num_slots, dim))

expert_klass = GLUFeedForward if geglu else FeedForward
Expand Down

0 comments on commit e1dc178

Please sign in to comment.