From fd95c3ecd074417808d1a7a52ee61e9d8b172d51 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 2 Apr 2024 13:59:16 -0700 Subject: [PATCH] Fix adapter speculation --- server/lorax_server/models/flash_causal_lm.py | 20 ++++++++++++++++--- server/lorax_server/utils/layers.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 6b8892295..c27ba2461 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -969,9 +969,23 @@ def generate_token( batch.block_tables_tensor = block_tables_tensor batch.slots = slots - # Assign pointers to LoRA weights + # Update adapter indices for speculative tokens (if present) + adapter_meta = batch.adapter_meta + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = adapter_meta.adapter_indices.unsqueeze(-1).expand(B, new_length).reshape(-1) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, + ) + + # Assign pointers to adapter weights # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.batched_lora_weights) + adapter_data = AdapterBatchData.from_meta(adapter_meta, self.batched_lora_weights) try: out, speculative_logits = self.forward(batch, adapter_data) @@ -1084,7 +1098,7 @@ def generate_token( idx += 1 cumulative_length += input_length - + # Set values in batch batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.position_ids = next_position_ids + accepted_ids diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index f1101a651..3cfbfe226 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -558,7 +558,7 @@ def forward_layer_type( rank_segments.segment_ends, self.layer_id, ) - + if end_idx - start_idx != result.shape[1]: result[:, start_idx:end_idx] += proj else: