diff --git a/reformer_pytorch/reformer_pytorch.py b/reformer_pytorch/reformer_pytorch.py index a3e0d4f..56f4f46 100644 --- a/reformer_pytorch/reformer_pytorch.py +++ b/reformer_pytorch/reformer_pytorch.py @@ -654,6 +654,7 @@ def rotate_every_two(x): return rearrange(x, '... d j -> ... (d j)') def apply_rotary_pos_emb(qk, sinu_pos): + sinu_pos = sinu_pos.type(qk.dtype) sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2) sin, cos = sinu_pos.unbind(dim = -2) sin, cos = map(lambda t: repeat(t, 'n d -> n (d j)', j = 2), (sin, cos)) diff --git a/setup.py b/setup.py index 65ce35b..2078b85 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'reformer_pytorch', packages = find_packages(exclude=['examples', 'pretraining']), - version = '1.4.3', + version = '1.4.4', license='MIT', description = 'Reformer, the Efficient Transformer, Pytorch', author = 'Phil Wang',