Skip to content

Latest commit

 

History

History
396 lines (327 loc) · 22.4 KB

README.md

File metadata and controls

396 lines (327 loc) · 22.4 KB

Encoder-Decoder

This document shows how to build and run an Encoder-Decoder (Enc-Dec) model in TensorRT-LLM on NVIDIA GPUs.

Table of Contents

Overview

The TensorRT-LLM Enc-Dec implementation can be found in tensorrt_llm/models/enc_dec/model.py. The TensorRT-LLM Enc-Dec example code is located in examples/enc_dec:

  • trtllm-build to build the TensorRT engine(s) needed to run the Enc-Dec model,
  • run.py to run the inference on an example input text.
  • Enc-Dec models can have specific implementations, such as the popular T5 family (T5, mT5, Flan-T5), BART family (BART, mBART), and FairSeq family (WMTs). They are now merged into a single convert script:
    • convert_checkpoint.py to convert weights from HuggingFace or FairSeq format to TRT-LLM format, and split weights for multi-GPU inference,

Usage

The TensorRT-LLM Enc-Dec example code locates at examples/enc_dec. It takes HuggingFace or FairSeq model name as input, and builds the corresponding TensorRT engines. On each GPU, there will be two TensorRT engines, one for Encoder and one for Decoder.

Encoder-Decoder Model Support

The implementation is designed to support generic encoder-decoder models by abstracting the common and derivative components of different model architectures, such as:

It also supports full Tensor Parallelism (TP), Pipeline Parallelism (PP), and a hybrid of the two. Currently, Fused Multi-Head Attention (FMHA) is not yet enabled for T5 family due to its relative attention design.

In this example, we use T5 (t5-small) and Flan-T5 (google/flan-t5-small) to showcase TRT-LLM support on Enc-Dec models. BART models and FairSeq models can follow very similar steps by just replacing the model name.

Download weights from HuggingFace Transformers

git clone https://huggingface.co/t5-small tmp/hf_models/t5-small
git clone https://huggingface.co/google/flan-t5-small tmp/hf_models/flan-t5-small
git clone https://huggingface.co/facebook/bart-large-cnn tmp/hf_models/bart-large-cnn
git clone https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt tmp/hf_models/mbart-large-50-many-to-one-mmt

Convert and Split Weights

The convert_checkpoint.py script converts weights from HuggingFace or FairSeq format to TRT-LLM format, and splits weights for multi-GPU inference. --tp_size specifies the number of GPUs for tensor parallelism during inference. Pipeline Parallelism size can be set with --pp_size for distributed inference.

The HuggingFace or Fairseq checkpoints of the enc-dec models mentioned in this Readme are all float32 precision. Use --dtype to set the target inference precision during the weight conversion.

After weight conversion, TensorRT-LLM converted weights and model configuration will be saved under <out_dir>/<tpX> directory, which is the --checkpoint_dir input path you should give to the next engine building phase.

Take T5 for example:

# Example: build t5-small using 4-way tensor parallelism on a node with 8 GPUs (but only use 4 of them, for demonstration purpose), BF16, enabling beam search up to width=1.
export MODEL_NAME="t5-small" # or "flan-t5-small"
export MODEL_TYPE="t5"
export INFERENCE_PRECISION="bfloat16"
export TP_SIZE=4
export PP_SIZE=1
export WORLD_SIZE=4
export MAX_BEAM_WIDTH=1
python convert_checkpoint.py --model_type ${MODEL_TYPE} \
                --model_dir tmp/hf_models/${MODEL_NAME} \
                --output_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \
                --tp_size ${TP_SIZE} \
                --pp_size ${PP_SIZE} \
                --weight_data_type float32 \
                --dtype ${INFERENCE_PRECISION}

Build TensorRT engine(s)

TensorRT-LLM builds TensorRT engine(s) with flexible controls on different types of optimizations. Note that these are just examples to demonstrate multi-GPU inference. For small models like T5-small, single GPU is usually sufficient.

After engine building, TensorRT engines will be saved under <out_dir>/<tpX> directory, which is the --engine_dir path you should give to the next engine running phase. It is recommended to have /<Y-gpu> in the output path where Y is number of total GPU ranks in a multi-node, multi-GPU setup, because the same Y number GPUs could be executed with different TP (Tensor Parallelism) and PP (Pipeline Parallelism) combinations.

