Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
Signed-off-by: Sungjae Lee <[email protected]>
  • Loading branch information
llsj14 committed Jan 12, 2025
1 parent 3756596 commit 20fbcaa
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,18 +646,22 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)

# The current LoRA adapter is configured for the target model.
# Applying it to the draft model might cause errors, so it is temporarily disabled.
# The current LoRA adapter is configured for the target model.
# Applying it to the draft model might cause errors.
original_lora_requests = [
metadata.lora_request for metadata in execute_model_req.seq_group_metadata_list
metadata.lora_request
for metadata in execute_model_req.seq_group_metadata_list
]
for metadata in execute_model_req.seq_group_metadata_list:
metadata.lora_request = None

self.proposer_worker.execute_model(execute_model_req)

# Restore the original LoRA information in execute_model_req after proposals are generated.
for metadata, original_lora_request in zip(execute_model_req.seq_group_metadata_list, original_lora_requests):
# Restore the original LoRA information in execute_model_req
# after proposals are generated.
for metadata, original_lora_request in zip(
execute_model_req.seq_group_metadata_list,
original_lora_requests):
metadata.lora_request = original_lora_request

sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
Expand Down Expand Up @@ -732,10 +736,11 @@ def _run_speculative_decoding_step(
self.previous_hidden_states = None

with Timer() as proposal_timer:
# The current LoRA adapter is configured for the target model.
# Applying it to the draft model might cause errors, so it is temporarily disabled.
# The current LoRA adapter is configured for the target model.
# Applying it to the draft model might cause errors.
original_lora_requests = [
metadata.lora_request for metadata in execute_model_req.seq_group_metadata_list
metadata.lora_request
for metadata in execute_model_req.seq_group_metadata_list
]
for metadata in execute_model_req.seq_group_metadata_list:
metadata.lora_request = None
Expand All @@ -744,8 +749,11 @@ def _run_speculative_decoding_step(
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)

# Restore the original LoRA information in execute_model_req after proposals are generated.
for metadata, original_lora_request in zip(execute_model_req.seq_group_metadata_list, original_lora_requests):
# Restore the original LoRA information in execute_model_req
# after proposals are generated.
for metadata, original_lora_request in zip(
execute_model_req.seq_group_metadata_list,
original_lora_requests):
metadata.lora_request = original_lora_request

if not self._allow_zero_draft_token_step and proposals.no_proposals:
Expand Down Expand Up @@ -1127,16 +1135,19 @@ def _vocab_size(self) -> int:
worker.vocab_size
for worker in [self.proposer_worker, self.scorer_worker]
]
assert all(orig_vocab_sizes[0] == vocab_size for vocab_size in orig_vocab_sizes)
assert all(orig_vocab_sizes[0] == vocab_size
for vocab_size in orig_vocab_sizes)

# If LoRA is enabled, additional padding is applied to the vocabulary size
# for kernel compatibility.
# If LoRA is enabled, additional padding is applied
# to the vocabulary size for kernel compatibility.
lora_vocab_padding_sizes = [
worker.lora_config.lora_vocab_padding_size
for worker in [self.proposer_worker, self.scorer_worker]
if worker.lora_config is not None and worker.lora_config.lora_vocab_padding_size is not None
if worker.lora_config is not None
and worker.lora_config.lora_vocab_padding_size is not None
]
assert all(lora_vocab_padding_sizes[0] == vocab_size for vocab_size in lora_vocab_padding_sizes)
assert all(lora_vocab_padding_sizes[0] == vocab_size
for vocab_size in lora_vocab_padding_sizes)

return orig_vocab_sizes[0] + lora_vocab_padding_sizes[0]

Expand Down

0 comments on commit 20fbcaa

Please sign in to comment.