From abeedc8cb60180892f1c19b42548c9833abe4f5b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 07:21:56 -0800 Subject: [PATCH] address https://github.com/lucidrains/x-transformers/issues/291 --- setup.py | 2 +- x_transformers/continuous.py | 23 ++++++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 87f8a991..1cd7a24f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.42.10', + version = '1.42.11', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/continuous.py b/x_transformers/continuous.py index 57c94a65..a2b0c409 100644 --- a/x_transformers/continuous.py +++ b/x_transformers/continuous.py @@ -2,7 +2,8 @@ from torch import nn import torch.nn.functional as F -from einops import pack, repeat, unpack +import einx +from einops import reduce, pack, repeat, unpack from x_transformers.x_transformers import ( AttentionLayers, @@ -24,6 +25,15 @@ def default(val, d): return val return d() if callable(d) else d +def masked_mean(t, mask): + t = einx.where('b n, b n d, -> b n d', mask, t, 0.) + + num = reduce(t, 'b n d -> b', 'sum') + den = mask.sum(dim = -1) + + masked_average = num / den.clamp(min = 1.) + return masked_average + # main classes class ContinuousTransformerWrapper(nn.Module): @@ -169,12 +179,15 @@ def __init__( net: ContinuousTransformerWrapper, ignore_index = -100, pad_value = 0, - loss_fn = nn.MSELoss(reduction = 'none') + loss_fn = nn.MSELoss(reduction = 'none'), + equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token) ): super().__init__() self.net = net self.max_seq_len = net.max_seq_len + self.loss_fn = loss_fn + self.equal_loss_weight_batch = equal_loss_weight_batch @torch.no_grad() def generate(self, start_tokens, seq_len, **kwargs): @@ -222,6 +235,10 @@ def forward(self, x, **kwargs): if exists(mask): assert loss.ndim > 1, 'loss should not be reduced if mask is passed in' - loss = loss[mask] + + if self.equal_loss_weight_batch: + loss = masked_mean(loss, mask) + else: + loss = loss[mask] return loss.mean()