We should distinguish between X - TP size and Y - total number of GPU ranks:

  • When X = Y, only TP is enabled
  • When X < Y, both TP and PP are enabled. In such case, please make sure you have completed weight conversion step for TP=X.

The default value of --max_input_len is 1024. When building DecoderModel, specify decoder input length with --max_input_len=1 for encoder-decoder model to start generation from decoder_start_token_id of length 1. If the start token is a single token (the default behavior of T5/BART/etc.), you should set --max_input_len as 1; if you want the decoder-only type of generation, set --max_input_len above 1 to get similar behavior as HF's decoder_forced_input_ids.

DecoderModel takes --max_encoder_input_len and --max_input_len as model inputs, --max_encoder_input_len is set to 1024 as default since --max_input_len is 1024 for EncoderModel.

To be noted: for T5, add --context_fmha disable, and --bert_attention_plugin, --gpt_attention_plugin, --remove_input_padding, --gemm_plugin require explicit disabling and setting.

# --gpt_attention_plugin is necessary in Enc-Dec.
# Try --gemm_plugin to prevent accuracy issue.
# It is recommended to use --remove_input_padding along with --gpt_attention_plugin for better performance
trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}/encoder \
                --output_dir tmp/trt_engines/${MODEL_NAME}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE}/encoder \
                --paged_kv_cache disable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width ${MAX_BEAM_WIDTH} \
                --max_batch_size 8 \
                --max_output_len 200 \
                --gemm_plugin ${INFERENCE_PRECISION} \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding enable \
                --context_fmha disable

# For decoder, refer to the above content and set --max_input_len correctly
trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}/decoder \
                --output_dir tmp/trt_engines/${MODEL_NAME}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE}/decoder \
                --paged_kv_cache disable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width ${MAX_BEAM_WIDTH} \
                --max_batch_size 8 \
                --max_output_len 200 \
                --gemm_plugin ${INFERENCE_PRECISION} \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding enable \
                --context_fmha disable \
                --max_input_len 1

For BART, --context_fmha can be enabled. trtllm-build has the default setting to enable it.

# Example: build bart-large-cnn using a single GPU, FP32, running greedy search
export MODEL_NAME="bart-large-cnn" # or "mbart-large-50-many-to-one-mmt"
export MODEL_TYPE="bart"
export INFERENCE_PRECISION="float32"
export TP_SIZE=1
export PP_SIZE=1
export WORLD_SIZE=1
export MAX_BEAM_WIDTH=1
python convert_checkpoint.py --model_type ${MODEL_TYPE} \
                --model_dir tmp/hf_models/${MODEL_NAME} \
                --output_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \
                --tp_size ${TP_SIZE} \
                --pp_size ${PP_SIZE} \
                --weight_data_type float32 \
                --dtype ${INFERENCE_PRECISION}

# Note: non-T5 models can enable FMHA for the encoder part, for FP16/BF16, the default is enabled
trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}/encoder \
                --output_dir tmp/trt_engines/${MODEL_NAME}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE}/encoder \
                --paged_kv_cache disable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width ${MAX_BEAM_WIDTH} \
                --max_batch_size 8 \
                --max_output_len 200 \
                --gemm_plugin ${INFERENCE_PRECISION} \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding enable
                # --context_fmha disable should be removed

# Use the same command for decoder engine
trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}/decoder \
                --output_dir tmp/trt_engines/${MODEL_NAME}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE}/decoder \
                --paged_kv_cache disable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width ${MAX_BEAM_WIDTH} \
                --max_batch_size 8 \
                --max_output_len 200 \
                --gemm_plugin ${INFERENCE_PRECISION} \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding enable \
                --max_input_len 1
                # --context_fmha disable should be removed

Run

Run a TensorRT-LLM Enc-Dec model using the engines generated by build.py. Note that during model deployment, only the TensorRT engine files are needed. Previously downloaded model checkpoints and converted weights can be removed.

# Inferencing w/ single GPU greedy search, compare results with HuggingFace FP32
python3 run.py --engine_dir tmp/trt_engines/${MODEL_NAME}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE} --engine_name ${MODEL_NAME} --model_name tmp/hf_models/${MODEL_NAME} --max_new_token=64 --num_beams=1 --compare_hf_fp32

# Inferencing w/ 4 GPUs (4-way TP, as configured during the engine building step), greedy search, compare results with HuggingFace FP32
mpirun --allow-run-as-root -np ${WORLD_SIZE} python3 run.py --engine_dir tmp/trt_engines/${MODEL_NAME}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE} --engine_name ${MODEL_NAME} --model_name tmp/hf_models/${MODEL_NAME} --max_new_token=64 --num_beams=1 --compare_hf_fp32

