diff --git a/flamingo_pytorch/flamingo_pytorch.py b/flamingo_pytorch/flamingo_pytorch.py index 945ac4c..1e0a7b3 100644 --- a/flamingo_pytorch/flamingo_pytorch.py +++ b/flamingo_pytorch/flamingo_pytorch.py @@ -176,7 +176,7 @@ def forward( # any text without a preceding media needs to have attention zeroed out text_without_media_mask = text_time == 0 text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1') - attn.masked_fill(text_without_media_mask, 0.) + attn = attn.masked_fill(text_without_media_mask, 0.) out = einsum('... i j, ... j d -> ... i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/setup.py b/setup.py index 8e7f26a..698de6a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'flamingo-pytorch', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'Flamingo - Pytorch', author = 'Phil Wang',