Skip to content

Commit

Permalink
Fix adapter speculation
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 2, 2024
1 parent a19f62a commit fd95c3e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
20 changes: 17 additions & 3 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,23 @@ def generate_token(
batch.block_tables_tensor = block_tables_tensor
batch.slots = slots

# Assign pointers to LoRA weights
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
adapter_indices = adapter_meta.adapter_indices.unsqueeze(-1).expand(B, new_length).reshape(-1)
adapter_segments = adapter_meta.adapter_segments * new_length
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_meta.adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_meta.segment_indices,
)

# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, self.batched_lora_weights)
adapter_data = AdapterBatchData.from_meta(adapter_meta, self.batched_lora_weights)

try:
out, speculative_logits = self.forward(batch, adapter_data)
Expand Down Expand Up @@ -1084,7 +1098,7 @@ def generate_token(
idx += 1

cumulative_length += input_length

# Set values in batch
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.position_ids = next_position_ids + accepted_ids
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def forward_layer_type(
rank_segments.segment_ends,
self.layer_id,
)

if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
Expand Down

0 comments on commit fd95c3e

Please sign in to comment.