Benchmark

The benchmark implementation and entrypoint can be found in benchmarks/python/benchmark.py. Specifically, benchmarks/python/enc_dec_benchmark.py is the benchmark script for Encoder-Decoder models.

In benchmarks/python/:

# Example 1: Single-GPU benchmark
python benchmark.py \
    -m t5_small \
    --batch_size "1;8" \
    --input_output_len "60,20;128,20" \
    --dtype float32 \
    --csv # optional

# Example 2: Multi-GPU benchmark
mpirun --allow-run-as-root -np 4 python benchmark.py \
    -m t5_small \
    --batch_size "1;8" \
    --input_output_len "60,20;128,20" \
    --dtype float32 \
    --csv # optional

Run BART with LoRA

  • Download the base model and lora model from HF:
git clone https://huggingface.co/facebook/bart-large-cnn tmp/hf_models/bart-large-cnn
git clone https://huggingface.co/sooolee/bart-large-cnn-samsum-lora tmp/hf_models/bart-large-cnn-samsum-lora

If using customize models, just put both the base model and lora model dirs into tmp/hf_models.

  • Convert and Split Weights, setting --hf_lora_dir.
export INFERENCE_PRECISION="float16"
python convert_checkpoint.py --model_type bart \
                --model_dir tmp/hf_models/bart-large-cnn \
                --output_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISION} \
                --tp_size 1 \
                --pp_size 1 \
                --weight_data_type float32 \
                --dtype ${INFERENCE_PRECISION}
  • Build engine, setting --use_lora_plugin.
trtllm-build --checkpoint_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISION}/tp1/pp1/encoder \
                --output_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/tp1/encoder \
                --paged_kv_cache disable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width 1 \
                --max_batch_size 8 \
                --max_output_len 200 \
                --gemm_plugin ${INFERENCE_PRECISION} \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding disable \
                --lora_plugin ${INFERENCE_PRECISION} \
                --lora_dir tmp/hf_models/bart-large-cnn-samsum-lora/ \
                --lora_target_modules attn_q attn_v

trtllm-build --checkpoint_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISION}/tp1/pp1/decoder \
                --output_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/tp1/decoder \
                --paged_kv_cache disable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width 1 \
                --max_batch_size 8 \
                --max_output_len 200 \
                --gemm_plugin ${INFERENCE_PRECISION} \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding disable \
                --max_input_len 1 \
                --lora_plugin ${INFERENCE_PRECISION} \
                --lora_dir tmp/hf_models/bart-large-cnn-samsum-lora/ \
                --lora_target_modules attn_q cross_attn_q attn_v cross_attn_v
  • Run the engine, setting --lora_dir and --lora_task_uids. --lora_task_uids should be set as a list of uids which length equals to batch size. The following example is for batch size = 3:
python run.py \
        --engine_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/tp1/ \
        --engine_name bart-large-cnn \
        --model_name tmp/hf_models/bart-large-cnn \
        --max_new_token=64 \
        --num_beams=1 \
        --lora_dir tmp/hf_models/bart-large-cnn-samsum-lora/ \
        --lora_task_uids 0 0 0
  • Run with multi-loRA, append --lora_dir with other lora directories and set --lora_task_uids according to the index of the lora directories. Set to "-1" to run with the base model:
python run.py \
        --engine_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/tp1/ \
        --engine_name bart-large-cnn \
        --model_name tmp/hf_models/bart-large-cnn \
        --max_new_token=64 \
        --num_beams=1 \
        --lora_dir tmp/hf_models/bart-large-cnn-samsum-lora/ ... \
        --lora_task_uids 0 -1 -1 0 0 -1

Reminders

  • Flan-T5 models have known issues regarding FP16 precision and using BF16 precision is recommended, regardless of TRT-LLM. While we are working on improving FP16 results, please stay with FP32 or BF16 precision for Flan-T5 family.
  • Batched/Ragged input with beam search is having subtle issues with some sequence results being truncated. For the time being, please follow (1) if batch size = 1, no problem (2) if batched input is padded (i.e., not using --remove_input_padding flag), no problem (3) if batched input is ragged (i.e., using --remove_input_padding), only use greedy search for now.
  • For T5 and Flan-T5 family that have relative attention bias design, the relative attention table is split along num_heads dimension in Tensor Parallelism mode. Therefore, num_heads must be divisible by tp_size. Please be aware of this when setting the TP parameter.
  • For mBART, models that can control output languages (e.g. mbart-large-50-many-to-many-mmt) are not currently supported, as the script does not support ForcedBOSTokenLogitsProcessor to control output languages.

