From 3756596dd0b6dae00e35f19202a0434a1a65d726 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 12 Jan 2025 06:02:09 +0000 Subject: [PATCH] feat: support LoRA with speculative decoding Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 40 ++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e369da1a70c23..0ed688bd8252f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -646,8 +646,20 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) + # The current LoRA adapter is configured for the target model. + # Applying it to the draft model might cause errors, so it is temporarily disabled. + original_lora_requests = [ + metadata.lora_request for metadata in execute_model_req.seq_group_metadata_list + ] + for metadata in execute_model_req.seq_group_metadata_list: + metadata.lora_request = None + self.proposer_worker.execute_model(execute_model_req) + # Restore the original LoRA information in execute_model_req after proposals are generated. + for metadata, original_lora_request in zip(execute_model_req.seq_group_metadata_list, original_lora_requests): + metadata.lora_request = original_lora_request + sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) if self._disable_logprobs else @@ -720,10 +732,22 @@ def _run_speculative_decoding_step( self.previous_hidden_states = None with Timer() as proposal_timer: + # The current LoRA adapter is configured for the target model. + # Applying it to the draft model might cause errors, so it is temporarily disabled. + original_lora_requests = [ + metadata.lora_request for metadata in execute_model_req.seq_group_metadata_list + ] + for metadata in execute_model_req.seq_group_metadata_list: + metadata.lora_request = None + # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) + # Restore the original LoRA information in execute_model_req after proposals are generated. + for metadata, original_lora_request in zip(execute_model_req.seq_group_metadata_list, original_lora_requests): + metadata.lora_request = original_lora_request + if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " @@ -1099,12 +1123,22 @@ def _vocab_size(self) -> int: """Get the vocab size of the model and make sure it's consistent between draft and target workers. """ - vocab_sizes = [ + orig_vocab_sizes = [ worker.vocab_size for worker in [self.proposer_worker, self.scorer_worker] ] - assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes) - return vocab_sizes[0] + assert all(orig_vocab_sizes[0] == vocab_size for vocab_size in orig_vocab_sizes) + + # If LoRA is enabled, additional padding is applied to the vocabulary size + # for kernel compatibility. + lora_vocab_padding_sizes = [ + worker.lora_config.lora_vocab_padding_size + for worker in [self.proposer_worker, self.scorer_worker] + if worker.lora_config is not None and worker.lora_config.lora_vocab_padding_size is not None + ] + assert all(lora_vocab_padding_sizes[0] == vocab_size for vocab_size in lora_vocab_padding_sizes) + + return orig_vocab_sizes[0] + lora_vocab_padding_sizes[0] @property def rank(self):