From a751fe2eb939dcdd81b736b2f67e745dc8472a09 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 9 May 2021 07:49:27 -0700 Subject: [PATCH] remove unused keyword argument on Reformer class --- README.md | 1 - reformer_pytorch/reformer_pytorch.py | 4 ++-- setup.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 9e0f1ef..59fd33f 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,6 @@ from reformer_pytorch import Reformer model = Reformer( dim = 512, depth = 12, - max_seq_len = 8192, heads = 8, lsh_dropout = 0.1, causal = True diff --git a/reformer_pytorch/reformer_pytorch.py b/reformer_pytorch/reformer_pytorch.py index e65dcff..4b7dd32 100644 --- a/reformer_pytorch/reformer_pytorch.py +++ b/reformer_pytorch/reformer_pytorch.py @@ -665,7 +665,7 @@ def apply_rotary_pos_emb(qk, sinu_pos): # reformer lm class Reformer(nn.Module): - def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_activation = None, ff_mult = 4, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128): + def __init__(self, dim, depth, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_activation = None, ff_mult = 4, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128): super().__init__() self.dim = dim self.depth = depth @@ -735,7 +735,7 @@ def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64 axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size)) self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape) - self.reformer = Reformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys) + self.reformer = Reformer(dim, depth, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys) self.norm = nn.LayerNorm(dim) if return_embeddings: diff --git a/setup.py b/setup.py index 7725728..638f844 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'reformer_pytorch', packages = find_packages(exclude=['examples', 'pretraining']), - version = '1.4.1', + version = '1.4.2', license='MIT', description = 'Reformer, the Efficient Transformer, Pytorch', author = 'Phil Wang',