From 20fbcaad41d0e06e87676ff16c77e3fbb276ec27 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 12 Jan 2025 06:07:38 +0000 Subject: [PATCH] make format Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 41 ++++++++++++++++---------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 0ed688bd8252f..87408195cbc6b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -646,18 +646,22 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) - # The current LoRA adapter is configured for the target model. - # Applying it to the draft model might cause errors, so it is temporarily disabled. + # The current LoRA adapter is configured for the target model. + # Applying it to the draft model might cause errors. original_lora_requests = [ - metadata.lora_request for metadata in execute_model_req.seq_group_metadata_list + metadata.lora_request + for metadata in execute_model_req.seq_group_metadata_list ] for metadata in execute_model_req.seq_group_metadata_list: metadata.lora_request = None self.proposer_worker.execute_model(execute_model_req) - # Restore the original LoRA information in execute_model_req after proposals are generated. - for metadata, original_lora_request in zip(execute_model_req.seq_group_metadata_list, original_lora_requests): + # Restore the original LoRA information in execute_model_req + # after proposals are generated. + for metadata, original_lora_request in zip( + execute_model_req.seq_group_metadata_list, + original_lora_requests): metadata.lora_request = original_lora_request sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( @@ -732,10 +736,11 @@ def _run_speculative_decoding_step( self.previous_hidden_states = None with Timer() as proposal_timer: - # The current LoRA adapter is configured for the target model. - # Applying it to the draft model might cause errors, so it is temporarily disabled. + # The current LoRA adapter is configured for the target model. + # Applying it to the draft model might cause errors. original_lora_requests = [ - metadata.lora_request for metadata in execute_model_req.seq_group_metadata_list + metadata.lora_request + for metadata in execute_model_req.seq_group_metadata_list ] for metadata in execute_model_req.seq_group_metadata_list: metadata.lora_request = None @@ -744,8 +749,11 @@ def _run_speculative_decoding_step( proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) - # Restore the original LoRA information in execute_model_req after proposals are generated. - for metadata, original_lora_request in zip(execute_model_req.seq_group_metadata_list, original_lora_requests): + # Restore the original LoRA information in execute_model_req + # after proposals are generated. + for metadata, original_lora_request in zip( + execute_model_req.seq_group_metadata_list, + original_lora_requests): metadata.lora_request = original_lora_request if not self._allow_zero_draft_token_step and proposals.no_proposals: @@ -1127,16 +1135,19 @@ def _vocab_size(self) -> int: worker.vocab_size for worker in [self.proposer_worker, self.scorer_worker] ] - assert all(orig_vocab_sizes[0] == vocab_size for vocab_size in orig_vocab_sizes) + assert all(orig_vocab_sizes[0] == vocab_size + for vocab_size in orig_vocab_sizes) - # If LoRA is enabled, additional padding is applied to the vocabulary size - # for kernel compatibility. + # If LoRA is enabled, additional padding is applied + # to the vocabulary size for kernel compatibility. lora_vocab_padding_sizes = [ worker.lora_config.lora_vocab_padding_size for worker in [self.proposer_worker, self.scorer_worker] - if worker.lora_config is not None and worker.lora_config.lora_vocab_padding_size is not None + if worker.lora_config is not None + and worker.lora_config.lora_vocab_padding_size is not None ] - assert all(lora_vocab_padding_sizes[0] == vocab_size for vocab_size in lora_vocab_padding_sizes) + assert all(lora_vocab_padding_sizes[0] == vocab_size + for vocab_size in lora_vocab_padding_sizes) return orig_vocab_sizes[0] + lora_vocab_padding_sizes[0]