diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2db2cbe08..4b1f27ef4 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1508,10 +1508,10 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> if FLASH_INFER: block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - cache_lengths=batch.cache_lengths, - input_lengths_tensor=batch.input_lengths_tensor, - cache_lengths_tensor=batch.cache_lengths_tensor, + input_lengths=input_lengths.tolist(), + cache_lengths=cache_lengths_tensor.tolist(), + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, max_current_length=max_s, )