Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement weight sharing #2

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ def top_k(logits, thres = 0.5):
probs.scatter_(1, ind, val)
return probs

class SharedEmbedding(nn.Embedding):
def __init__(self, linear, start_index, end_index, **kwargs):
super().__init__(end_index - start_index, linear.weight.shape[1], **kwargs)
del self.weight

self.linear = linear
self.start_index = start_index
self.end_index = end_index

def forward(self, input):
return F.embedding(
input, self.linear.weight[self.start_index:self.end_index], self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)

# discrete vae class

class ResBlock(nn.Module):
Expand Down Expand Up @@ -326,7 +340,10 @@ def __init__(
stable = False,
sandwich_norm = False,
shift_tokens = True,
rotary_emb = True
rotary_emb = True,
shared_attn_ids = None,
shared_ff_ids = None,
share_input_output_emb = False,
):
super().__init__()
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
Expand All @@ -338,9 +355,6 @@ def __init__(

num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len)

self.text_emb = nn.Embedding(num_text_tokens, dim)
self.image_emb = nn.Embedding(num_image_tokens, dim)

self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for <bos>
self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)

Expand Down Expand Up @@ -374,7 +388,9 @@ def __init__(
stable = stable,
sandwich_norm = sandwich_norm,
shift_tokens = shift_tokens,
rotary_emb = rotary_emb
rotary_emb = rotary_emb,
shared_attn_ids = shared_attn_ids,
shared_ff_ids = shared_ff_ids,
)

self.stable = stable
Expand All @@ -387,6 +403,13 @@ def __init__(
nn.Linear(dim, self.total_tokens),
)

if share_input_output_emb:
self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
else:
self.text_emb = nn.Embedding(num_text_tokens, dim)
self.image_emb = nn.Embedding(num_image_tokens, dim)

seq_range = torch.arange(seq_len)
logits_range = torch.arange(total_tokens)

Expand Down Expand Up @@ -417,7 +440,7 @@ def generate_texts(
text_tokens = torch.tensor([[0]]).cuda()
else:
text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)

for _ in range(text_tokens.shape[1], text_seq_len):
device = text_tokens.device

Expand All @@ -443,9 +466,9 @@ def generate_texts(
filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim = -1)
sample = torch.multinomial(probs, 1)

text_tokens = torch.cat((text_tokens, sample), dim=-1)

padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
return text_tokens, texts
Expand Down
38 changes: 27 additions & 11 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable
from functools import partial
from itertools import islice, cycle

Expand All @@ -21,9 +22,7 @@ def default(val, d):
return val if exists(val) else d

def cast_tuple(val, depth = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
return val if isinstance(val, Iterable) else (val,) * depth

# classes

Expand Down Expand Up @@ -150,7 +149,9 @@ def __init__(
stable = False,
sandwich_norm = False,
shift_tokens = False,
rotary_emb = True
rotary_emb = True,
shared_attn_ids = None,
shared_ff_ids = None,
):
super().__init__()
layers = nn.ModuleList([])
Expand All @@ -160,7 +161,13 @@ def __init__(
attn_types = cast_tuple(attn_types)
attn_type_layer = islice(cycle(attn_types), depth)

for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
shared_attn_ids = cycle(default(shared_attn_ids, range(depth)))
shared_ff_ids = cycle(default(shared_ff_ids, range(depth)))
shared_attn_layers = {}
shared_ff_layers = {}

for (ind, sparse_attn, attn_type, attn_id, ff_id) in \
zip(range(depth), sparse_layer, attn_type_layer, shared_attn_ids, shared_ff_ids):
if attn_type == 'full':
attn_class = partial(Attention, stable = stable)
elif attn_type == 'sparse':
Expand All @@ -176,12 +183,21 @@ def __init__(
else:
raise ValueError(f'attention type "{attn_type}" is not valid')

if attn_type != 'mlp':
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
else:
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)

ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
attn, reused_attn_type = shared_attn_layers.get(attn_id, (None, None))
if not exists(attn):
if attn_type != 'mlp':
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
else:
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
shared_attn_layers[attn_id] = (attn, attn_type)
elif attn_type != reused_attn_type:
raise ValueError('attn_types do not match shared_attn_ids '
f'(ind = {ind}, attn_type = "{attn_type}", reused_attn_type = "{reused_attn_type}")')

ff = shared_ff_layers.get(ff_id)
if not exists(ff):
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
shared_ff_layers[ff_id] = ff

if shift_tokens:
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
Expand Down
20 changes: 14 additions & 6 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
help='path to your folder of images and text for learning the DALL-E')

parser.add_argument(
'--wds',
type = str,
default='',
'--wds',
type = str,
default='',
help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
)

Expand Down Expand Up @@ -134,6 +134,10 @@

model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true')

model_group.add_argument('--shared_attn_ids', default = None, type = str, help = 'Comma separated list of shared attention layer ids. Default: sharing is disabled')

model_group.add_argument('--shared_ff_ids', default = None, type = str, help = 'Comma separated list of shared feed forward layer ids. Default: sharing is disabled')

args = parser.parse_args()

# helpers
Expand Down Expand Up @@ -191,6 +195,8 @@ def cp_path_to_dir(cp_path, tag):
ROTARY_EMB = args.rotary_emb

ATTN_TYPES = tuple(args.attn_types.split(','))
SHARED_ATTN_IDS = tuple(args.shared_attn_ids.split(',')) if exists(args.shared_attn_ids) else None
SHARED_FF_IDS = tuple(args.shared_ff_ids.split(',')) if exists(args.shared_ff_ids) else None

DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'

Expand Down Expand Up @@ -303,6 +309,8 @@ def cp_path_to_dir(cp_path, tag):
stable=STABLE,
shift_tokens=SHIFT_TOKENS,
rotary_emb=ROTARY_EMB,
shared_attn_ids=SHARED_ATTN_IDS,
shared_ff_ids=SHARED_FF_IDS,
)
resume_epoch = 0

Expand Down Expand Up @@ -368,7 +376,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
if myimg not in item:
return False
return True

w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue)
filtered_dataset = w_dataset.select(filter_dataset)
ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True)
Expand Down Expand Up @@ -600,7 +608,7 @@ def save_model(path, epoch=0):

if i % SAVE_EVERY_N_STEPS == 0:
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)

if i % 100 == 0:
if distr_backend.is_root_worker():
sample_text = text[:1]
Expand Down Expand Up @@ -633,7 +641,7 @@ def save_model(path, epoch=0):
distr_scheduler.step(avg_loss)

save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)

if distr_backend.is_root_worker():
# save trained model to wandb as an artifact every epoch's end

Expand Down