Skip to content

Commit

Permalink
address #291
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 11, 2024
1 parent 15dad55 commit abeedc8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.10',
version = '1.42.11',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
23 changes: 20 additions & 3 deletions x_transformers/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from torch import nn
import torch.nn.functional as F

from einops import pack, repeat, unpack
import einx
from einops import reduce, pack, repeat, unpack

from x_transformers.x_transformers import (
AttentionLayers,
Expand All @@ -24,6 +25,15 @@ def default(val, d):
return val
return d() if callable(d) else d

def masked_mean(t, mask):
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)

num = reduce(t, 'b n d -> b', 'sum')
den = mask.sum(dim = -1)

masked_average = num / den.clamp(min = 1.)
return masked_average

# main classes

class ContinuousTransformerWrapper(nn.Module):
Expand Down Expand Up @@ -169,12 +179,15 @@ def __init__(
net: ContinuousTransformerWrapper,
ignore_index = -100,
pad_value = 0,
loss_fn = nn.MSELoss(reduction = 'none')
loss_fn = nn.MSELoss(reduction = 'none'),
equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
):
super().__init__()
self.net = net
self.max_seq_len = net.max_seq_len

self.loss_fn = loss_fn
self.equal_loss_weight_batch = equal_loss_weight_batch

@torch.no_grad()
def generate(self, start_tokens, seq_len, **kwargs):
Expand Down Expand Up @@ -222,6 +235,10 @@ def forward(self, x, **kwargs):

if exists(mask):
assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
loss = loss[mask]

if self.equal_loss_weight_batch:
loss = masked_mean(loss, mask)
else:
loss = loss[mask]

return loss.mean()

0 comments on commit abeedc8

Please sign in to comment.