diff --git a/lamorel/src/lamorel/server/llms/module_functions/score_module_function.py b/lamorel/src/lamorel/server/llms/module_functions/score_module_function.py index 7ee6046..a9c3ed9 100644 --- a/lamorel/src/lamorel/server/llms/module_functions/score_module_function.py +++ b/lamorel/src/lamorel/server/llms/module_functions/score_module_function.py @@ -1,5 +1,6 @@ from . import BaseModuleFunction import torch +from torch.nn.functional import log_softmax class LogScoringModuleFn(BaseModuleFunction): def __init__(self, pade_token, model_type, pre_encoded_input): @@ -25,6 +26,7 @@ def forward(self, forward_outputs, minibatch, tokenized_contexts, **kwargs): logits = forward_outputs["logits"][:, :-1, :] # skip token appended by tokenizer output_tokens = minibatch["decoder_input_ids"][:, 1:] # skip pad token + logits = log_softmax(logits, dim=-1) tokens_logprobs = \ torch.gather(logits, 2, output_tokens[:, :, None]).squeeze(-1).to(torch.float32) # filter with sequence tokens