Skip to content

Commit

Permalink
update flag for turning on axial positional embedding and update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 25, 2021
1 parent a751fe2 commit 2cbc36b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,11 @@ y = model(x, keys = c, input_mask = i_mask, context_mask = c_mask)

## Positional Embeddings

<a href="https://github.com/AranKomat">Aran</a> has informed me that the Reformer team used axial position embeddings with great results on longer sequences. I tested it out and indeed it works very well! So well in fact that I have decided to make this the default. You can adjust the shape and dimension of the axial embeddings by following the instructions below.
The default positional embedding uses <a href="https://arxiv.org/abs/2104.09864">rotary embeddings</a>.

However, <a href="https://github.com/AranKomat">Aran</a> has informed me that the Reformer team used axial position embeddings with great results on longer sequences.

You can turn on axial positional embedding and adjust the shape and dimension of the axial embeddings by following the instructions below.

```python
import torch
Expand All @@ -169,8 +172,8 @@ model = ReformerLM(
ff_chunks = 8,
attn_chunks = 2,
causal = True,
axial_position_emb = True, # set this to True
axial_position_shape = (128, 64), # the shape must multiply up to the max_seq_len (128 x 64 = 8192)
axial_position_dims = (512, 512) # the dims must sum up to the model dimensions (512 + 512 = 1024)
)

x = torch.randint(0, 20000, (1, 8192)).long()
Expand Down
10 changes: 5 additions & 5 deletions reformer_pytorch/reformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def forward(self, x, **kwargs):
return torch.stack(x.chunk(2, dim=-1)).mean(dim=0)

class ReformerLM(nn.Module):
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, rotary_emb = True, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, axial_position_emb = False, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
super().__init__()
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
Expand All @@ -725,15 +725,15 @@ def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64
self.pos_emb = Always(0)
self.layer_pos_emb = Always(None)

if rotary_emb:
self.layer_pos_emb = FixedPositionalEmbedding(dim_head)
if axial_position_emb:
axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size))
self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape)
elif absolute_position_emb:
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
elif fixed_position_emb:
self.pos_emb = FixedPositionalEmbedding(emb_dim)
else:
axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size))
self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape)
self.layer_pos_emb = FixedPositionalEmbedding(dim_head)

self.reformer = Reformer(dim, depth, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys)
self.norm = nn.LayerNorm(dim)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'reformer_pytorch',
packages = find_packages(exclude=['examples', 'pretraining']),
version = '1.4.2',
version = '1.4.3',
license='MIT',
description = 'Reformer, the Efficient Transformer, Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2cbc36b

Please sign in to comment.