Attention Scaling Factors

The q_scaling convention in the TRT-LLM plugin is defined as follows:

norm_factor = 1.f / (q_scaling * sqrt(head_size))

In the Multi-Head Attention (MHA) mechanism, the output of the Q*K^T product is scaled by this constant value norm_factor as norm_factor * (Q*K^T) for softmax. This scaling factor can be adjusted or neutralized based on the model's requirements.

Handling in Different Models:

  • BART/FairSeq NMT: For the BART models, q_scaling is set to 1.f. Therefore, the norm_factor for BART becomes 1.f / sqrt(head_size). TRT-LLM uses the default value q_scaling = 1.f. Similar to FairSeq NMT models.
  • T5: For the T5 models, q_scaling is 1.f/sqrt(head_size), leading to a norm_factor of 1.f. This is handled in T5 by the TRT-LLM's get_offset_q_scaling() function, which reads head_size from the T5 model configuration and sets q_scaling = 1.f/sqrt(head_size) to effectively offset the norm_factor to 1.f.

Run FairSeq NMT (Neural Machine Translation) models

FairSeq model download and library dependency are different from HuggingFace ones. Especially if you are following the recommended docker container setup in README, it has a custom PyTorch build but FairSeq installation will force upgrade the PyTorch version. As a workaround, we skip the torch and torchaudio dependencies in FairSeq to make everything work nicely inside the TRT-LLM container.

# Download weights from HuggingFace Transformers
# Instructions from: https://github.com/facebookresearch/fairseq/blob/main/examples/translation/README.md#example-usage-cli-tools. Public model checkpoints are also listed there. Here we use WMT'14 Transformer model as an example.
mkdir -p tmp/fairseq_models && curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2 | tar xvjf - -C tmp/fairseq_models  --one-top-level=wmt14 --strip-components 1 --no-same-owner

# Install FairSeq dependency
# avoid base torch to be upgraded by fairseq
pushd tmp && (git clone https://github.com/facebookresearch/fairseq.git || true) && pushd fairseq && sed -i '/torch>=/d;/torchaudio>=/d' setup.py && pip install -e . && pip install sacremoses subword_nmt && popd && popd

# Convert and Split Weights, single GPU example
export TP_SIZE=1
export PP_SIZE=1
export WORLD_SIZE=1
export INFERENCE_PRECISION="float32"
python convert_checkpoint.py --model_type nmt \
                --model_dir tmp/fairseq_models/wmt14 \
                --output_dir tmp/trt_models/wmt14/${INFERENCE_PRECISION} \
                --tp_size ${TP_SIZE} \
                --pp_size ${PP_SIZE} \
                --weight_data_type float32 \
                --dtype ${INFERENCE_PRECISION}

# Build TensorRT engine(s)
# Note: non-T5 models can enable FMHA for the encoder part, although only FP16/BF16 precisions are valid
trtllm-build --checkpoint_dir tmp/trt_models/wmt14/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}/encoder \
                --output_dir tmp/trt_engines/wmt14/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE}/encoder \
                --paged_kv_cache disable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width 1 \
                --max_batch_size 8 \
                --max_output_len 200 \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding disable

trtllm-build --checkpoint_dir tmp/trt_models/wmt14/${INFERENCE_PRECISION}/tp${TP_SIZE}/pp${PP_SIZE}/decoder \
                --output_dir tmp/trt_engines/wmt14/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE}/decoder \
                --paged_kv_cache disable \
                --moe_plugin disable \
                --enable_xqa disable \
                --use_custom_all_reduce disable \
                --max_beam_width 1 \
                --max_batch_size 8 \
                --max_output_len 200 \
                --bert_attention_plugin ${INFERENCE_PRECISION} \
                --gpt_attention_plugin ${INFERENCE_PRECISION} \
                --remove_input_padding disable \
                --max_input_len 1
# Run
mpirun --allow-run-as-root -np ${WORLD_SIZE} python3 run.py --engine_dir tmp/trt_engines/wmt14/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE} --engine_name wmt14 --model_name tmp/fairseq_models/wmt14/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE} --max_new_token=24 --num_beams=1