Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Medusa max_draft_len overhead impact #2506

Closed
ValeGian opened this issue Nov 26, 2024 · 2 comments
Closed

Medusa max_draft_len overhead impact #2506

ValeGian opened this issue Nov 26, 2024 · 2 comments
Assignees
Labels
question Further information is requested triaged Issue has been triaged by maintainers

Comments

@ValeGian
Copy link

ValeGian commented Nov 26, 2024

System Info

CPU architecture: x86_64
GPU: 8 NVIDIA H200
Libraries
TensorRT-LLM: v0.14.0
CUDA: 12.4
NVIDIA driver version: 550.127.05

Setup Info

I'm attempting to use Medusa with TensorRT-LLM to accelerate inference of a fine-tuned Llama 3.1 70B model originally in FP16 precision. To achieve this, I first converted the model to FP8 precision and built it using the following commands:

quantize.py --model_dir=<FINE-TUNED MODEL DIR> --dtype=float16 --tp_size=1 --output_dir=<QUANTIZED MODEL DIR> --qformat=fp8 --kv_cache_dtype=fp8 --calib_dataset=<CALIB DATASET> --calib_size=512 --batch_size=8 --calib_max_seq_length=1024

trtllm-build --checkpoint_dir=<QUANTIZED MODEL DIR> --max_beam_width=1 --max_seq_len=131072 --max_input_len=130560 --max_num_tokens=32768 --max_batch_size=8 --context_fmha=enable --output_dir=<OUT DIR> --use_fp8_context_fmha=disable

I used this FP8 model to distill a dataset and then trained 3 Medusa heads. When evaluated on a validation dataset, the Medusa heads achieved the following token prediction accuracies wrt the tokens generated by the original FP16 fine-tuned model:

TopK=0
> Head 0 Accuracy=0.6837761270606081
> Head 1 Accuracy=0.32617484167971394
> Head 2 Accuracy=0.1807497640902462

TopK=4
> Head 0 Accuracy=0.8547368890673448
> Head 1 Accuracy=0.5451475708643937
> Head 2 Accuracy=0.35749212612899667

These results indicate that the Medusa heads are correctly predicting tokens.

Next, I built an FP8 model with Medusa heads and set max_draft_len=1:

quantize.py --model_dir=<FINE-TUNED MODEL DIR> --dtype=float16 --tp_size=1 --output_dir=<QUANTIZED MODEL DIR> --qformat=fp8 --kv_cache_dtype=fp8 --calib_dataset=<CALIB DATASET> --calib_size=512 --batch_size=8 --calib_max_seq_length=1024 --max_draft_len=1 --num_medusa_heads=3 --num_medusa_layers=1 --medusa_model_dir=<MEDUSA MODEL DIR>

trtllm-build --checkpoint_dir=<QUANTIZED MODEL DIR> --max_beam_width=1 --max_seq_len=131072 --max_input_len=130560 --max_num_tokens=32768 --max_batch_size=8 --context_fmha=enable --output_dir=<OUT DIR> --use_fp8_context_fmha=disable --speculative_decoding_mode=medusa --max_draft_len=1

Running this model built with Medusa and a comparable model built without Medusa in a framework that utilizes TensorRT-LLM's implementation of inflight batching, I observed the following inference p99 latencies:

  • FP8 model without Medusa: 2.526s
  • FP8 model with Medusa and medusa_choices="[[0]]": 2.271s

I'm adding the medusa_choices in the code as follows:

decoding_config = trtllm.DecodingConfig()
if medusa_choices is not None:
    decoding_config.medusa_choices = ast.literal_eval(medusa_choices)

executor_config = trtllm.ExecutorConfig(
    max_beam_width=max_beam_width,
    max_batch_size=max_batch_size,
    max_num_tokens=max_num_tokens,
    batching_type=trtllm.BatchingType.INFLIGHT,
    scheduler_config=trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT),
    kv_cache_config=kv_cache_config,
    decoding_config=decoding_config,
    enable_chunked_context=enable_chunked_context,
    gpu_weights_percent=1
)

session = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, executor_config)

and creating the trtllm.Request like so:

import tensorrt_llm.bindings.executor as trtllm

tokens = self.tokenizer.encode(prompt, add_special_tokens=True,
                               max_length=self.config.build_config.max_input_len, truncation=True)

output_config = trtllm.OutputConfig()
output_config.exclude_input_from_output = True

sampling_conf = trtllm.SamplingConfig(
    temperature=1 if self.medusa else 0.1,
    top_k=1 if self.medusa else 50,
    top_p=0.9,
    random_seed=self.seed,
    beam_width=1 if self.medusa else self.max_beam_width
)

trt_request = trtllm.Request(
    input_token_ids = tokens,
    max_new_tokens = self.max_output_len,
    pad_id = self.tokenizer.pad_token_id,
    end_id = self.tokenizer.eos_token_id,
    streaming = True,
    sampling_config = sampling_conf,
    output_config = output_config
) 

Questions

  1. When I build a similar engine with max_draft_len=17 and run it with the same medusa_choices="[[0]]", I notice a clear increase in inference latency (p99 of 2.918s). Is this expected behavior due to the increased max_draft_len, even though I'm specifying to use only topk 0 of the first head?
  2. Do you have any benchmarks that demonstrate the overhead introduced by increasing the Medusa choice tree size (and the max_draft_len with it)?
@ValeGian ValeGian changed the title Medusa quality impact + max_draft_len overhead impact Medusa max_draft_len overhead impact Nov 27, 2024
@hello-11 hello-11 added question Further information is requested triaged Issue has been triaged by maintainers labels Dec 2, 2024
@rakib-hasan
Copy link

Hi @ValeGian
For (1), I think it is expected for the following reason. (It is a matter of compile-time-known vs runtime-known dimension)
For max_draft_len=1, TRT could choose some kernel where there is no need for any overhead (e.g. loop).
vs
For max_draft_len=17 and running with 1 medusa choice, TRT will still need to build an engine that is valid for all draft lengths from 1 to 17. So, it could choose a different kernel with an extra loop and an entirely different optimization strategy that is optimal for all values from 1 to 17, not just 1 as in the previous case. Adding that flexibility and balanced performance across all possible shapes, it can cost some performance.

For (2), unfortunately, we do not have any benchmarks yet that demonstrate the impact of these parameters. Maybe we can add it in the near future.

@ValeGian
Copy link
Author

ValeGian commented Dec 3, 2024

Thank you for the answer! Closing the question

@ValeGian ValeGian closed this as completed Dec 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants