Skip to content

Commit

Permalink
change up epsilon in layernorm the case of using fp16, thanks to @Vel…
Browse files Browse the repository at this point in the history
…drovive for figuring out this stabilizes training
  • Loading branch information
lucidrains committed Jul 29, 2022
1 parent 748c7fe commit 2d67d58
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
14 changes: 10 additions & 4 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,34 +547,40 @@ def p2_reweigh_loss(self, loss, times):
# diffusion prior

class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False):
def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__()
self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))

def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps

if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()

var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
return (x - mean) * (var + eps).rsqrt() * self.g

class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False):
def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__()
self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps

if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach()

var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
return (x - mean) * (var + eps).rsqrt() * self.g

class Residual(nn.Module):
def __init__(self, fn):
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.4.0'
__version__ = '1.4.2'

0 comments on commit 2d67d58

Please sign in to comment.