Skip to content

Commit

Permalink
feat: support LoRA with speculative decoding
Browse files Browse the repository at this point in the history
Signed-off-by: Sungjae Lee <[email protected]>
  • Loading branch information
llsj14 committed Jan 12, 2025
1 parent 4b657d3 commit 3756596
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,20 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)

# The current LoRA adapter is configured for the target model.
# Applying it to the draft model might cause errors, so it is temporarily disabled.

Check failure on line 650 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:650:81: E501 Line too long (95 > 80)
original_lora_requests = [
metadata.lora_request for metadata in execute_model_req.seq_group_metadata_list

Check failure on line 652 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:652:81: E501 Line too long (95 > 80)
]
for metadata in execute_model_req.seq_group_metadata_list:
metadata.lora_request = None

self.proposer_worker.execute_model(execute_model_req)

# Restore the original LoRA information in execute_model_req after proposals are generated.

Check failure on line 659 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:659:81: E501 Line too long (103 > 80)
for metadata, original_lora_request in zip(execute_model_req.seq_group_metadata_list, original_lora_requests):

Check failure on line 660 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:660:81: E501 Line too long (122 > 80)
metadata.lora_request = original_lora_request

sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
Expand Down Expand Up @@ -720,10 +732,22 @@ def _run_speculative_decoding_step(
self.previous_hidden_states = None

with Timer() as proposal_timer:
# The current LoRA adapter is configured for the target model.
# Applying it to the draft model might cause errors, so it is temporarily disabled.

Check failure on line 736 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:736:81: E501 Line too long (95 > 80)
original_lora_requests = [
metadata.lora_request for metadata in execute_model_req.seq_group_metadata_list

Check failure on line 738 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:738:81: E501 Line too long (95 > 80)
]
for metadata in execute_model_req.seq_group_metadata_list:
metadata.lora_request = None

# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)

# Restore the original LoRA information in execute_model_req after proposals are generated.

Check failure on line 747 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:747:81: E501 Line too long (103 > 80)
for metadata, original_lora_request in zip(execute_model_req.seq_group_metadata_list, original_lora_requests):

Check failure on line 748 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:748:81: E501 Line too long (122 > 80)
metadata.lora_request = original_lora_request

if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
Expand Down Expand Up @@ -1099,12 +1123,22 @@ def _vocab_size(self) -> int:
"""Get the vocab size of the model and make sure it's consistent between
draft and target workers.
"""
vocab_sizes = [
orig_vocab_sizes = [
worker.vocab_size
for worker in [self.proposer_worker, self.scorer_worker]
]
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
return vocab_sizes[0]
assert all(orig_vocab_sizes[0] == vocab_size for vocab_size in orig_vocab_sizes)

Check failure on line 1130 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:1130:81: E501 Line too long (88 > 80)

# If LoRA is enabled, additional padding is applied to the vocabulary size

Check failure on line 1132 in vllm/spec_decode/spec_decode_worker.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/spec_decode/spec_decode_worker.py:1132:81: E501 Line too long (83 > 80)
# for kernel compatibility.
lora_vocab_padding_sizes = [
worker.lora_config.lora_vocab_padding_size
for worker in [self.proposer_worker, self.scorer_worker]
if worker.lora_config is not None and worker.lora_config.lora_vocab_padding_size is not None
]
assert all(lora_vocab_padding_sizes[0] == vocab_size for vocab_size in lora_vocab_padding_sizes)

return orig_vocab_sizes[0] + lora_vocab_padding_sizes[0]

@property
def rank(self):
Expand Down

0 comments on commit 3756596

Please sign in to comment.