From d33d9a35ba4efc7dbdaa1c837078c39b779159b7 Mon Sep 17 00:00:00 2001 From: cromac Date: Thu, 22 Feb 2024 12:05:23 +0100 Subject: [PATCH] Adding a log_softmax in LogScoringModuleFn to properly handle logits --- .../server/llms/module_functions/score_module_function.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lamorel/src/lamorel/server/llms/module_functions/score_module_function.py b/lamorel/src/lamorel/server/llms/module_functions/score_module_function.py index 7ee6046..a9c3ed9 100644 --- a/lamorel/src/lamorel/server/llms/module_functions/score_module_function.py +++ b/lamorel/src/lamorel/server/llms/module_functions/score_module_function.py @@ -1,5 +1,6 @@ from . import BaseModuleFunction import torch +from torch.nn.functional import log_softmax class LogScoringModuleFn(BaseModuleFunction): def __init__(self, pade_token, model_type, pre_encoded_input): @@ -25,6 +26,7 @@ def forward(self, forward_outputs, minibatch, tokenized_contexts, **kwargs): logits = forward_outputs["logits"][:, :-1, :] # skip token appended by tokenizer output_tokens = minibatch["decoder_input_ids"][:, 1:] # skip pad token + logits = log_softmax(logits, dim=-1) tokens_logprobs = \ torch.gather(logits, 2, output_tokens[:, :, None]).squeeze(-1).to(torch.float32) # filter with sequence tokens