From 2d67d5821eaa1a3759a38eb401634946ad860d8b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 29 Jul 2022 12:41:02 -0700 Subject: [PATCH] change up epsilon in layernorm the case of using fp16, thanks to @Veldrovive for figuring out this stabilizes training --- dalle2_pytorch/dalle2_pytorch.py | 14 ++++++++++---- dalle2_pytorch/version.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ce53fb05..8e496040 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -547,34 +547,40 @@ def p2_reweigh_loss(self, loss, times): # diffusion prior class LayerNorm(nn.Module): - def __init__(self, dim, eps = 1e-5, stable = False): + def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False): super().__init__() self.eps = eps + self.fp16_eps = fp16_eps self.stable = stable self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): + eps = self.eps if x.dtype == torch.float32 else self.fp16_eps + if self.stable: x = x / x.amax(dim = -1, keepdim = True).detach() var = torch.var(x, dim = -1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True) - return (x - mean) * (var + self.eps).rsqrt() * self.g + return (x - mean) * (var + eps).rsqrt() * self.g class ChanLayerNorm(nn.Module): - def __init__(self, dim, eps = 1e-5, stable = False): + def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False): super().__init__() self.eps = eps + self.fp16_eps = fp16_eps self.stable = stable self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) def forward(self, x): + eps = self.eps if x.dtype == torch.float32 else self.fp16_eps + if self.stable: x = x / x.amax(dim = 1, keepdim = True).detach() var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) - return (x - mean) * (var + self.eps).rsqrt() * self.g + return (x - mean) * (var + eps).rsqrt() * self.g class Residual(nn.Module): def __init__(self, fn): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 96e3ce8d..98d186be 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.4.0' +__version__ = '1.4.2'