From 8dc93258fe507931e5da77431b5d74f5aa7830d9 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 2 Apr 2024 14:54:09 -0700 Subject: [PATCH] Fixed Generation --- server/lorax_server/models/causal_lm.py | 13 ++++++++----- server/lorax_server/models/seq2seq_lm.py | 13 ++++++++----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 6a38e2e90..3d8f745e7 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -8,6 +8,7 @@ from lorax_server.models import Model from lorax_server.models.types import ( Batch, + NextTokens, PrefillTokens, AlternativeTokens, Generation, @@ -733,11 +734,13 @@ def generate_token( request.id, prefill_tokens, prefill_tokens_length, - None, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + NextTokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + None, + ), generated_text, ) diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index b4ae92b93..2040f064d 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -11,6 +11,7 @@ GeneratedText, Batch, Generation, + NextTokens, PrefillTokens, ) from lorax_server.pb import generate_pb2 @@ -717,11 +718,13 @@ def generate_token( request.id, prefill_tokens, prefill_tokens_length, - None, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + NextTokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + None, + ), generated_text, )