diff --git a/nuwa_pytorch/nuwa_pytorch.py b/nuwa_pytorch/nuwa_pytorch.py index 80c661b..bc4ec48 100644 --- a/nuwa_pytorch/nuwa_pytorch.py +++ b/nuwa_pytorch/nuwa_pytorch.py @@ -276,10 +276,10 @@ def __init__( self.chunk_size = chunk_size self.net = nn.Sequential( - nn.Linear(dim, inner_dim * 2), + nn.Linear(dim, inner_dim * 2, bias = False), GEGLU(), nn.Dropout(dropout), - nn.Linear(inner_dim, dim) + nn.Linear(inner_dim, dim, bias = False) ) def forward(self, x): @@ -315,7 +315,7 @@ def __init__( self.dropout = nn.Dropout(dropout) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) - self.to_out = nn.Linear(inner_dim, dim) + self.to_out = nn.Linear(inner_dim, dim, bias = False) def forward( self, @@ -638,7 +638,7 @@ def __init__( self.talking_heads = nn.Conv3d(heads, heads, 1, bias = False) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) - self.to_out = nn.Linear(inner_dim, dim) + self.to_out = nn.Linear(inner_dim, dim, bias = False) # handle variables for unfold @@ -787,7 +787,7 @@ def __init__( self.dropout = nn.Dropout(dropout) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) - self.to_out = nn.Linear(inner_dim, dim) + self.to_out = nn.Linear(inner_dim, dim, bias = False) # handle variables for 2d unfold @@ -938,7 +938,7 @@ def __init__( self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) - self.to_out = nn.Linear(inner_dim, dim) + self.to_out = nn.Linear(inner_dim, dim, bias = False) self.null_k = nn.Parameter(torch.randn(heads, dim_head)) self.null_v = nn.Parameter(torch.randn(heads, dim_head)) @@ -1821,7 +1821,7 @@ def __init__( sparse_3dna_rel_pos_bias = sparse_3dna_rel_pos_bias ) - self.to_logits = nn.Linear(dim, num_image_tokens) + self.to_logits = nn.Linear(dim, num_image_tokens, bias = False) def embed_text(self, text, mask = None): batch, seq_len, device = *text.shape, text.device @@ -2090,8 +2090,8 @@ def __init__( sparse_2dna_rel_pos_bias = sparse_2dna_rel_pos_bias ) - self.to_video_logits = nn.Linear(dim, num_image_tokens) - self.to_audio_logits = nn.Linear(dim, num_audio_tokens) + self.to_video_logits = nn.Linear(dim, num_image_tokens, bias = False) + self.to_audio_logits = nn.Linear(dim, num_audio_tokens, bias = False) def embed_text(self, text, mask = None): batch, seq_len, device = *text.shape, text.device @@ -2413,7 +2413,7 @@ def __init__( sparse_3dna_attn = True ) - self.to_logits = nn.Linear(dim, num_image_tokens) + self.to_logits = nn.Linear(dim, num_image_tokens, bias = False) def embed_sketch(self, sketch, mask = None): batch, frames, channels, image_size, _, device = *sketch.shape, sketch.device diff --git a/setup.py b/setup.py index 06d089a..6aa542e 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'nuwa-pytorch', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.7.2', + version = '0.7.3', license='MIT', description = 'NÜWA - Pytorch', author = 'Phil Wang',