Skip to content

Commit

Permalink
Merge pull request #38 from flowersteam/#37
Browse files Browse the repository at this point in the history
Adding a log_softmax in LogScoringModuleFn to properly handle logits
  • Loading branch information
ClementRomac authored Feb 22, 2024
2 parents ba394ec + d33d9a3 commit c82d1b1
Showing 1 changed file with 2 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import BaseModuleFunction
import torch
from torch.nn.functional import log_softmax

class LogScoringModuleFn(BaseModuleFunction):
def __init__(self, pade_token, model_type, pre_encoded_input):
Expand All @@ -25,6 +26,7 @@ def forward(self, forward_outputs, minibatch, tokenized_contexts, **kwargs):
logits = forward_outputs["logits"][:, :-1, :] # skip </s> token appended by tokenizer
output_tokens = minibatch["decoder_input_ids"][:, 1:] # skip pad token

logits = log_softmax(logits, dim=-1)
tokens_logprobs = \
torch.gather(logits, 2, output_tokens[:, :, None]).squeeze(-1).to(torch.float32) # filter with sequence tokens

Expand Down

0 comments on commit c82d1b1

Please sign in to comment.