From 4d7f46862060185567e96cac7bc8187fedc39890 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 13 Dec 2024 11:44:16 +0530 Subject: [PATCH] Fix FlashInfer + Medusa bug (#715) --- server/lorax_server/models/flash_causal_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2db2cbe08..4b1f27ef4 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1508,10 +1508,10 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> if FLASH_INFER: block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - cache_lengths=batch.cache_lengths, - input_lengths_tensor=batch.input_lengths_tensor, - cache_lengths_tensor=batch.cache_lengths_tensor, + input_lengths=input_lengths.tolist(), + cache_lengths=cache_lengths_tensor.tolist(), + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, max_current_length=max_s, )