diff --git a/.gitignore b/.gitignore index 3cc202c38..d3ea24ec6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,9 @@ __pycache__/ *.cache *.nsys-rep .VSCodeCounter -build*/ +cpp/build* +build +!tensorrt_llm/bench/build !builders/ *.egg-info/ .coverage diff --git a/README.md b/README.md index e981b1d7e..12ebaee18 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,15 @@ TensorRT-LLM
## Latest News -* [2024/08/06] 🗫 Multilingual Challenge Accepted 🗫 -🤖 #TensorRT #LLM boosts low-resource languages like Hebrew, Indonesian and Vietnamese ⚡[➡️ link](https://developer.nvidia.com/blog/accelerating-hebrew-llm-performance-with-nvidia-tensorrt-llm/?linkId=100000278659647) +* [2024/08/13] 🐍 DIY Code Completion with #Mamba ⚡ #TensorRT #LLM for speed 🤖 NIM for ease ☁️ deploy anywhere +[➡️ link](https://developer.nvidia.com/blog/revolutionizing-code-completion-with-codestral-mamba-the-next-gen-coding-llm/)
- +
+* [2024/08/06] 🗫 Multilingual Challenge Accepted 🗫 +🤖 #TensorRT #LLM boosts low-resource languages like Hebrew, Indonesian and Vietnamese ⚡[➡️ link](https://developer.nvidia.com/blog/accelerating-hebrew-llm-performance-with-nvidia-tensorrt-llm/?linkId=100000278659647) + * [2024/07/30] Introducing🍊 @SliceXAI ELM Turbo 🤖 train ELM once ⚡ #TensorRT #LLM optimize ☁️ deploy anywhere [➡️ link](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms) diff --git a/benchmarks/README.md b/benchmarks/README.md index 575769842..00f450319 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -7,5 +7,5 @@ There are currently three workflows to benchmark TensorRT-LLM: - The recommended workflow that uses TensorRT-LLM C++ API and can take advantage of the latest features of TensorRT-LLM. * [Python benchmarks](./python) - The Python benchmarking scripts can only benchmark the Python runtime, which do not support the latest features, such as in-flight batching. -* [The Python benchmarking suite](./suite) +* [The Python benchmarking suite](./Suite.md) - This benchmarking suite is a current work in progress and is prone to large changes. diff --git a/benchmarks/Suite.md b/benchmarks/Suite.md new file mode 100644 index 000000000..f447b73e7 --- /dev/null +++ b/benchmarks/Suite.md @@ -0,0 +1,316 @@ +# TensorRT-LLM Benchmarking + +> [!WARNING] Work in Progress +> This benchmarking suite is a current work in progress and is prone to large changes. + +TensorRT-LLM provides a packaged benchmarking utility that is accessible via the `trtllm-bench` CLI tool. + +#### Supported Networks for Benchmarking + +- [`tiiuae/falcon-180B`](https://huggingface.co/tiiuae/falcon-180B) +- [`meta-llama/Llama-2-7b-hf`](https://huggingface.co/meta-llama/Llama-2-7b-hf) +- [`meta-llama/Llama-2-70b-hf`](https://huggingface.co/meta-llama/Llama-2-70b-hf) +- [`meta-llama/Meta-Llama-3-8B`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) +- [`meta-llama/Meta-Llama-3-70B`](https://huggingface.co/meta-llama/Meta-Llama-3-70B) +- [`EleutherAI/gpt-j-6b`](https://huggingface.co/EleutherAI/gpt-j-6b) + +#### Support Quantization Modes + +TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the +[documentation](https://nvidia.github.io/TensorRT-LLM/precision.html). + +- None (no quantization applied) +- W8A16 +- W4A16 +- W4A16_AWQ +- W4A8_AWQ +- W4A16_GPTQ +- FP8 +- INT8 + +> [!NOTE] Please see the supported quantization methods for each network [here](https://nvidia.github.io/TensorRT-LLM/precision.html#support-matrix) + + +## Inflight Benchmarking with a Dataset + +This section covers how to benchmark TensorRT-LLM using inflight batching. + + +### Quickstart + +For this quick start guide, we will focus on running a short max throughput benchmark on +`meta-llama/Llama-2-7b-hf` on a syntehtic dataset with a uniform distribution of prompts with ISL:OSL +of 128:128. In order to run the benchmark from start to finish simply run the following commands: + +```shell +python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1400 > /tmp/synthetic_128_128.txt +trtllm-bench --model meta-llama/Llama-2-7b-hf build --dataset /tmp/synthetic_128_128.txt --quantization FP8 +trtllm-bench --model meta-llama/Llama-2-7b-hf throughput --dataset /tmp/synthetic_128_128.txt --engine-path /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1 +``` + +And that's it! Once the benchmark completes, a summary will be printed with summary metrics. + +``` +=========================================================== += ENGINE DETAILS +=========================================================== +Model: meta-llama/Llama-2-7b-hf +Engine Directory: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1 +TensorRT-LLM Version: 0.12.0.dev2024073000 +Dtype: float16 +KV Cache Dtype: FP8 +Quantization: FP8 +Max Input Length: 2048 +Max Sequence Length: 4098 + +=========================================================== += WORLD + RUNTIME INFORMATION +=========================================================== +TP Size: 1 +PP Size: 1 +Max Runtime Batch Size: 4096 +Max Runtime Tokens: 8192 +Scheduling Policy: Guaranteed No Evict +KV Memory Percentage: 99.0% +Issue Rate (req/sec): 3.680275266452667e+18 +=========================================================== += STATISTICS +=========================================================== +Number of requests: 3000 +Average Input Length (tokens): 128.0 +Average Output Length (tokens): 128.0 +Token Throughput (tokens/sec): 23405.927228471104 +Request Throughput (req/sec): 182.8588064724305 +Total Latency (seconds): 16.406100739 +=========================================================== +``` + +### Workflow + +The workflow for `trtllm-bench` is composed of the following steps: + +1. Prepare a dataset to drive the inflight batching benchmark. +2. Build a benchmark engine using `trtllm-bench build` subcommand. +3. Run the max throughput benchmark using the `trtllm-bench throughput` subcommand. + +#### Preparing a Dataset + +The inflight benchmark utilizes a fixed JSON schema so that it is simple and +straightforward to specify requests. The schema is defined as follows: + +| Key | Required | Type | Description | +| :- | :-: | :-: | :- | +| `task_id`| Y | String | Unique identifier for the request. | +| `prompt` | N* | String | Input text for a generation request. | +| `logits` | N* | List[Integer] | List of logits that make up the request prompt. | +| `output_tokens` | Y | Integer | Number of generated tokens for this request. | + +> [!NOTE] Prompt and logits are mutually exclusive* +> While having both `prompt` and `logits` is not required, at least one is required. +> If `logits` are specified, the `prompt` entry is ignored for request generation. + +Examples of valid entries for the inflight benchmark are: + +- Entries with a human-readable prompt and no logits. +```json +{"task_id": 1, "prompt": "Generate an infinite response to the following: This is the song that never ends, it goes on and on my friend.", "output_tokens": 1000} +{"task_id": 2, "prompt": "Generate an infinite response to the following: Na, na, na, na", "output_tokens": 1000} +``` + +- Entries which contain logits. +```json +{"task_id":0,"logits":[863,22056,25603,11943,8932,13195,3132,25032,21747,22213],"output_tokens":128} +{"task_id":1,"logits":[14480,13598,15585,6591,1252,8259,30990,26778,7063,30065,21764,11023,1418],"output_tokens":128} +``` + +> [!INFO] A whole entry is on a line! +> To make the passing of data simpler, a complete JSON entry is on each line so that the benchmarker +> can simply read a line and assume a complete entry. When creating a dataset, be sure that a complete +> JSON entry is on every line. + +#### Using `prepare_dataset` to Create Synthetic Datasets + +In order to prepare a synthetic dataset, you can use the provided script in the `benchmarks/cpp` +directory. For example, to generate a synthetic dataset of 1000 requests with a uniform ISL/OSL of +128/128 for [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), simply run: + +```shell +benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > $PATH_TO_DATASET +``` + +You can pipe the above command to a file to reuse the same dataset, or simply pipe its output to the +benchmark script (example below). + +### Building a Benchmark Engine + +The second thing you'll need once you have a dataset is an engine to benchmark against. In order to +build a pre-configured engine for one of the supported ISL:OSL combinations, you can run the following +using the dataset you generated with `prepare_dataset.py` to build an FP8 quantized engine: + +```shell +trtllm-bench --model $HF_MODEL_NAME build --dataset $PATH_TO_DATASET --quantization FP8 +``` + +or manually set a max sequence length thatL you plan to run with specifically: + +```shell +trtllm-bench --model $HF_MODEL_NAME build --max_seq_len $MAX_SEQ_LEN --quantization FP8 +``` + +The engine in this case will be written to the `/tmp/$HF_MODEL_NAME/tp_1_pp_1/` directory. + +### Running a Max Throughput Benchmark + +The `trtllm-bench` command line tool provides a max throughput benchmark that is accessible via the +`throughput` subcommand. This benchmark tests a TensorRT-LLM engine under maximum load to provide an +upper bound throughput number. + +#### How the Benchmarker Works + +The benchmarker will read in a data file or standard input (stdin) as a stream where a single line contains +a complete JSON request entry. The process that the benchmarker is as follows: + +1. Iterate over all input requests. If `logits` is specified, construct the request using the specified +list of logits. Otherwise, tokenize the `prompt` with as specified by `--model $HF_MODEL_NAME`. +3. Submit the dataset to the TensorRT-LLM `Executor` API at as fast of a rate as possible (offline mode). +4. Wait for all requests to return, compute statistics, then report out results. + +To run the benchmarker, run the following with the engine and dataset generated above: + +``` +trtllm-bench --model $HF_MODEL_NAME throughput --dataset $PATH_TO_DATASET --engine_dir /tmp/$HF_MODEL_NAME/tp_1_pp_1/ +``` + +When the benchmark runs, you will see output similar to the following: + +``` +Preparing to run throughput benchmark... +Setting up benchmarker and infrastructure. +Initializing Throughput Benchmark. [rate=%d req/s] +Ready to start benchmark. +Initializing Executor. +[TensorRT-LLM][INFO] Engine version 0.12.0.dev2024073000 found in the config file, assuming engine(s) built by new builder API. +[TensorRT-LLM][INFO] Initializing MPI with thread mode 3 +[TensorRT-LLM][INFO] Initialized MPI +[TensorRT-LLM][INFO] Engine version 0.12.0.dev2024073000 found in the config file, assuming engine(s) built by new builder API. +[TensorRT-LLM][INFO] MPI size: 1, MPI local size: 1, rank: 0 +[TensorRT-LLM][INFO] Rank 0 is using GPU 0 +[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 4096 +[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 4096 +[TensorRT-LLM][INFO] TRTGptModel maxBeamWidth: 1 +[TensorRT-LLM][INFO] TRTGptModel maxSequenceLen: 4098 +[TensorRT-LLM][INFO] TRTGptModel maxDraftLen: 0 +[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: 4098 +[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0 +[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 1 +[TensorRT-LLM][INFO] TRTGptModel maxNumTokens: 8192 +[TensorRT-LLM][INFO] TRTGptModel maxInputLen: 4097 = maxSequenceLen - 1 since chunked context is enabled +[TensorRT-LLM][INFO] Capacity Scheduler Policy: GUARANTEED_NO_EVICT +[TensorRT-LLM][INFO] Context Chunking Scheduler Policy: FIRST_COME_FIRST_SERVED +[TensorRT-LLM][INFO] Loaded engine size: 6214 MiB +[TensorRT-LLM][INFO] [MemUsageChange] Allocated 928.77 MiB for execution context memory. +[TensorRT-LLM][INFO] [MS] Running engine with multi stream info +[TensorRT-LLM][INFO] [MS] Number of aux streams is 1 +[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2 +[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream +[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB) +[TensorRT-LLM][INFO] [MS] Running engine with multi stream info +[TensorRT-LLM][INFO] [MS] Number of aux streams is 1 +[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2 +[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream +[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB) +[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 1. Please ensure there are no enqueued operations pending in this context prior to switching profiles +[TensorRT-LLM][INFO] [MS] Running engine with multi stream info +[TensorRT-LLM][INFO] [MS] Number of aux streams is 1 +[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2 +[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream +[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB) +[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 2. Please ensure there are no enqueued operations pending in this context prior to switching profiles +[TensorRT-LLM][INFO] [MS] Running engine with multi stream info +[TensorRT-LLM][INFO] [MS] Number of aux streams is 1 +[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2 +[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream +[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB) +[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 3. Please ensure there are no enqueued operations pending in this context prior to switching profiles +[TensorRT-LLM][INFO] [MS] Running engine with multi stream info +[TensorRT-LLM][INFO] [MS] Number of aux streams is 1 +[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2 +[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream +[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB) +[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 4. Please ensure there are no enqueued operations pending in this context prior to switching profiles +[TensorRT-LLM][INFO] [MS] Running engine with multi stream info +[TensorRT-LLM][INFO] [MS] Number of aux streams is 1 +[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2 +[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream +[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB) +[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 5. Please ensure there are no enqueued operations pending in this context prior to switching profiles +[TensorRT-LLM][INFO] [MemUsageChange] Allocated 1.14 GB GPU memory for runtime buffers. +[TensorRT-LLM][INFO] [MemUsageChange] Allocated 4.35 GB GPU memory for decoder. +[TensorRT-LLM][INFO] Memory usage when calculating max tokens in paged kv cache: total: 79.10 GiB, available: 63.62 GiB +[TensorRT-LLM][INFO] Number of blocks in KV cache primary pool: 4607 +[TensorRT-LLM][INFO] Number of blocks in KV cache secondary pool: 0, onboard blocks to primary memory before reuse: true +[TensorRT-LLM][INFO] Max KV cache pages per sequence: 65 +[TensorRT-LLM][INFO] Number of tokens per block: 64. +[TensorRT-LLM][INFO] [MemUsageChange] Allocated 62.99 GiB for max tokens in paged KV cache (294848). +[TensorRT-LLM][INFO] Executor instance created by worker +Starting response daemon...Executor started. + +Request serving started. +Starting statistics collection. +Collecting live stats... +Benchmark started. +Request serving stopped. +Collecting last stats... +Ending statistics collection. +Stop received. +Stopping response parsing. +Collecting last responses before shutdown. +Completed request parsing. +Parsing stopped. +Request generator successfully joined. +Statistics process successfully joined. +=========================================================== += ENGINE DETAILS +=========================================================== +Model: meta-llama/Llama-2-7b-hf +Engine Directory: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1 +TensorRT-LLM Version: 0.12.0.dev2024073000 +Dtype: float16 +KV Cache Dtype: FP8 +Quantization: FP8 +Max Input Length: 2048 +Max Sequence Length: 4098 + +=========================================================== += WORLD + RUNTIME INFORMATION +=========================================================== +TP Size: 1 +PP Size: 1 +Max Runtime Batch Size: 4096 +Max Runtime Tokens: 8192 +Scheduling Policy: Guaranteed No Evict +KV Memory Percentage: 99.0% +Issue Rate (req/sec): 3.680275266452667e+18 +=========================================================== += STATISTICS +=========================================================== +Number of requests: 3000 +Average Input Length (tokens): 128.0 +Average Output Length (tokens): 128.0 +Token Throughput (tokens/sec): 23405.927228471104 +Request Throughput (req/sec): 182.8588064724305 +Total Latency (seconds): 16.406100739 +=========================================================== + +Benchmark Shutdown called! +Shutting down ExecutorServer. +[TensorRT-LLM][INFO] Orchestrator sendReq thread exiting +[TensorRT-LLM][INFO] Orchestrator recv thread exiting +Executor shutdown. +[TensorRT-LLM][INFO] Leader sendThread exiting +[TensorRT-LLM][INFO] Leader recvReq thread exiting +``` + +> [!WARNING] Some statistics are not reported. +> There are some statistics that are not reported in the summary (typically as 0.0). These statistics +> are not available currently. diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index d3861a2f3..488f71a19 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -24,6 +24,7 @@ #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/gptJsonConfig.h" @@ -173,6 +174,9 @@ struct BenchmarkParams // Decoding params std::optional>> medusaChoices; + + std::optional executorLookaheadConfig; + std::optional requestLookaheadConfig; }; class InferenceRequestsAsyncSend @@ -509,6 +513,7 @@ class Recorder { if (!mStreaming) { + TLLM_LOG_DEBUG("response.getResult().outputTokenIds"); auto outputTokenIds = response.getResult().outputTokenIds; int32_t outSeqLen = 0; @@ -824,9 +829,11 @@ class ExecutorServer executorConfig.setMaxNumTokens(benchmarkParams.maxNumTokens.value()); } - executorConfig.setDecodingConfig(texec::DecodingConfig( - benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() : texec::DecodingMode::Auto(), - std::nullopt, benchmarkParams.medusaChoices)); + executorConfig.setDecodingConfig( + texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() + : benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead() + : texec::DecodingMode::Auto(), + benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices)); executorConfig.setExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig); if (executorModelType == texec::ModelType::kDECODER_ONLY) @@ -910,7 +917,7 @@ class ExecutorServer for (auto const& response : responses) { auto const reqId = response.getRequestId(); - + TLLM_LOG_DEBUG("response.getResult().isFinal"); if (response.getResult().isFinal) { mActiveCount--; @@ -1323,7 +1330,8 @@ std::shared_ptr makeRequest(std::uint64_t reqId, Sample const& ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId, BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr, ITensor::SharedPtr const& returnGenerationLogits = nullptr, ITensor::SharedPtr const& loraWeights = nullptr, - ITensor::SharedPtr const& loraConfig = nullptr) + ITensor::SharedPtr const& loraConfig = nullptr, + std::optional lookaheadConfig = std::nullopt) { auto request = std::make_shared(reqId); auto const& inputIds = sample.inputIds; @@ -1361,6 +1369,10 @@ std::shared_ptr makeRequest(std::uint64_t reqId, Sample const& { request->setLoraConfig(loraConfig); } + if (lookaheadConfig) + { + request->setLookaheadConfig(lookaheadConfig.value()); + } if (streaming) { request->setIsStreaming(true); @@ -1372,18 +1384,20 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW std::optional const& eosId, std::optional const& padId, bool streaming = false, bool const& returnContextLogits = false, bool const& returnGenerationLogits = false, std::optional const& loraConfig = std::nullopt, + std::optional const& lookaheadConfig = std::nullopt, std::optional encoderInputTokenIds = std::nullopt) { auto samplingConfig = texec::SamplingConfig{beamWidth}; auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false}; return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId, - std::nullopt, // badWords - std::nullopt, // stopWords - std::nullopt, // embeddingBias - std::nullopt, // speculativeDecoding - std::nullopt, // pTuning - loraConfig, - std::nullopt, // logitsPostProcessorName + std::nullopt, // badWords + std::nullopt, // stopWords + std::nullopt, // embeddingBias + std::nullopt, // speculativeDecoding + std::nullopt, // pTuning + loraConfig, // loraConfig + lookaheadConfig, // lookaheadConfig + std::nullopt, // logitsPostProcessorName encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt); } @@ -1429,9 +1443,11 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType optionalParams.maxBatchSize = benchmarkParams.maxBatchSize; optionalParams.maxNumTokens = benchmarkParams.maxNumTokens; optionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy}; - optionalParams.decodingConfig = texec::DecodingConfig( - benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() : texec::DecodingMode::Auto(), - std::nullopt, benchmarkParams.medusaChoices); + optionalParams.decodingConfig + = texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() + : benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead() + : texec::DecodingMode::Auto(), + benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices); optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig( benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc); @@ -1501,8 +1517,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType ++reqId; if (i == terminateReqId) ++reqId; - auto request = makeRequest( - reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager); + auto request = makeRequest(reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor, + padIdTensor, bufferManager, nullptr, nullptr, nullptr, nullptr, benchmarkParams.requestLookaheadConfig); gptServer->enqueue(request); } gptServer->waitForEmpty(); @@ -1517,7 +1533,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType for (std::size_t i = 0; i < numSamples; ++i) { auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor, - padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor); + padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor, nullptr, + nullptr, benchmarkParams.requestLookaheadConfig); gptServer->enqueue(request); if (i < numSamples - 1) @@ -1541,7 +1558,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType for (std::size_t i = 0; i < numSamples; ++i) { auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor, - padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor); + padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor, + nullptr, nullptr, benchmarkParams.requestLookaheadConfig); gptServer->enqueue(request); } gptServer->waitForEmpty(); @@ -1644,13 +1662,13 @@ void benchmarkExecutor(std::optional const& decoderEngine { Sample s{std::vector{decoderStartTokenId}, 1, static_cast(taskId)}; requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, - loraConfig, std::vector{1, 2, 3, 4, 5})); + loraConfig, std::nullopt, std::vector{1, 2, 3, 4, 5})); } else { Sample s{std::vector{1, 2, 3, 4, 5}, 1, static_cast(taskId)}; requests.emplace_back( - makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig)); + makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig, std::nullopt)); } } executorServer->enqueue(std::move(requests), true); @@ -1668,12 +1686,14 @@ void benchmarkExecutor(std::optional const& decoderEngine { Sample s{std::vector{decoderStartTokenId}, samples[0].outputLen, samples[0].taskId}; requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming, - returnContextLogits, returnGenerationLogits, std::nullopt, samples[0].inputIds)); + returnContextLogits, returnGenerationLogits, std::nullopt, + benchmarkParams.requestLookaheadConfig, samples[0].inputIds)); } else { requests.emplace_back(makeExecutorRequest(samples[0], beamWidth, eosId, padId, - benchmarkParams.streaming, returnContextLogits, returnGenerationLogits)); + benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, std::nullopt, + benchmarkParams.requestLookaheadConfig)); } } executorServer->enqueue(std::move(requests), true); @@ -1699,12 +1719,14 @@ void benchmarkExecutor(std::optional const& decoderEngine { Sample s{std::vector{decoderStartTokenId}, samples[i].outputLen, samples[i].taskId}; requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming, - returnContextLogits, returnGenerationLogits, loraConfig, samples[i].inputIds)); + returnContextLogits, returnGenerationLogits, loraConfig, benchmarkParams.requestLookaheadConfig, + samples[i].inputIds)); } else { requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId, - benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig)); + benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig, + benchmarkParams.requestLookaheadConfig)); } } @@ -1789,6 +1811,25 @@ std::vector> parseVectorOfVectors(std::string const& inp return result; } +texec::LookaheadDecodingConfig parseLookaheadConfig(std::string const& input) +{ + std::regex regex("\\[ *(\\d+) *, *(\\d+) *, *(\\d+) *\\]"); + std::smatch match; + if (std::regex_match(input, match, regex)) + { + TLLM_CHECK(match.size() == 4); + auto w = std::stoi(match[1]); + auto n = std::stoi(match[2]); + auto g = std::stoi(match[3]); + return texec::LookaheadDecodingConfig(w, n, g); + } + else + { + TLLM_LOG_WARNING("cannot parse lookahead config from '%s'", input.c_str()); + return texec::LookaheadDecodingConfig(); + } +} + } // namespace int main(int argc, char* argv[]) @@ -1898,6 +1939,14 @@ int main(int argc, char* argv[]) options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation", cxxopts::value()->default_value("false")); + options.add_options()("executor_lookahead_config", + "lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size]", + cxxopts::value()); + + options.add_options()("request_lookahead_config", + "lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size], and each <= " + "executor lookahead config", + cxxopts::value()); auto result = options.parse(argc, argv); @@ -2055,6 +2104,16 @@ int main(int argc, char* argv[]) { benchmarkParams.medusaChoices = parseVectorOfVectors(result["medusa_choices"].as()); } + if (result.count("executor_lookahead_config")) + { + benchmarkParams.executorLookaheadConfig + = parseLookaheadConfig(result["executor_lookahead_config"].as()); + } + if (result.count("request_lookahead_config")) + { + benchmarkParams.requestLookaheadConfig + = parseLookaheadConfig(result["request_lookahead_config"].as()); + } // Argument: multi_block_mode benchmarkParams.multiBlockMode = result["multi_block_mode"].as(); diff --git a/benchmarks/python/all_reduce.py b/benchmarks/python/all_reduce.py index ae7cb8868..d91cdd0d4 100644 --- a/benchmarks/python/all_reduce.py +++ b/benchmarks/python/all_reduce.py @@ -23,7 +23,6 @@ import tensorrt_llm as tllm from tensorrt_llm import Mapping, Tensor -from tensorrt_llm._ipc_utils import peer_access from tensorrt_llm._utils import OMPI_COMM_TYPE_HOST, mpi_comm from tensorrt_llm.functional import AllReduceStrategy, allreduce from tensorrt_llm.plugin.plugin import current_all_reduce_helper @@ -106,18 +105,18 @@ def allreduce_benchmark(dtype: str, _, start = cuda.cuEventCreate(0) _, stop = cuda.cuEventCreate(0) runtimes = [] - with peer_access(mapping): - tllm.mpi_barrier() - - for _ in range(10): - cuda.cuEventRecord(start, stream.cuda_stream) - session.run(inputs=feed_dict, - outputs={"output": output}, - stream=stream.cuda_stream) - cuda.cuEventRecord(stop, stream.cuda_stream) - torch.cuda.synchronize() - _, ms = cuda.cuEventElapsedTime(start, stop) - runtimes.append(ms) + + tllm.mpi_barrier() + + for _ in range(10): + cuda.cuEventRecord(start, stream.cuda_stream) + session.run(inputs=feed_dict, + outputs={"output": output}, + stream=stream.cuda_stream) + cuda.cuEventRecord(stop, stream.cuda_stream) + torch.cuda.synchronize() + _, ms = cuda.cuEventElapsedTime(start, stop) + runtimes.append(ms) median_ms = sorted(runtimes)[len(runtimes) // 2] assert torch.allclose(output, (input * world_size)**inner_loop) diff --git a/benchmarks/suite/README.md b/benchmarks/suite/README.md deleted file mode 100644 index bba21609e..000000000 --- a/benchmarks/suite/README.md +++ /dev/null @@ -1,234 +0,0 @@ -# TensorRT-LLM Benchmarking - -> [!WARNING] Work in Progress -> This benchmarking suite is a current work in progress and is prone to large changes. - -This package is the official benchmarking suite for TensorRT-LLM. This benchmark will be updated -as development of TensorRT-LLM continues. - -## Installation - -From this folder, run `pip install -r requirements.txt` to install the extra dependencies required for this tool. - -### Available Build and Benchmark Options - -The following model options are available for benchmarking models. - -| Option | Required | Default | Description | -| :- | :-: | :-: | :- | -| `--model` | Y | - | The name of the model to benchmark. | -| `--dtype` | N | `float16` | The datatype of the weights. | -| `--max-batch-size` | Y | - | The batch size to build the engine with for the benchmark. | -| `--kv-dtype` | N | `float16` | The datatype to store the KV Cache in. | -| `--kv-cache-free-gpu-mem-fraction` | N | `0.98` | The percentage of free memory that the KV cache is allowed to occupy. | -| `--quantization` | N | `None` |The quantization algorithm to be used when benchmarking. See the [documentation](https://nvidia.github.io/TensorRT-LLM/precision.html) for more information| -| `--workspace` | N | `/tmp` | The directory to store benchmarking intermediate files. | -| `--tensor-parallel-size` | N | `1` | Number of tensor parallel shards to run the benchmark with. | -| `--pipeline-parallel-size` | N | `1` | Number of pipeline parallel shards to run the benchmark with. | - -#### Supported Networks for Benchmarking - -- [`tiiuae/falcon-7b`](https://huggingface.co/tiiuae/falcon-7b) -- [`tiiuae/falcon-40b`](https://huggingface.co/tiiuae/falcon-40b) -- [`tiiuae/falcon-180B`](https://huggingface.co/tiiuae/falcon-180B) -- [`meta-llama/Llama-2-7b-hf`](https://huggingface.co/meta-llama/Llama-2-7b-hf) -- [`meta-llama/Llama-2-13b-hf`](https://huggingface.co/meta-llama/Llama-2-13b-hf) -- [`meta-llama/Llama-2-70b-hf`](https://huggingface.co/meta-llama/Llama-2-70b-hf) -- [`EleutherAI/gpt-j-6b`](https://huggingface.co/EleutherAI/gpt-j-6b) - -#### Support Quantization Modes - -TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the -[documentation](https://nvidia.github.io/TensorRT-LLM/precision.html). - -- None (no quantization applied) -- W8A16 -- W4A16 -- W4A16_AWQ -- W4A8_AWQ -- W4A16_GPTQ -- FP8 -- INT8 - -> [!NOTE] Please see the supported quantization methods for each network [here](https://nvidia.github.io/TensorRT-LLM/precision.html#support-matrix) - -## Static Benchmarking a Network - -In order to benchmark a static batch for a network, run a command like the following: - -```shell -cd tensorrt_llm_bench/ -python benchmark.py --model tiiuae/falcon-7b static --isl 128 --osl 128 --max-batch-size 1 -``` - -This command line will build a unique engine for the configuration and run the benchmark using -the `gptSessionBenchmark` binary. You need to build the TensorRT-LLM wheel with the `--benchmarks` flag for this binary to be compiled: - -```shell -python3 ./scripts/build_wheel.py --benchmarks -``` - -If you've already compiled the wheel without benchmarks, you can build the benchmarking binaries with the following after the fact: - -```shell -pushd cpp/build/ -make -j benchmarks -popd -``` - -The complete list of arguments for static benchmarking are as follows: -| Option | Required | Default | Description | -| :- | :-: | :-: | :- | -| `--isl` | Y | - | The input sequence length to pass in during benchmark. | -| `--osl` | Y | - | The output sequence length to generate in the benchmark. | -| `--gpt-session-path` | N | `../../cpp/build/benchmarks/gptSessionBenchmark` | The path to the built gptSessionBenchmark binary. | -| `--warm-up-runs` | N | `2` | The number of warm up runs to run before benchmarking actual results. | -| `--num-runs` | N | `10` | The number runs to generate benchmarking results from. | -| `--duration` | N | `60` | The minimum iteration time, in seconds, to measure. | - -> [!WARNING] -> `gptSession` will be deprecated for the 1.0 release of TensorRT-LLM. This command line will change in order to match and update benchmarks accordingly. - - -## Inflight Benchmarking with a Dataset - -This section covers how to benchmark TensorRT-LLM using inflight batching. - -### Workflow - -The workflow for inflight batching is slightly different than the [static scenario](#static-benchmarking-a-network) as it requires a workload of requests instead of a single static batch. The following is the workflow for benchmarking using inflight batching: - -1. Prepare a dataset to drive the inflight batching benchmark. -2. Run the `inflight` benchmarking subcommand and provide the dataset from step 1. - -#### Preparing a Dataset - -The inflight benchmark utilizes a fixed JSON schema so that it is simple and -straightforward to specify requests. The schema is defined as follows: - -| Key | Required | Type | Description | -| :- | :-: | :-: | :- | -| `task_id`| Y | String | Unique identifier for the request. | -| `prompt` | N* | String | Input text for a generation request. | -| `logits` | N* | List[Integer] | List of logits that make up the request prompt. | -| `output_tokens` | Y | Integer | Number of generated tokens for this request. | - -> [!NOTE] Prompt and logits are mutually exclusive* -> While having both `prompt` and `logits` is not required, at least one is required. -> If `logits` are specified, the `prompt` entry is ignored for request generation. - -Examples of valid entries for the inflight benchmark are: - -- Entries with a human-readable prompt and no logits. -```json -{"task_id": 1, "prompt": "Generate an infinite response to the following: This is the song that never ends, it goes on and on my friend.", "output_tokens": 1000} -{"task_id": 2, "prompt": "Generate an infinite response to the following: Na, na, na, na", "output_tokens": 1000} -``` - -- Entries which contain logits. -```json -{"task_id":0,"logits":[863,22056,25603,11943,8932,13195,3132,25032,21747,22213],"output_tokens":128} -{"task_id":1,"logits":[14480,13598,15585,6591,1252,8259,30990,26778,7063,30065,21764,11023,1418],"output_tokens":128} -``` - -> [!INFO] A whole entry is on a line! -> To make the passing of data simpler, a complete JSON entry is on each line so that the benchmarker -> can simply read a line and assume a complete entry. When creating a dataset, be sure that a complete -> JSON entry is on every line. - -#### Using `prepare_dataset` to Create Synthetic Datasets - -In order to prepare a synthetic dataset, you can use the provided script in the `benchmarks/cpp` -directory. For example, to generate a synthetic dataset of 1000 requests with a uniform ISL/OSL of -128/128 for [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), simply run: - -```shell -benchmarks/cpp/prepare_dataset.py --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 --stdout -``` - -You can pipe the above command to a file to reuse the same dataset, or simply pipe its output to the -benchmark script (example below). - -### Running a Dataset with the Benchmarker - -Once you've generated a dataset (see [above](#preparing-a-dataset)), you can run the benchmarker -in one of two ways: - -```shell -benchmarks/suite/tensorrt_llm_bench/benchmark.py --model $HF_MODEL_NAME --max-batch-size $BATCH_SIZE < $DATASET_PATH -``` - -> [!INFO] Alternative to piping. -> There is also a `--dataset` option for `benchmark.py` that can be used instead of piping a file. - -or - -```shell -benchmarks/cpp/prepare_dataset.py --tokenizer $HF_MODEL_NAME --input-mean $ISL --output-mean $OSL --num-requests $NUM_REQUESTS --stdout | benchmarks/suite/tensorrt_llm_bench/benchmark.py --model $HF_MODEL_NAME --max-batch-size $BATCH_SIZE --request-rate $REQUEST_RATE -``` - -#### How the Benchmarker Works - -The benchmarker will read in a data file or standard input (stdin) as a stream where a single line contains -a complete JSON request entry. The process that the benchmarker is as follows: - -1. Iterate over all input requests. If `logits` is specified, construct the request using the specified -list of logits. Otherwise, tokenize the `prompt` with as specified by `--model $HF_MODEL_NAME`. -2. Build the TensorRT-LLM engine. -3. Submit the dataset to the TensorRT-LLM `Executor` API at the request rate specified by `--request-rate $REQUEST_RATE` -4. Wait for all requests to return, compute statistics, then report out results. - -When the benchmark runs successfully, you will see a report out of the run similar to the following: - -``` -[RANK 0] Submitting requests... -[RANK 0] Completed request submission. -[RANK 0] Calculating results. -[RANK 0] Reporting... -[RANK 0] JSON: {'benchmark_cmd': '', 'binary': '', 'build_cmd': 'trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_seq_len 256 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16', 'first_token_latency': 0.0, 'inflight_batching': True, 'kv_mem_fraction': 0.98, 'latency_units': 'ms', 'max_batch_size': 1024, 'max_tokens': 8000, 'model': 'meta-llama/Llama-2-7b-hf', 'peak_gpu_mem_units': 'GB', 'peak_gpu_mem': 0.0, 'scheduler': 'Max Utilization', 'throughput_units': 'tokens/second', 'throughput': 17634.422523488243, 'time_per_output_token': 0.0, 'total_input_tokens': 128000, 'total_latency': 7.258530855178833, 'total_output_tokens': 128000} -=========================================================== -= METADATA -=========================================================== -Model: meta-llama/Llama-2-7b-hf -TP Size: 1 -PP Size: 1 -Scheduling Policy: Max Utilization -In-flight Batcher?: True -Dtype: float16 -KV Cache Dtype: FP8 -Quantization: FP8 -KV Memory Percentage: 98.0% - -=========================================================== -= ENGINE DETAILS -=========================================================== -Engine Directory: /tmp/meta-llama/llama-2-7b-hf -Max Batch Size: 1024 -Total Input Length: 128000 -Total Output Length: 128000 -Max Tokens: 8000 - -=========================================================== -= STATISTICS -=========================================================== -Throughput (tokens/second): 17634.422523488243 -Total Latency (ms): 7258.5309 -First Token Latency (ms): 0.0 -Token-to-token Latency (ms): 0.0 -Peak GPU Memory Usage (GB): 0.0 - -=========================================================== -= COMMANDS -=========================================================== -Build: trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_seq_len 256 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16 -Benchmark: - -[RANK 0] Terminating. -``` - -> [!WARNING] Some statistics are not reported. -> There are some statistics that are not reported in the summary (typically as 0.0). These statistics -> are not available currently. - - -That's it! -- you've successfully benchmarked TensorRT-LLM! diff --git a/benchmarks/suite/requirements.txt b/benchmarks/suite/requirements.txt deleted file mode 100644 index e75e33990..000000000 --- a/benchmarks/suite/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -pydantic>=2.2.1 -click-option-group == 0.5.6 -aenum == 3.1.15 diff --git a/benchmarks/suite/tensorrt_llm_bench/__init__.py b/benchmarks/suite/tensorrt_llm_bench/__init__.py deleted file mode 100644 index d6bfb8507..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module for running TensorRT-LLM benchmarks.""" diff --git a/benchmarks/suite/tensorrt_llm_bench/benchmark.py b/benchmarks/suite/tensorrt_llm_bench/benchmark.py deleted file mode 100644 index 65e306459..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/benchmark.py +++ /dev/null @@ -1,125 +0,0 @@ -from pathlib import Path -from typing import get_args - -import click -from ifb import executor_benchmark -from static import static_benchmark -from utils import VALID_CACHE_DTYPES, VALID_COMPUTE_DTYPES, VALID_QUANT_ALGOS -from utils.dataclasses import BenchmarkConfig - - -@click.group(context_settings={'show_default': True}) -@click.option( - "--model", - "-m", - required=True, - type=str, - help="The Huggingface name of the model to benchmark.", -) -@click.option( - "--max-batch-size", - hidden=True, - default=0, - type=int, - help="Maximum batch size to build the benchmark engine with.", -) -@click.option( - "--kv-dtype", - type=click.Choice(tuple(get_args(VALID_CACHE_DTYPES))), - default="float16", - help="The dtype to store the KV Cache in.", -) -@click.option( - "--dtype", - type=click.Choice(tuple(get_args(VALID_COMPUTE_DTYPES))), - default="float16", - help="Activation and plugin data type.", -) -@click.option( - "--quantization", - "-q", - type=click.Choice(tuple(get_args(VALID_QUANT_ALGOS))), - default="None", - help= - ("The quantization algorithm to be used when benchmarking. See the " - "documentations for more information.\n" - " - https://nvidia.github.io/TensorRT-LLM/precision.html" - " - https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/quantization-in-TRT-LLM.md" - ), -) -@click.option( - "--workspace", - "-w", - required=False, - type=click.Path(writable=True, readable=True), - default="/tmp", - help="The directory to store benchmarking intermediate files.", -) -@click.option( - "--tensor-parallel-size", - "-tp", - type=int, - default=1, - required=False, - help="Number of tensor parallel shards to run the benchmark with.", -) -@click.option( - "--pipeline-parallel-size", - "-pp", - type=int, - default=1, - required=False, - help="Number of pipeline parallel shards to run the benchmark with.", -) -@click.option( - "--kv-cache-free-gpu-mem-fraction", - "-kv-mem", - type=float, - default=0.98, - help="The percentage of free memory that the KV Cache is allowed to occupy.", -) -@click.option( - "--build-opts", - type=str, - default="", - required=False, - hidden=True, - help="Passthrough options for trtllm-build to fine-tuning build commands.") -@click.pass_context -def benchmark( - ctx, - model: str, - max_batch_size: int, - workspace: Path, - dtype: str, - kv_dtype: str, - quantization: str, - tensor_parallel_size: int, - pipeline_parallel_size: int, - kv_cache_free_gpu_mem_fraction: float, - build_opts: str, -): - """Utility for using TRT-LLM for benchmarking networks from Huggingface.""" - ctx.obj = BenchmarkConfig( - model=model, - max_batch_size=max_batch_size, - workspace=Path(workspace), - dtype=dtype, - cache_dtype=kv_dtype, - quantization=quantization, - tensor_parallel=tensor_parallel_size, - pipeline_parallel=pipeline_parallel_size, - kv_cache_mem_percentage=kv_cache_free_gpu_mem_fraction, - build_overrides=build_opts.split(), - ) - - # Create the workspace where we plan to store intermediate files. - ctx.obj.workspace.mkdir(parents=True, exist_ok=True) - - -# Add nested subcommands to main benchmark CLI. -benchmark.add_command(static_benchmark) -benchmark.add_command(executor_benchmark) - -if __name__ == "__main__": - benchmark() diff --git a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/__init__.py b/benchmarks/suite/tensorrt_llm_bench/benchmarkers/__init__.py deleted file mode 100644 index 2cc2877af..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import List, Protocol - -from utils.dataclasses import BenchmarkResults, InferenceRequest - - -class Benchmarker(Protocol): - """Protocol for defining benchmarking classes for building/benchmarking.""" - - def build(self) -> None: - """Build a model to be benchmarked.""" - ... - - def benchmark(self) -> BenchmarkResults: - """Benchmark the constructed model container by a benchmarker.""" - ... - - -class DatasetBenchmarker(Protocol): - - def benchmark_dataset(self, - dataset: List[InferenceRequest]) -> BenchmarkResults: - """_summary_ - - Args: - dataset (List[InferenceRequest]): List of inference requests to benchmark. - - Returns: - BenchmarkResults: The results of the benchmark run. - """ - ... diff --git a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/pybind_executor.py b/benchmarks/suite/tensorrt_llm_bench/benchmarkers/pybind_executor.py deleted file mode 100644 index 5742e90a5..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/pybind_executor.py +++ /dev/null @@ -1,146 +0,0 @@ -from datetime import timedelta -from time import sleep, time -from typing import List - -from mpi4py.MPI import COMM_WORLD -from transformers import PreTrainedTokenizer -from utils.dataclasses import BenchmarkConfig, BenchmarkResults -from utils.enums import IFBSchedulingPolicy, ResultsSchedulingPolicy - -from tensorrt_llm.bindings.executor import (Executor, ExecutorConfig, - KvCacheConfig, ModelType, - OutputConfig, Request, - SchedulerConfig) - -from . import InferenceRequest - - -class PybindExecutorBenchmarker: - """Utility class for running inflight benchmarks via the Executor API.""" - - def __init__( - self, - config: BenchmarkConfig, - ): - """Initialize a gptSessionBenchmark instance. - - Args: - config (BenchmarkConfig): Benchmark configuration for build/run. - """ - self.config: BenchmarkConfig = config - - @staticmethod - def get_request(request: InferenceRequest, - tokenizer: PreTrainedTokenizer) -> Request: - return Request( - input_token_ids=request.logits, - max_new_tokens=request.output_tokens, - stop_words=[], - bad_words=[], - streaming=False, - output_config=OutputConfig(exclude_input_from_output=True), - pad_id=tokenizer.pad_token_id, - end_id=tokenizer.eos_token_id, - ) - - def initialize_executor(self) -> Executor: - """ - Initialize an Executor instance. - - Returns: - Executor: An instance of a TensorRT-LLM Executor. - """ - policy = IFBSchedulingPolicy(self.config.scheduling_policy).value - executor_config: ExecutorConfig = ExecutorConfig( - max_beam_width=1, - enable_chunked_context=self.config.chunking, - scheduler_config=SchedulerConfig( - capacity_scheduler_policy=policy, ), - kv_cache_config=KvCacheConfig( - free_gpu_memory_fraction=self.config.kv_cache_mem_percentage, ), - ) - - executor: Executor = Executor( - model_path=self.config.engine_path, - model_type=ModelType.DECODER_ONLY, - executor_config=executor_config, - ) - - return executor - - def benchmark_dataset(self, rate: int, - dataset: List[InferenceRequest]) -> BenchmarkResults: - """Benchmark the Executor Pybind interface. - - Args: - dataset (List[InferenceRequest]): List of inference requests to - benchmark with. - - Returns: - BenchmarkResults: Final results from running the specified dataset. - """ - request_ids = [] - num_finished = 0 - num_errored = 0 - num_input_tokens = 0 - num_output_tokens = 0 - delay = 1.0 / float(rate) - last_request = len(dataset) - 1 - bench_result = None - - executor = self.initialize_executor() - if executor.can_enqueue_requests(): - print(f"[RANK {COMM_WORLD.rank}] Submitting requests...") - start = time() - for i, request in enumerate(dataset): - sleep_time = delay if i != last_request else 0 - request_ids.append(executor.enqueue_request(request)) - num_input_tokens += len(request.input_token_ids) - sleep(sleep_time) - print(f"[RANK {COMM_WORLD.rank}] Completed request submission.") - - while num_finished <= last_request: - responses = executor.await_responses(timeout=timedelta( - milliseconds=1)) - for response in responses: - has_error = response.has_error() - num_finished += 1 - num_errored += 1 if has_error else 0 - - if not has_error: - result = response.result - for out_tokens in result.output_token_ids: - num_output_tokens += len(out_tokens) - end = time() - print(f"[RANK {COMM_WORLD.rank}] Calculating results.") - e2e_time = end - start - e2e_time * 1000.0 - policy = ResultsSchedulingPolicy( - IFBSchedulingPolicy(self.config.scheduling_policy).value) - - bench_result = BenchmarkResults( - model=self.config.model, - dtype=self.config.dtype.value, - quantization=str(self.config.quantization.value), - max_batch_size=self.config.max_batch_size, - total_input_tokens=num_input_tokens, - total_output_tokens=num_output_tokens, - tp_size=self.config.tensor_parallel, - pp_size=self.config.pipeline_parallel, - kv_mem_fraction=self.config.kv_cache_mem_percentage, - scheduler=policy.value, - max_tokens=self.config.max_tokens, - inflight_batching=True, - total_latency=e2e_time, - first_token_latency=0, - time_per_output_token=0, - latency_units="ms", - throughput=num_output_tokens / e2e_time, - throughput_units="tokens/second", - peak_gpu_mem=0.0, - peak_gpu_mem_units="GB", - build_cmd="", - benchmark_cmd="", - ) - - return bench_result diff --git a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/static.py b/benchmarks/suite/tensorrt_llm_bench/benchmarkers/static.py deleted file mode 100644 index b5b3c49ea..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/static.py +++ /dev/null @@ -1,208 +0,0 @@ -import platform -from pathlib import Path -from subprocess import CompletedProcess -from typing import Dict, List - -from utils import command_logger, process_error_check, run_process -from utils.dataclasses import BenchmarkConfig, BenchmarkResults -from utils.trtllm_config import TRTLLMConfig - - -class gptSessionBenchmarker: - """Utility class for running static benchmarks with gptSessionBenchmark.""" - - def __init__( - self, - config: BenchmarkConfig, - benchmark_binary: Path, - batch_size: int, - isl: int, - osl: int, - warm_up_runs: int, - num_runs: int, - duration: int, - kv_cache_free_fraction: float = .9, - ): - """Initialize a gptSessionBenchmark instance. - - Args: - config (BenchmarkConfig): Benchmark configuration for build/run. - benchmark_binary (Path): Path to the benchmarking binary. - batch_size (int): Batch size to configure the build with. - isl (int): Input sequence length to configure the build with. - osl (int): Output sequence length to configure the build with. - kv_cache_free_fraction (float, optional): The amount of remaining - GPU memory after model loading to save for the KV Cache. Defaults - to .9. - """ - self.config: BenchmarkConfig = config - self.gpt_session_path = Path(benchmark_binary).absolute() - self.batch_size = batch_size - self.input_length = isl - self.output_length = osl - self.warm_up = warm_up_runs - self.num_runs = num_runs - self.duration = duration - self.kv_cache_mem = kv_cache_free_fraction - self.result = None - - def get_build_command(self) -> List[str]: - """Build the engine command for TRT-LLM. - - Returns: - List[str]: A list of command line arguments to run a build command. - """ - model = self.config.model - tp = self.config.tensor_parallel - pp = self.config.pipeline_parallel - dtype = self.config.dtype.value - kv_dtype = self.config.cache_dtype - quant_algo = self.config.quantization.value - output_dir = self.config.engine_path - max_batch_size = self.batch_size - max_isl = self.input_length - max_osl = self.output_length - workspace = self.config.workspace - - # Generate the TRT-LLM Configuration file using the dataclass - # NOTE: This method does not use weights. - trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo, - kv_dtype.value) - # Write the generated configuration file to the benchmark workspace. - trtllm_config.to_json(workspace) - - # Return the full command for building TRT-LLM via subprocess call. - cmd = [ - "trtllm-build", - "--output_dir", - output_dir, - "--model_config", - Path(workspace, "generated_config.json"), - "--workers", - self.config.world_size, - # Define the maximums the engine can accept. - "--max_batch_size", - max_batch_size, - "--max_input_len", - max_isl, - "--max_seq_len", - max_osl + max_isl, - "--context_fmha", - "enable", - # Set the attention plugin data type. - "--gpt_attention_plugin", - dtype, - # Disable paged cache since we aren't batching on the fly. - "--paged_kv_cache", - "disable", - ] + kv_dtype.get_build_options(dtype) - - return [str(arg) for arg in cmd] - - @command_logger(prefix="BUILD COMMAND: ") - @process_error_check - def _run_build(self, cmd: List[str]) -> CompletedProcess: - """Wrapper for calling the build for TRT-LLM. - - Purpose of this wrapper is so that we can decorate it/log it. - - Args: - cmd (List[str]): List of command line arguments for running. - - Returns: - CompletedProcess: Completed process information for parsing and - reporting. - """ - return run_process( - cmd, - self.config.workspace, - ) - - def build(self) -> None: - """Build the engine for benchmarking.""" - self._run_build(self.get_build_command()) - - @command_logger(prefix="BENCHMARK COMMAND: ") - @process_error_check - def _run_benchmark(self, cmd: List[str]) -> CompletedProcess: - """Run the benchmark command in the configured workspace. - - Args: - cmd (List[str]): List of command line arguments to run via - subprocess. - - Returns: - CompletedProcess: Completed process information for reporting. - """ - return run_process(cmd, run_dir=self.config.workspace, use_environ=True) - - @staticmethod - def parse_benchmark_result(benchmark_line: str) -> Dict[str, str]: - pass - - def benchmark(self): - """Benchmarks a TRT-LLM for a configured instance.""" - - # Compile the command for running - cmd = ["mpiexec", "-n", self.config.world_size] - cmd += ["-allow-run-as-root"] if platform.system() != "Windows" else "" - cmd += [ - self.gpt_session_path, - "--engine_dir", - self.config.engine_path, - "--batch_size", - self.batch_size, - "--log_level", - "info", - "--kv_cache_free_gpu_mem_fraction", - self.kv_cache_mem, - "--beam_width", - "1", - "--warm_up", - self.warm_up, - "--num_runs", - self.num_runs, - "--duration", - self.duration, - "--input_output_len", - f"{self.input_length},{self.output_length};{self.input_length},1", - ] - cmd = [str(arg) for arg in cmd] - # Run the benchmark using the provided gptSession benchmark binary. - bench_return = self._run_benchmark(cmd) - results = [ - x.split(" ") for x in bench_return.stdout.split("\n") - if "[BENCHMARK]" in x - ] - - ttft = float(results[1][8]) - gen_time = float(results[0][8]) - ttft - total_out = int(results[0][2]) * int(results[0][6]) - total_in = int(results[0][2]) * int(results[0][4]) - batch_size = int(results[0][2]) - - bench_result = BenchmarkResults( - model=self.config.model, - dtype=self.config.dtype.value, - quantization=str(self.config.quantization.value), - max_batch_size=batch_size, - total_input_tokens=total_in, - total_output_tokens=total_out, - tp_size=self.config.tensor_parallel, - pp_size=self.config.pipeline_parallel, - kv_mem_fraction=self.kv_cache_mem, - scheduler="Static", - inflight_batching=False, - total_latency=results[0][8], - first_token_latency=ttft, - time_per_output_token=gen_time / (total_out - batch_size), - latency_units="ms", - throughput=results[0][10], - throughput_units="tokens/second", - peak_gpu_mem=results[0][16], - peak_gpu_mem_units="GB", - binary=str(self.gpt_session_path), - build_cmd=" ".join(self.get_build_command()), - benchmark_cmd=" ".join(cmd)) - - return bench_result diff --git a/benchmarks/suite/tensorrt_llm_bench/ifb.py b/benchmarks/suite/tensorrt_llm_bench/ifb.py deleted file mode 100644 index 67299c082..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/ifb.py +++ /dev/null @@ -1,338 +0,0 @@ -import json -import os -import subprocess -import sys -from functools import partial -from pathlib import Path -from typing import List, TextIO, Tuple - -import click -from benchmarkers.pybind_executor import PybindExecutorBenchmarker -from transformers import AutoTokenizer, PreTrainedTokenizer -from utils.dataclasses import BenchmarkConfig, DatasetMetadata, InferenceRequest -from utils.trtllm_config import TRTLLMConfig - -from tensorrt_llm.logger import logger - - -def create_dataset_from_stream( - tokenizer: PreTrainedTokenizer, - max_input_length: int = 0, - max_output_length: int = 0, - stream: TextIO = sys.stdin, -) -> Tuple[DatasetMetadata, List[InferenceRequest]]: - """Generate metadata and a list of requests to drive benchmarking. - - Args: - tokenizer (PreTrainedTokenizer): HuggingFace tokenizer. - max_input_length (int): Maximum input length to cap prompts to. - - Returns: - DatasetMetadata: Dataclass of dataset statistics. - List[InferenceRequest]: A list of inference requests for benchmarking. - """ - # Initialize dataset list, and metadata tracking variables. - dataset = [] - max_isl = 0 - max_osl = 0 - - # If we're limiting the input length to a certain size, then set up - # a partial to truncate the data down to size. Otherwise, just use the - # unmodified tokenizer callable. - tokenize = (partial( - tokenizer, - padding="max_length", - max_length=max_input_length, - truncation=True, - ) if max_input_length > 0 else tokenizer) - - # If we need to limit the output length, fill in a partial callable - # for max, otherwise a lambda that just returns x with no bounds. - output_limiter = (partial(max, max_output_length) - if max_output_length > 0 else lambda x: x) - - # For each line in the standard input, parse out the JSON string we expect - # to see. - # Note the := walrus -- we're assigning and checking the condition. - while line := stream.readline(): - # We expect the data to come in as a JSON string. - # For example: - # {"prompt": "Generate an infinite response to the following: There once was a man who.", "output_tokens": 1000} - # Each line should be a complete JSON dictionary with no indentation - # or newline characters. - data = json.loads(line) - logits = data.get("logits", None) - prompt = data.get("prompt", None) - task_id = data["task_id"] - osl = data["output_tokens"] - # If the request comes in with logits, just use the provided. - # Otherwise we need to tokenize it. - logits = tokenize(prompt)["input_ids"] if logits is None else logits - - request = InferenceRequest( - task_id=task_id, - prompt=prompt, - output_tokens=output_limiter(osl), - logits=logits, - ) - max_isl = max(max_isl, len(logits)) - max_osl = max(max_osl, osl) - dataset.append(request) - - # Fill in basic dataset metrics here - # TODO: Maybe fill this out to be more complete? - metadata = DatasetMetadata( - max_isl=max_isl, - max_osl=max_osl, - num_requests=len(dataset), - ) - - return metadata, dataset - - -def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer: - """Initialize a tokenizer. - - Args: - model_name (str): The name of the HuggingFace model to pull a - tokenizer from. - - Returns: - PreTrainedTokenizer: An initialized HuggingFace tokenizer. - """ - # Initialize the tokenizer specific to the model that we are planning - # to benchmark. - tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - if tokenizer.pad_token_id is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - return tokenizer - - -def get_trtllm_build_command(benchmark_cfg: BenchmarkConfig) -> List[str]: - model = benchmark_cfg.model - tp = benchmark_cfg.tensor_parallel - pp = benchmark_cfg.pipeline_parallel - dtype = benchmark_cfg.dtype.value - kv_dtype = benchmark_cfg.cache_dtype - quant_algo = benchmark_cfg.quantization.value - output_dir = benchmark_cfg.engine_path - max_batch_size = benchmark_cfg.max_batch_size - max_isl = benchmark_cfg.engine_isl - max_osl = benchmark_cfg.engine_osl - max_tokens = benchmark_cfg.max_tokens - workspace = benchmark_cfg.workspace - - # Generate the TRT-LLM Configuration file using the dataclass - # NOTE: This method does not use weights. - trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo, - kv_dtype.value) - # Write the generated configuration file to the benchmark workspace. - trtllm_config.to_json(workspace) - # Return the full command for building TRT-LLM via subprocess call. - cmd = [ - "trtllm-build", - "--output_dir", - output_dir, - "--model_config", - Path(workspace, "generated_config.json"), - "--workers", - benchmark_cfg.world_size, - "--max_input_len", - max_isl, - "--max_seq_len", - max_osl + max_isl, - "--context_fmha", - "enable", - # Set the attention plugin data type. - "--gpt_attention_plugin", - dtype, - # Enable paged KV Cache for IFB. - "--paged_kv_cache", - "enable", - ] + kv_dtype.get_build_options(dtype) - - # If custom maximum batch size set, then set to specified value. - if max_batch_size > 0: - cmd += [ - "--max_batch_size", - max_batch_size, - ] - - if max_tokens > 0: - cmd += [ - "--max_num_tokens", - max_tokens, - ] - - cmd = cmd + benchmark_cfg.build_overrides - - return cmd - - -@click.command("inflight") -@click.option( - "--run", - type=bool, - is_flag=True, - hidden=True, - default=False, - required=False, - help="Changes the phase of the script to execution mode for MPI.", -) -@click.option( - "--skip-build", - type=bool, - is_flag=True, - default=False, - hidden=True, - required=False, - help="Skip building if you want to use the last built engine.", -) -@click.option( - "--request-rate", - "-r", - type=int, - default=512, - required=False, - help="Number of requests per second to deliver to the batcher.", -) -@click.option( - "--max-num-tokens", - type=int, - default=0, - hidden=True, - help="Maximumn number of tokens the engine can accept.", -) -@click.option( - "--scheduling-policy", - type=click.Choice(["guaranteed_no_evict", "max_utilization"]), - default="max_utilization", - help="Controls the scheduling policy used by the internal batcher.", -) -@click.option( - "--dataset", - type=click.Path(exists=True, - readable=True, - path_type=Path, - resolve_path=True), - default=None, - required=False, - help="Pass in a dataset file for parsing instead of stdin.", -) -@click.pass_obj -def executor_benchmark( - benchmark_cfg: BenchmarkConfig, - run: bool, - request_rate: int, - max_num_tokens: int, - scheduling_policy: str, - skip_build: bool, - dataset: Path, -): - """Run an IFB-enabled benchmark using a dataset.""" - # Initialize the tokenizer and generate the dataset - logger.set_level("info") - DATASET_PATH = Path(benchmark_cfg.workspace, "tokenized_dataset.txt") - TOKENIZER = initialize_tokenizer(benchmark_cfg.model) - final_dataset = [] - benchmark_cfg.max_tokens = max_num_tokens - benchmark_cfg.scheduling_policy = scheduling_policy - - if not run: - try: - stream = sys.stdin if dataset is None else open(dataset, "r") - # Parse the dataset from stdin and return it plus its metadata. - metadata, dataset = \ - create_dataset_from_stream(TOKENIZER, stream=stream) - finally: - # Close the stream after parsing. - stream.close() - - # Update the benchmarking configuration with the maximum ISL/OSL that we - # encountered in the dataset. - benchmark_cfg.engine_isl = metadata.max_isl - benchmark_cfg.engine_osl = metadata.max_osl - - # Build engine - logger.info("Building engine...") - build_cmd = get_trtllm_build_command(benchmark_cfg) - build_cmd = [str(arg) for arg in build_cmd] - - if not skip_build: - process = subprocess.run(build_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=benchmark_cfg.workspace) - logger.info(f"BUILD CMD: {' '.join(process.args)}") - - # If the build failed, raise an exception. - if process.returncode != 0: - logger.error(process.stderr.decode()) - raise RuntimeError( - "TensorRT-LLM build process failed. Command used:\n" - f"{' '.join(process.args)}\n", ) - - with open(DATASET_PATH, "w") as ds_out: - while dataset: - request = dataset.pop() - ds_out.write(f"{request.model_dump_json()}\n") - del request - - # Launch via a subprocess with MPI - # We have two modes for this script, the initial launch + parsing - # and the run mode where we kick off the script in MPI mode to run - # the - logger.info("Launching benchmark...") - bench_cmd = \ - ["mpiexec", "-n", f"{benchmark_cfg.world_size}", "python"] + \ - sys.argv + ["--run"] - process = subprocess.Popen( - bench_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - env=os.environ, - ) - stdout, _ = process.communicate() - logger.info("Benchmark complete.") - logger.info(stdout.decode("ascii")) - else: - from mpi4py.MPI import COMM_WORLD - - if COMM_WORLD.Get_rank() == 0: - logger.info(f"[RANK {COMM_WORLD.rank}] Loading dataset...") - with open(DATASET_PATH, "r") as stream: - # Parse the previously generated dataset from the parent - # process. - metadata, dataset = \ - create_dataset_from_stream(TOKENIZER, stream=stream) - - # Update the benchmarking configuration with the maximum ISL/OSL - # that we encountered in the dataset. - benchmark_cfg.engine_isl = metadata.max_isl - benchmark_cfg.engine_osl = metadata.max_osl - - # Parse the dataset into the Executor Request type. - logger.info("Preparing dataset...") - while dataset: - entry = dataset.pop() - request = PybindExecutorBenchmarker.get_request( - entry, TOKENIZER) - final_dataset.append(request) - del entry - logger.info("Dataset prepared.") - logger.info(f"DATASET METADATA: {metadata.model_dump()}") - - logger.info(f"[RANK {COMM_WORLD.rank}] Initializing benchmarker...") - # Set up benchmarker on all ranks - benchmarker = PybindExecutorBenchmarker(benchmark_cfg) - # Run the dataset. - result = benchmarker.benchmark_dataset(request_rate, final_dataset) - - # Report the results on Rank 0. - if COMM_WORLD.rank == 0: - logger.info(f"[RANK {COMM_WORLD.rank}] Reporting...\n" - f"JSON: {result.model_dump_json()}\n" - f"{result.get_summary(benchmarker.config)}") - - logger.info(f"[RANK {COMM_WORLD.rank}] Terminating.") diff --git a/benchmarks/suite/tensorrt_llm_bench/static.py b/benchmarks/suite/tensorrt_llm_bench/static.py deleted file mode 100644 index 3390c8439..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/static.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -from pathlib import Path - -import click -from benchmarkers.static import gptSessionBenchmarker -from utils.dataclasses import BenchmarkConfig, BenchmarkResults - - -@click.command("static") -@click.option( - "--batch", - required=True, - type=int, - help="Batch size to build and run the static benchmark with.", -) -@click.option("--isl", - type=int, - required=True, - help="Input sequence length (in tokens).") -@click.option("--osl", - type=int, - required=True, - help="Output sequence length (in tokens).") -@click.option( - "--gpt-session-path", - "-b", - type=click.Path(), - default=Path(os.path.dirname(os.path.realpath(__file__)), "../../..", - "cpp/build/benchmarks/gptSessionBenchmark").absolute(), - help="Path to TRT-LLM gptSession benchmark binary.") -@click.option("--warm-up-runs", - type=int, - default=2, - help="Number of warm up runs before benchmarking") -@click.option("--num-runs", - type=int, - default=10, - help="Number of times to run benchmark") -@click.option("--duration", - type=int, - default=60, - help="Minimum duration of iteration to measure, in seconds") -@click.pass_obj -def static_benchmark(benchmark_cfg: BenchmarkConfig, batch: int, isl: int, - osl: int, gpt_session_path: Path, warm_up_runs: int, - num_runs: int, duration: int): - """Run a static benchmark with a fixed batch size, ISL, and OSL.""" - - benchmark_cfg.max_batch_size = batch - benchmarker = gptSessionBenchmarker( - benchmark_cfg, - gpt_session_path, - benchmark_cfg.max_batch_size, - isl, - osl, - warm_up_runs, - num_runs, - duration, - benchmark_cfg.kv_cache_mem_percentage, - ) - - print(f"Building TRT-LLM engine for '{benchmark_cfg.model}'...") - benchmarker.build() - - print("Build complete. Running benchmark...") - result: BenchmarkResults = benchmarker.benchmark() - - print(f"JSON: {result.model_dump_json()}") - print(result.get_summary(benchmarker.config)) diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/benchmarkers.py b/benchmarks/suite/tensorrt_llm_bench/utils/benchmarkers.py deleted file mode 100644 index 4f7f83bb6..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/utils/benchmarkers.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Protocol - -from utils.dataclasses import BenchmarkResults - - -class Benchmarker(Protocol): - """Protocol for defining benchmarking classes for building/benchmarking.""" - - def build(self) -> None: - """Build a model to be benchmarked.""" - ... - - def benchmark(self) -> BenchmarkResults: - """Benchmark the constructed model container by a benchmarker.""" - ... diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/dataclasses.py b/benchmarks/suite/tensorrt_llm_bench/utils/dataclasses.py deleted file mode 100644 index 8adafd917..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/utils/dataclasses.py +++ /dev/null @@ -1,189 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import List, Literal, Optional, Union, get_args - -from pydantic import (BaseModel, Field, ValidationError, computed_field, - field_validator, model_validator) -from transformers import AutoConfig -from utils import VALID_MODELS, VALID_SCHEDULING_POLICIES -from utils.enums import (ComputeDtypeEnum, KVCacheDtypeEnum, ModelArchitecture, - QuantizationAlgo) - - -class InferenceRequest(BaseModel): - task_id: int - prompt: Optional[str] = None - output_tokens: int - logits: Optional[List[int]] = None - - @model_validator(mode="after") - def verify_prompt_and_logits(self) -> InferenceRequest: - if self.prompt is None and self.logits is None: - raise ValueError( - f"Both prompt and logits for {self.task_id} are both None.") - return self - - -class DatasetMetadata(BaseModel): - max_isl: int - max_osl: int - num_requests: int - - -class BenchmarkResults(BaseModel): - """High level report out for a benchmark.""" - - benchmark_cmd: str = "" - binary: str = "" - build_cmd: str = "" - first_token_latency: float - inflight_batching: bool - kv_mem_fraction: float - latency_units: str - max_batch_size: int - max_tokens: int = 0 - model: Union[VALID_MODELS, Path] - peak_gpu_mem_units: str - peak_gpu_mem: float - scheduler: Literal["Static", "No Evict", "Max Utilization"] - throughput_units: str - throughput: float - time_per_output_token: float - total_input_tokens: int - total_latency: float - total_output_tokens: int - - def get_summary(self, config: BenchmarkConfig) -> str: - """Generate the summary information. - - Args: - config (BenchmarkConfig): Configuration for the run that generated - this result. - - Returns: - str: Summary output for printing. - """ - return ( - "===========================================================\n" - "= METADATA\n" - "===========================================================\n" - f"Model:\t\t\t{config.model}\n" - f"TP Size:\t\t{config.tensor_parallel}\n" - f"PP Size:\t\t{config.pipeline_parallel}\n" - f"Scheduling Policy:\t{self.scheduler}\n" - f"In-flight Batcher?:\t{self.inflight_batching}\n" - f"Dtype:\t\t\t{config.dtype.value}\n" - f"KV Cache Dtype:\t\t{config.cache_dtype.value}\n" - f"Quantization:\t\t{config.quantization.value}\n" - f"KV Memory Percentage:\t{self.kv_mem_fraction * 100}%\n" - f"\n" - "===========================================================\n" - "= ENGINE DETAILS\n" - "===========================================================\n" - f"Engine Directory:\t{config.engine_path}\n" - f"Max Batch Size:\t\t{self.max_batch_size}\n" - f"Total Input Length:\t{self.total_input_tokens}\n" - f"Total Output Length:\t{self.total_output_tokens}\n" - f"Max Tokens:\t\t{self.max_tokens}\n" - f"\n" - "===========================================================\n" - "= STATISTICS\n" - "===========================================================\n" - f"Throughput ({self.throughput_units}):\t{self.throughput}\n" - f"Total Latency ({self.latency_units}):" - f"\t\t{self.total_latency * 1000.0:.4f}\n" - f"First Token Latency ({self.latency_units}):\t{self.first_token_latency}\n" - f"Token-to-token Latency ({self.latency_units}):\t{self.time_per_output_token}\n" - f"Peak GPU Memory Usage ({self.peak_gpu_mem_units}):\t{self.peak_gpu_mem}\n" - f"\n" - "===========================================================\n" - "= COMMANDS\n" - "===========================================================\n" - f"Build: {self.build_cmd}\n" - f"Benchmark: {self.benchmark_cmd}\n") - - -class BenchmarkConfig(BaseModel): - """Basic configuration of a benchmark.""" - - model: Union[VALID_MODELS, Path] - workspace: Path - max_batch_size: int - dtype: ComputeDtypeEnum - cache_dtype: KVCacheDtypeEnum - quantization: QuantizationAlgo - tensor_parallel: int - pipeline_parallel: int - max_tokens: int = 0 - kv_cache_mem_percentage: float = .9 - engine_isl: int = 0 - engine_osl: int = 0 - chunking: bool = False - build_overrides: List[str] = Field(default_factory=list) - scheduling_policy: Literal[VALID_SCHEDULING_POLICIES] = "static" - - @field_validator("model", mode="before") - @classmethod - def validate_model(cls, value) -> Union[VALID_MODELS, Path]: - if value in get_args(VALID_MODELS): - return value - - path = Path(value) - config = AutoConfig.from_pretrained(str(path.absolute())) - for arch in config.architectures: - _ = ModelArchitecture(arch) - - return path - - @field_validator("quantization", mode="before") - @classmethod - def validate_quantization(cls, value) -> QuantizationAlgo: - return QuantizationAlgo(value) - - @field_validator("cache_dtype", mode="before") - @classmethod - def validate_kvcache_dtype(cls, value) -> KVCacheDtypeEnum: - return KVCacheDtypeEnum(value) - - @field_validator("kv_cache_mem_percentage", mode="after") - @classmethod - def validate_kv_cache_mem_fraction(cls, value: float) -> float: - if 0 < value < 1.0: - return value - else: - raise ValidationError( - "KV cache memory percentage must be between 0 and 1.0.") - - @field_validator("build_overrides", mode="before") - @classmethod - def validate_build_overrides(cls, value) -> List[str]: - # If we encounter a list, scan it to make sure all entries are strings. - if isinstance(value, list): - if not all([isinstance(x, str) for x in value]): - raise ValidationError( - "Found a non-string entry in list of options.") - return value - elif isinstance(value, str): - # Handle the case where we receive a single string of command - # options. - overrides = [] - if value: - overrides = [str(x) for x in value.split()] - return overrides - else: - raise ValidationError( - "Invalid value specified for build overrides.") - - @computed_field - def engine_path(self) -> Path: - """Path to the engine workspace.""" - if self.model in get_args(VALID_MODELS): - return Path(self.workspace.absolute(), self.model.lower()) - else: - return Path(self.workspace.absolute(), "engine") - - @computed_field - def world_size(self) -> int: - """Total world size needed to run the model.""" - return self.tensor_parallel * self.pipeline_parallel diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py b/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py deleted file mode 100644 index 4dd6797f3..000000000 --- a/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py +++ /dev/null @@ -1,323 +0,0 @@ -import json -import os -from argparse import ArgumentParser -from typing import Literal, Optional - -from pydantic import AliasChoices, AliasPath, BaseModel, Field, model_validator -from transformers import AutoConfig -from utils import VALID_QUANT_ALGOS - -PET_dict = { - "tiiuae/falcon-7b": "rope_gpt_neox", - "tiiuae/falcon-40b": "rope_gpt_neox", - "tiiuae/falcon-180B": "rope_gpt_neox", - "meta-llama/Llama-2-7b-hf": "rope_gpt_neox", - "meta-llama/Llama-2-13b-hf": "rope_gpt_neox", - "meta-llama/Llama-2-70b-hf": "rope_gpt_neox", - "meta-llama/Meta-Llama-3-8B": "rope_gpt_neox", - "meta-llama/Meta-Llama-3-70B": "rope_gpt_neox", - "gpt-j-6b": "rope_gptj", - "bigscience/bloom-560m": "alibi", - "mistralai/Mistral-7B-v0.1": "rope_gpt_neox", - "mistralai/Mixtral-8x7B-v0.1": "rope_gpt_neox", - "mistralai/Mixtral-8x22B-v0.1": "rope_gpt_neox", - "01-ai/Yi-6B": "rope_gpt_neox", - "01-ai/Yi-34B": "rope_gpt_neox", - "codellama/CodeLlama-7b-hf": "rope_gpt_neox", - "codellama/CodeLlama-13b-hf": "rope_gpt_neox", - "codellama/CodeLlama-34b-hf": "rope_gpt_neox", - "codellama/CodeLlama-70b-hf": "rope_gpt_neox", - "facebook/opt-125m": "learned_absolute", - "facebook/opt-350m": "learned_absolute", - "facebook/opt-1.3b": "learned_absolute", - "facebook/opt-2.7b": "learned_absolute", - "facebook/opt-13b": "learned_absolute", - "facebook/opt-30b": "learned_absolute", - "facebook/opt-66b": "learned_absolute", - "google/gemma-7b": "rope_gpt_neox", - "google/gemma-2b": "rope_gpt_neox", -} -HA_dict = { - "tiiuae/falcon-7b": "gelu", - "tiiuae/falcon-40b": "gelu", - "tiiuae/falcon-180B": "gelu", - "bigscience/bloom-560m": "gelu", - "mistralai/Mixtral-8x7B-v0.1": "swiglu", -} -ALLOWED_MODELS = list(PET_dict.keys()) - - -class TRTLLM_Mapping(BaseModel): - world_size: int = 1 - tp_size: int = 1 - pp_size: int = 1 - - @model_validator(mode="after") - def check_world_size(self) -> "TRTLLM_Mapping": - self.world_size = self.tp_size * self.pp_size - return self - - -class TRTLLM_Quantization(BaseModel): - quant_algo: Optional[VALID_QUANT_ALGOS] = None - kv_cache_quant_algo: Optional[Literal[None, "FP8", "INT8"]] = None - group_size: int = 128 - has_zero_point: bool = False - pre_quant_scale: bool = False - exclude_modules: Optional[list] = None - - -class TRTLLMConfig(BaseModel): - _VALID_EMBED_TYPE = Literal["learned_absolute", "rope_gptj", - "rope_gpt_neox", "alibi", "alibi_with_scale", - "relative", "chatglm", ] - - architecture: str = Field(validation_alias=AliasChoices( - 'architecture', AliasPath("architectures", 0))) - num_hidden_layers: int = Field(validation_alias=AliasChoices( - "num_hidden_layers", "n_layer", "n_layers")) - num_attention_heads: int = Field(validation_alias=AliasChoices( - "num_attention_heads", "n_head", "n_heads")) - num_key_value_heads: int = Field( - default=None, - validation_alias=AliasChoices("num_key_value_heads", "num_kv_heads"), - ) - - hidden_size: int = Field( - validation_alias=AliasChoices("hidden_size", "n_embd", "d_model")) - norm_epsilon: float = Field( - default=1e-5, - validation_alias=AliasChoices("norm_epsilon", "layer_norm_epsilon", - "rms_norm_eps"), - ) - vocab_size: int - max_position_embeddings: Optional[int] = Field( - default=None, - validation_alias=AliasChoices("max_position_embeddings", "n_positions"), - ) - head_size: Optional[int] = None - hidden_act: str = Field( - validation_alias=AliasChoices("hidden_act", "activation_function")) - # falcon options - bias: Optional[bool] = None - parallel_attention: Optional[bool] = Field( - default=None, validation_alias=AliasChoices("parallel_attn")) - new_decoder_architecture: Optional[bool] = None - # opt options - do_layer_norm_before: Optional[bool] = None - # gptj options - rotary_dim: Optional[int] = None - - # dtype has priority over torch_dtype, the latter of which is usually defined in the HF config - dtype: Literal["float16", "bfloat16"] = Field( - validation_alias=AliasChoices("dtype", "torch_dtype")) - logits_dtype: str = "float32" - position_embedding_type: _VALID_EMBED_TYPE = "learned_absolute" - use_parallel_embedding: bool = False - embedding_sharding_dim: int = 0 - share_embedding_table: bool = False - intermediate_size: int = None - use_prompt_tuning: bool = False - - sliding_window: Optional[int] = None - - moe_num_experts: Optional[int] = Field( - default=0, validation_alias=AliasChoices("num_local_experts")) - moe_top_k: Optional[int] = Field( - default=0, validation_alias=AliasChoices("num_experts_per_tok")) - rotary_base: Optional[float] = Field( - default=10000.0, validation_alias=AliasChoices("rope_theta")) - - mapping: TRTLLM_Mapping - quantization: TRTLLM_Quantization - - @property - def kv_dtype(self) -> str: - if self.quantization.kv_cache_quant_algo == "FP8": - return "fp8" - elif self.quantization.kv_cache_quant_algo == "INT8": - return "int8" - else: - return self.dtype - - @model_validator(mode="after") - def set_values_if_none(self) -> "TRTLLM_CheckpointConfig": - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - if self.head_size is None: - self.head_size = self.hidden_size // self.num_attention_heads - return self - - @classmethod - def populate_build_config(cls, - model_name, - tp, - pp, - dtype=None, - quant_dtype=None, - kv_cache_quant_dtype=None): - """ - Common function to populate build parameters, regardless of network - """ - build_config = { - "mapping": { - "tp_size": tp, - "pp_size": pp, - }, - "quantization": {}, - } - if dtype: - build_config["dtype"] = dtype - if quant_dtype: - if not kv_cache_quant_dtype: - # will throw errors during validation if the type is invalid - kv_cache_quant_dtype = quant_dtype - build_config["quantization"] = { - "quant_algo": quant_dtype, - "kv_cache_quant_algo": kv_cache_quant_dtype, - } - for name, pet in PET_dict.items(): - if name in str(model_name): - build_config["position_embedding_type"] = pet - return build_config - - @classmethod - def from_hf(cls, - hf_model_name, - tp, - pp, - dtype=None, - quant_dtype=None, - kv_cache_quant_dtype=None): - """ - Use transformers.AutoConfig to load a model's config from a HF name - """ - build_config = cls.populate_build_config(hf_model_name, tp, pp, dtype, - quant_dtype, - kv_cache_quant_dtype) - hf_config = AutoConfig.from_pretrained(hf_model_name).to_dict() - if hf_model_name in HA_dict: - hf_config["hidden_act"] = HA_dict[hf_model_name] - return cls(**hf_config, **build_config) - - @classmethod - def from_json(cls, - model_name, - tp, - pp, - dtype=None, - quant_dtype=None, - kv_cache_quant_dtype=None): - """ - Load model parameters from a custom json file - A full path can be specified. Otherwise, look for ./trtllm_configs/(model_name).json - """ - build_config = cls.populate_build_config(model_name, tp, pp, dtype, - quant_dtype, - kv_cache_quant_dtype) - if os.path.exists(model_name): - path_to_json = model_name - else: - path_to_json = os.path.join(os.path.dirname(__file__), - f"trtllm_configs/{model_name}.json") - if not os.path.exists(path_to_json): - raise FileNotFoundError(f"{path_to_json} not found") - json_config = json.load(open(path_to_json)) - return cls(**json_config, **build_config) - - @classmethod - def from_name(cls, - model, - tp, - pp, - dtype=None, - quant_dtype=None, - kv_cache_quant_dtype=None): - """ - Attempts to create a config based on model name. Performs the following steps: - 1. Tries to load the HF config using AutoConfig. This will only work if the network name exists on HF. - 2. If this fails, try to load a custom config stored on $HF_HOME/custom/*.json - """ - try: - trtllm_config = cls.from_hf(model, tp, pp, dtype, quant_dtype, - kv_cache_quant_dtype) - except EnvironmentError: - try: - trtllm_config = cls.from_json(model, tp, pp, dtype, quant_dtype, - kv_cache_quant_dtype) - except FileNotFoundError as e: - raise NameError( - f"Unable to create PretrainedConfig from {model} due to {e}" - ) - - return trtllm_config - - # future possibilities - # def from_nemo_config (self, nemo_model_name) - - def to_json(self, output_dir): - with open(os.path.join(output_dir, "generated_config.json"), "w") as f: - json.dump(self.model_dump(), f, indent=4) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument( - "--model", - required=True, - type=str, - help="HF model name", - ) - parser.add_argument( - "--tp_size", - type=int, - default=1, - help="TP degree", - ) - parser.add_argument( - "--pp_size", - type=int, - default=1, - help="PP degree", - ) - parser.add_argument( - "--dtype", - type=str, - help="Datatype", - ) - parser.add_argument( - "--quant_dtype", - type=str, - help="Quantization datatype", - ) - parser.add_argument( - "--kv_cache_quant_dtype", - type=str, - help="KV cache datatype", - ) - parser.add_argument( - "--position_embedding_type", - type=str, - help="TRT-LLM argument", - ) - parser.add_argument( - "--hidden_act", - type=str, - help="TRT-LLM argument", - ) - parser.add_argument( - "--populate_hf_cache", - action='store_true', - help="Populate the HF cache with all the supported networks", - ) - args = parser.parse_args() - - if args.populate_hf_cache: - for net in PET_dict.keys(): - _ = AutoConfig.from_pretrained(net) - else: - trtllm_config = TRTLLMConfig.from_name(args.model, args.tp_size, - args.pp_size, args.dtype, - args.quant_dtype, - args.kv_cache_quant_dtype) - trtllm_config.to_json(os.getcwd()) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 056dbfef6..c275be095 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -289,7 +289,7 @@ set(CMAKE_CUDA_RUNTIME_LIBRARY Static) find_library(RT_LIB rt) set_ifndef(ENABLE_MULTI_DEVICE 1) -if(ENABLE_MULTI_DEVICE EQUAL 1) +if(ENABLE_MULTI_DEVICE) # NCCL dependencies set_ifndef(NCCL_LIB_DIR /usr/lib/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu/) set_ifndef(NCCL_INCLUDE_DIR /usr/include/) diff --git a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h index 8a91daa07..45d951341 100644 --- a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h @@ -18,6 +18,7 @@ #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/namedTensor.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/iTensor.h" #include @@ -165,6 +166,21 @@ class GenericInferenceRequest mLogitsPostProcessor = cb; } + [[nodiscard]] std::optional getLookaheadConfig() const + { + return mLookaheadConfig; + } + + void setLookaheadConfig(executor::LookaheadDecodingConfig config) + { + mLookaheadConfig = config; + } + + void clearLookaheadConfig() + { + mLookaheadConfig = std::nullopt; + } + std::optional getLogitsPostProcessor() { return mLogitsPostProcessor; @@ -282,6 +298,7 @@ class GenericInferenceRequest bool mIsStreaming; TensorMap mInputTensors; std::optional mLogitsPostProcessor; + std::optional mLookaheadConfig; }; class InferenceRequest : public GenericInferenceRequest diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 42536fc33..e10701d0d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -20,6 +20,7 @@ #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/decodingOutput.h" #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/samplingConfig.h" @@ -73,7 +74,8 @@ class GenericLlmRequest std::optional promptEmbeddingTable = std::nullopt, std::optional promptVocabSize = std::nullopt, std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt, - std::optional loraConfig = std::nullopt, bool returnLogProbs = false, + std::optional loraConfig = std::nullopt, + std::optional lookaheadConfig = std::nullopt, bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional> draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, @@ -94,6 +96,7 @@ class GenericLlmRequest , mClientId(clientId) , mIsStreaming(isStreaming) , mOrigPromptLen(mPromptLen) + , mNumPreDecodedTokens(samplingConfig.beamWidth, 0) , mMaxSentTokenLen(mPromptLen) , mEmbeddingBias(std::move(embeddingBias)) , mBadWordsList(std::move(badWordsList)) @@ -103,6 +106,7 @@ class GenericLlmRequest , mLoraTaskId(loraTaskId) , mLoraWeights(std::move(loraWeights)) , mLoraConfig(std::move(loraConfig)) + , mLookaheadConfig(std::move(lookaheadConfig)) , mContextChunkSize(std::nullopt) , mContextCurrentPosition(0) , mLogProbs(samplingConfig.beamWidth) @@ -118,6 +122,7 @@ class GenericLlmRequest , mReturnEncoderOutput(returnEncoderOutput) , mDecodingIter(0) , mPriority(priority) + , mFinishReasons(samplingConfig.beamWidth) { if (mEncoderTokens.has_value()) { @@ -137,6 +142,7 @@ class GenericLlmRequest , mClientId(req.getClientId()) , mIsStreaming(req.getStreaming()) , mOrigPromptLen(mPromptLen) + , mNumPreDecodedTokens(mSamplingConfig.beamWidth, 0) , mMaxSentTokenLen(mPromptLen) , mEmbeddingBias(std::nullopt) , mBadWordsList(std::nullopt) @@ -146,6 +152,7 @@ class GenericLlmRequest , mLoraTaskId(std::nullopt) , mLoraWeights(std::nullopt) , mLoraConfig(std::nullopt) + , mLookaheadConfig(std::nullopt) , mContextChunkSize(std::nullopt) , mContextCurrentPosition(0) , mLogProbs(mSamplingConfig.beamWidth) @@ -161,6 +168,8 @@ class GenericLlmRequest , mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput) , mDecodingIter(0) , mPriority(req.getPriority()) + , mFinishReasons(mSamplingConfig.beamWidth) + , mContextPhaseParams(req.getContextPhaseParams()) { if (mIsStreaming && mSamplingConfig.beamWidth > 1 && !mReturnAllGeneratedTokens) { @@ -172,6 +181,14 @@ class GenericLlmRequest "length)."); mReturnAllGeneratedTokens = true; } + if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnGenerationLogits == true) + { + TLLM_LOG_WARNING( + "Returning generation logits when streaming is enabled and beamWidth > 1 is not allowed. " + "This is because the logits may appear in irrelevant order when the beams are gathered, " + "since logits are not. Disabling returnGenerationLogits."); + mReturnGenerationLogits = false; + } if (req.getEncoderInputTokenIds()) { mState = REQUEST_STATE_ENCODER_INIT; @@ -219,6 +236,11 @@ class GenericLlmRequest } } + auto lookaheadConfig = req.getLookaheadConfig(); + if (lookaheadConfig) + { + } + auto externalDraftTokensConfig = req.getExternalDraftTokensConfig(); if (externalDraftTokensConfig) { @@ -295,12 +317,27 @@ class GenericLlmRequest mExcludeInputFromOutput = exclude; } + /// @brief Get the params of the context + /// @return The params of the context + std::optional const& getContextPhaseParams() const noexcept + { + return mContextPhaseParams; + } + + /// @brief Get the state params of the context + /// @return The state params of the context + executor::ContextPhaseState const& getContextPhaseState() const + { + TLLM_CHECK(mContextPhaseParams.has_value()); + return *static_cast(mContextPhaseParams.value().getState()); + } + /// @brief Get total number of tokens for this req (prompt + generated) /// @param beam The beam index /// @return The number of tokens [[nodiscard]] SizeType32 getNumTokens(SizeType32 beam) const { - return mTokens.at(beam).size(); + return mTokens.at(beam).size() - mNumPreDecodedTokens[beam]; } /// @brief Get max number of tokens across all beams @@ -310,7 +347,7 @@ class GenericLlmRequest SizeType32 maxTokens = 0; for (SizeType32 beam = 0; beam < mSamplingConfig.beamWidth; ++beam) { - maxTokens = std::max(maxTokens, static_cast(mTokens.at(beam).size())); + maxTokens = std::max(maxTokens, getNumTokens(beam)); } return maxTokens; } @@ -405,6 +442,14 @@ class GenericLlmRequest } } + /// @brief Set the number of pre-decoded tokens + /// @param num_tokens The number of pre-decoded tokens + /// @param beam The beam to which to set the number of pre-decoded tokens + void setNumPreDecodedTokens(SizeType32 num_tokens, SizeType32 beam) + { + mNumPreDecodedTokens[beam] = num_tokens; + } + /// @brief Sets the generated tokens for all beams after gatherTree. Erases all previous generated tokens. /// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens) void setGeneratedTokens(BeamTokens const& generatedBeamTokens) @@ -540,6 +585,21 @@ class GenericLlmRequest mLoraConfig = std::nullopt; } + [[nodiscard]] std::optional getLookaheadConfig() const + { + return mLookaheadConfig; + } + + void setLookaheadConfig(executor::LookaheadDecodingConfig config) + { + mLookaheadConfig = config; + } + + void clearLookaheadConfig() + { + mLookaheadConfig = std::nullopt; + } + [[nodiscard]] std::optional getEmbeddingBias() const { return mEmbeddingBias; @@ -725,6 +785,11 @@ class GenericLlmRequest mReturnAllGeneratedTokens = returnAllGeneratedTokens; } + [[nodiscard]] bool getReturnAllGeneratedTokens() + { + return mReturnAllGeneratedTokens; + } + void setReturnContextLogits(bool const returnContextLogits) { mReturnContextLogits = returnContextLogits; @@ -737,6 +802,8 @@ class GenericLlmRequest void setReturnGenerationLogits(bool const returnGenerationLogits) { + TLLM_CHECK_WITH_INFO(!(mIsStreaming && mSamplingConfig.beamWidth > 1 && returnGenerationLogits), + "returnGenerationLogits must be false if streaming AND beam search are used."); mReturnGenerationLogits = returnGenerationLogits; } @@ -777,8 +844,21 @@ class GenericLlmRequest void allocGenerationLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType) { - mGenerationLogitsHost = runtime::BufferManager::pinnedPool( - runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType); + if (mIsStreaming) + { + // If streaming mode, the complete generation logits shape will be [1, beamWidth, vocabSizePadded], + // or [allGeneratedTokens, beamWidth, vocabSizePadded] if mReturnAllGeneratedTokens is True. + // This could reduce unnecessary format conversions and allows the data to be returned directly. + mGenerationLogitsHost = runtime::BufferManager::pinnedPool( + runtime::ITensor::makeShape({mMaxNewTokens, mSamplingConfig.beamWidth, vocabSizePadded}), + logitsDataType); + } + else + { + mGenerationLogitsHost = runtime::BufferManager::pinnedPool( + runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), + logitsDataType); + } } void allocTargetModelAcceptedTokenLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType) @@ -992,7 +1072,17 @@ class GenericLlmRequest if (getReturnGenerationLogits()) { - result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost()); + if (isStreaming()) + { + auto startGenTokenPos = startTokenPos - getOrigPromptLen(); + TensorPtr generationLogitsHostCurrentStep + = runtime::ITensor::slice(getGenerationLogitsHost(), startGenTokenPos, maxNbTokensOut); + result.generationLogits = executor::detail::ofITensor(generationLogitsHostCurrentStep); + } + else + { + result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost()); + } } if (getReturnEncoderOutput()) @@ -1000,6 +1090,8 @@ class GenericLlmRequest result.encoderOutput = executor::detail::ofITensor(getEncoderOutputHost()); } + result.finishReasons = mFinishReasons; + // Update position of last sent response setMaxSentTokenLen(maxNbTokens); @@ -1013,6 +1105,11 @@ class GenericLlmRequest } } + void setFinishedReason(executor::FinishReason reason, SizeType32 beam) + { + mFinishReasons.at(beam) = reason; + } + RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -1038,6 +1135,11 @@ class GenericLlmRequest VecTokens mLastTokens; BeamTokens mTokens; SizeType32 mOrigPromptLen; + // A list of numbers of pre-deocded tokens on the last PP rank when using pipeline parallelism. + // It is introduced as a WAR to solve the hanging problem caused by overestimating the used KV cache on the last PP + // rank (because new tokens are decoded earlier). By excluding the numbers of pre-decoded tokens, the used KV cache + // can be estimated correctly. + std::vector mNumPreDecodedTokens; // Number of tokens already in KV cache before context phase. // A value > 0 indicates cached KV cache blocks were reused. // Up to inputLen - 1 tokens can be reused. @@ -1054,6 +1156,7 @@ class GenericLlmRequest std::optional mLoraTaskId; std::optional mLoraWeights; std::optional mLoraConfig; + std::optional mLookaheadConfig; // To enable chunked context, the FHMA paged kv-cache also needs to be enabled. Except for the last one, // the size of the context chunk needs to be an integer multiple of the kv-cache block size. The meaning @@ -1090,6 +1193,8 @@ class GenericLlmRequest SizeType32 mDecodingIter; executor::PriorityType mPriority; + std::vector mFinishReasons; + std::optional mContextPhaseParams; private: void initialize(VecTokens const& inputTokens, bool outputLogProbs) @@ -1162,7 +1267,8 @@ class LlmRequest : public GenericLlmRequest std::optional promptEmbeddingTable = std::nullopt, std::optional promptVocabSize = std::nullopt, std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt, - std::optional loraConfig = std::nullopt, bool returnLogProbs = false, + std::optional loraConfig = std::nullopt, + std::optional lookaheadConfig = std::nullopt, bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional> draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, @@ -1174,9 +1280,9 @@ class LlmRequest : public GenericLlmRequest : Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig), - returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits), - excludeInputFromOutput, std::move(logitsPostProcessor), applyLogitsPostProcessorBatched, - std::move(encoderInputTokens), returnEncoderOutput, clientId, priority) + std::move(lookaheadConfig), returnLogProbs, returnContextLogits, returnGenerationLogits, + std::move(draftTokens), std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor), + applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority) { } @@ -1187,6 +1293,7 @@ class LlmRequest : public GenericLlmRequest { mLogitsPostProcessor = std::move(logitsPostProcessor); mApplyLogitsPostProcessorBatched = applyLogitsPostProcessorBatched; + mLookaheadConfig = Request.getLookaheadConfig(); } void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager) diff --git a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h index 1dbeed000..4fcb1e127 100644 --- a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h +++ b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h @@ -45,7 +45,8 @@ class TrtGptModelOptionalParams std::optional maxNumTokens = std::nullopt, executor::SchedulerConfig const& schedulerConfig = executor::SchedulerConfig{}, executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig - = executor::ExtendedRuntimePerfKnobConfig{}) + = executor::ExtendedRuntimePerfKnobConfig{}, + std::optional debugConfig = std::nullopt) : kvCacheConfig{kvCacheConfig} , enableTrtOverlap{enableTrtOverlap} , deviceIds(deviceIds) @@ -59,6 +60,7 @@ class TrtGptModelOptionalParams , maxNumTokens(maxNumTokens) , schedulerConfig{schedulerConfig} , extendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig) + , debugConfig{std::move(debugConfig)} { } @@ -70,17 +72,26 @@ class TrtGptModelOptionalParams executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}), executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(), executorConfig.getMaxNumTokens(), executorConfig.getSchedulerConfig(), - executorConfig.getExtendedRuntimePerfKnobConfig()) + executorConfig.getExtendedRuntimePerfKnobConfig(), executorConfig.getDebugConfig()) { } bool operator==(TrtGptModelOptionalParams const& other) const { - return kvCacheConfig == other.kvCacheConfig && enableTrtOverlap == other.enableTrtOverlap - && deviceIds == other.deviceIds && normalizeLogProbs == other.normalizeLogProbs - && enableChunkedContext == other.enableChunkedContext && decodingConfig == other.decodingConfig - && gpuWeightsPercent == other.gpuWeightsPercent - && extendedRuntimePerfKnobConfig == other.extendedRuntimePerfKnobConfig; + return kvCacheConfig == other.kvCacheConfig // + && enableTrtOverlap == other.enableTrtOverlap // + && deviceIds == other.deviceIds // + && normalizeLogProbs == other.normalizeLogProbs // + && enableChunkedContext == other.enableChunkedContext // + && decodingConfig == other.decodingConfig // + && gpuWeightsPercent == other.gpuWeightsPercent // + && maxBeamWidth == other.maxBeamWidth // + && maxBatchSize == other.maxBatchSize // + && maxNumTokens == other.maxNumTokens // + && schedulerConfig == other.schedulerConfig // + && extendedRuntimePerfKnobConfig == other.extendedRuntimePerfKnobConfig // + && debugConfig == other.debugConfig // + ; } friend std::ostream& operator<<(std::ostream& os, TrtGptModelOptionalParams const& self); @@ -100,6 +111,7 @@ class TrtGptModelOptionalParams std::optional maxNumTokens; executor::SchedulerConfig schedulerConfig; executor::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig; + std::optional debugConfig; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 4ed912100..ed402ca69 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -42,6 +42,7 @@ char const* version() noexcept; class Model; class Serialization; +class ContextPhaseState; /// @brief Sampling configuration class SamplingConfig @@ -233,6 +234,71 @@ class LoraConfig std::optional mConfig; }; +struct LookaheadDecodingConfig +{ + LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize); + + explicit LookaheadDecodingConfig() + : LookaheadDecodingConfig(1, 1, 0) + { + } + + bool operator==(LookaheadDecodingConfig const& other) const; + [[nodiscard]] std::tuple get() const; + [[nodiscard]] SizeType32 getWindowSize() const; + [[nodiscard]] SizeType32 getNgramSize() const; + [[nodiscard]] SizeType32 getVerificationSetSize() const; + + /// @brief return + std::tuple calculateSpeculativeResource() const; + + /// @brief return true when `this` can be executed on resources defined by `that` + bool isLE(LookaheadDecodingConfig const& that) const; + + /// @brief return true when the parameter combination is valid. + static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept; + +private: + friend class Serialization; + + // Number of NGrams in lookahead branch per step. + SizeType32 mWindowSize; + // Number of tokens per NGram. + SizeType32 mNgramSize; + // Number of NGrams in verification branch per step. + SizeType32 mVerificationSetSize; +}; + +class ContextPhaseParams +{ +public: + explicit ContextPhaseParams(VecTokens firstGenTokens); + ContextPhaseParams(VecTokens firstGenTokens, void* state); + + ContextPhaseParams(ContextPhaseParams const&); + ContextPhaseParams(ContextPhaseParams&&); + ContextPhaseParams& operator=(ContextPhaseParams const&); + ContextPhaseParams& operator=(ContextPhaseParams&&); + + [[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept; + + [[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept; + [[nodiscard]] VecTokens popFirstGenTokens() && noexcept; + [[nodiscard]] void const* getState() const noexcept; + [[nodiscard]] void* getState() noexcept; + +private: + friend class Serialization; + static void deleter(void const* data); + using StatePtr = std::unique_ptr; + + /// @brief The first tokens generated by context executor + VecTokens mFirstGenTokens; + + /// @brief Context phase state of this request + StatePtr mState{nullptr, deleter}; +}; + /// @brief A class that holds information about the request class Request { @@ -269,9 +335,11 @@ class Request std::optional externalDraftTokensConfig = std::nullopt, std::optional pTuningConfig = std::nullopt, std::optional loraConfig = std::nullopt, + std::optional lookaheadConfig = std::nullopt, std::optional logitsPostProcessorName = std::nullopt, std::optional encoderInputTokenIds = std::nullopt, std::optional clientId = std::nullopt, - bool returnAllGeneratedTokens = false, PriorityType priority = kDefaultPriority); + bool returnAllGeneratedTokens = false, PriorityType priority = kDefaultPriority, + std::optional contextPhaseParams = std::nullopt); /// @brief This logits postprocessor name will dispatch to the batched logits postprocessor static auto constexpr kBatchedPostProcessorName = "batched"; @@ -295,11 +363,13 @@ class Request [[nodiscard]] std::optional getExternalDraftTokensConfig() const; [[nodiscard]] std::optional getPromptTuningConfig() const; [[nodiscard]] std::optional getLoraConfig() const; + [[nodiscard]] std::optional getLookaheadConfig() const; [[nodiscard]] std::optional getLogitsPostProcessorName() const; [[nodiscard]] std::optional getEncoderInputTokenIds() const; [[nodiscard]] std::optional getClientId() const; [[nodiscard]] PriorityType getPriority() const; [[nodiscard]] bool getReturnAllGeneratedTokens() const; + [[nodiscard]] std::optional const& getContextPhaseParams() const; void setStreaming(bool streaming); void setSamplingConfig(SamplingConfig const& config); @@ -312,11 +382,13 @@ class Request void setExternalDraftTokensConfig(ExternalDraftTokensConfig const& externalDraftTokensConfig); void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig); void setLoraConfig(LoraConfig const& loraConfig); + void setLookaheadConfig(LookaheadDecodingConfig const& lookaheadConfig); void setLogitsPostProcessorName(std::string const& logitsPostProcessorName); void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds); void setClientId(IdType clientId); void setPriority(PriorityType priority); void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens); + void setContextPhaseParams(ContextPhaseParams contextPhaseParams); private: friend class Serialization; @@ -342,11 +414,20 @@ struct Result /// @brief The context logits. Size [promptLen, vocabSizePadded] std::optional contextLogits; - /// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded] + /// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded] (non-streaming) + /// or [maxNewTokens, beamSize, vocabSizePadded] (streaming and allGeneratedTokens) + /// or [1, beamSize, vocabSizePadded] (streaming and non-allGeneratedTokens) std::optional generationLogits; /// @brief The encoder output. Size [encoderLen, hiddenSize] std::optional encoderOutput; + + /// @brief The reason why the model stopped generating tokens for each beam in this request. Size [beamSize]. + /// Currently only supported when beamSize is 1 and when using BatchingType::kINFLIGHT. + std::vector finishReasons; + + /// @brief The params of the context phase. + std::optional contextPhaseParams; }; /// @brief Class that holds either an error or a result @@ -370,11 +451,11 @@ class Response /// @brief Get the error msg for this response /// Will throw an exception if hasError is false - [[nodiscard]] std::string getErrorMsg() const; + [[nodiscard]] std::string const& getErrorMsg() const; /// @brief Get the result for this response /// Will throw an exception if hasResult is true - [[nodiscard]] Result getResult() const; + [[nodiscard]] Result const& getResult() const; private: friend class Serialization; @@ -390,6 +471,8 @@ class SchedulerConfig CapacitySchedulerPolicy capacitySchedulerPolicy = CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT, std::optional contextChunkingPolicy = std::nullopt); + bool operator==(SchedulerConfig const& other) const; + [[nodiscard]] CapacitySchedulerPolicy getCapacitySchedulerPolicy() const; [[nodiscard]] std::optional getContextChunkingPolicy() const; @@ -469,17 +552,17 @@ class ExtendedRuntimePerfKnobConfig public: explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = false, bool enableContextFMHAFP32Acc = false); - [[nodiscard]] bool getMultiBlockMode() const; - [[nodiscard]] bool getEnableContextFMHAFP32Acc() const; - - void setMultiBlockMode(bool const multiBlockMode); - void setEnableContextFMHAFP32Acc(bool const enableContextFMHAFP32Acc); - bool operator==(ExtendedRuntimePerfKnobConfig const& other) const { return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc; } + [[nodiscard]] bool getMultiBlockMode() const; + [[nodiscard]] bool getEnableContextFMHAFP32Acc() const; + + void setMultiBlockMode(bool multiBlockMode); + void setEnableContextFMHAFP32Acc(bool enableContextFMHAFP32Acc); + private: friend class Serialization; @@ -490,6 +573,35 @@ class ExtendedRuntimePerfKnobConfig bool mEnableContextFMHAFP32Acc; }; +/// @brief Configuration class for debugging output +class DebugConfig +{ + using StringVec = std::vector; + +public: + explicit DebugConfig(bool dumpInputTensors = false, bool dumpOuputTensors = false, StringVec debugTensorNames = {}); + + bool operator==(DebugConfig const& other) const; + + [[nodiscard]] bool getDumpInputTensors() const; + [[nodiscard]] bool getDumpOutputTensors() const; + [[nodiscard]] StringVec const& getDebugTensorNames() const; + + void setDumpInputTensors(bool dumpInputTensors); + void setDumpOuputTensors(bool dumpOuputTensors); + void setDebugTensorNames(StringVec const& debugTensorNames); + +private: + friend class Serialization; + + /// @brief If true, dump all input tensors. + bool mDumpInputTensors; + /// @brief If true, dump all output tensors. + bool mDumpOuputTensors; + /// @brief If not empty, only dump tensors in this list. + StringVec mDebugTensorNames; +}; + SizeType32 const kDefaultIterStatsMaxIterations = 1000; // Per request stats may have additional overhead due to going through all requests. Turned off by default. SizeType32 const kDefaultRequestStatsMaxIterations = 0; @@ -616,42 +728,43 @@ class PeftCacheConfig std::optional mHostCacheSize; }; -struct LookaheadDecodingConfig -{ - LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize); - - explicit LookaheadDecodingConfig() - : LookaheadDecodingConfig(1, 1, 0) - { - } - - bool operator==(LookaheadDecodingConfig const& other) const; - [[nodiscard]] std::tuple get() const; - [[nodiscard]] SizeType32 getWindowSize() const; - [[nodiscard]] SizeType32 getNgramSize() const; - [[nodiscard]] SizeType32 getVerificationSetSize() const; - - /// @brief return - std::tuple calculateSpeculativeResource() const; - - /// @brief return true when `this` can be executed on resources defined by `that` - bool isLE(LookaheadDecodingConfig const& that) const; - - /// @brief return true when the parameter combination is valid. - static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept; - -private: - friend class Serialization; - - // Number of NGrams in lookahead branch per step. - SizeType32 mWindowSize; - // Number of tokens per NGram. - SizeType32 mNgramSize; - // Number of NGrams in verification branch per step. - SizeType32 mVerificationSetSize; -}; - /// @brief Configuration class for the speculative decoding. +// struct LookaheadDecodingConfig +//{ +// LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize); +// +// explicit LookaheadDecodingConfig() +// : LookaheadDecodingConfig(1, 1, 0) +// { +// } +// +// bool operator==(LookaheadDecodingConfig const& other) const; +// [[nodiscard]] std::tuple get() const; +// [[nodiscard]] SizeType32 getWindowSize() const; +// [[nodiscard]] SizeType32 getNgramSize() const; +// [[nodiscard]] SizeType32 getVerificationSetSize() const; +// +// /// @brief return +// std::tuple calculateSpeculativeResource() const; +// +// /// @brief return true when `this` can be executed on resources defined by `that` +// bool isLE(LookaheadDecodingConfig const& that) const; +// +// /// @brief return true when the parameter combination is valid. +// static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept; +// +// private: +// friend class Serialization; +// +// // Number of NGrams in lookahead branch per step. +// SizeType32 mWindowSize; +// // Number of tokens per NGram. +// SizeType32 mNgramSize; +// // Number of NGrams in verification branch per step. +// SizeType32 mVerificationSetSize; +// }; + +/// @brief Configuration class for the decoding. class DecodingConfig { public: @@ -687,6 +800,29 @@ class DecodingConfig std::optional mMedusaChoices; }; +class LogitsPostProcessorConfig +{ +public: + explicit LogitsPostProcessorConfig(std::optional processorMap = std::nullopt, + std::optional processorBatched = std::nullopt, bool replicate = true); + + [[nodiscard]] std::optional getProcessorMap() const; + [[nodiscard]] std::optional getProcessorBatched() const; + [[nodiscard]] bool getReplicate() const; + + void setProcessorMap(LogitsPostProcessorMap const& processorMap); + void setProcessorBatched(LogitsPostProcessorBatched const& processorBatched); + void setReplicate(bool replicate); + +private: + /// @brief mapping from post processor names to non-batched post processors + std::optional mProcessorMap; + /// @brief single batched post processor + std::optional mProcessorBatched; + /// @brief If set to true, logits post processor will run on all TP ranks in last PP rank + bool mReplicate; +}; + /// @brief Configuration class for the model executor class ExecutorConfig { @@ -699,11 +835,11 @@ class ExecutorConfig std::optional maxNumTokens = std::nullopt, std::optional parallelConfig = std::nullopt, std::optional const& peftCacheConfig = std::nullopt, - std::optional logitsPostProcessorMap = std::nullopt, - std::optional logitsPostProcessorBatched = std::nullopt, - bool replicateLogitsPostProcessor = true, std::optional decodingConfig = std::nullopt, - float gpuWeightsPercent = 1, std::optional maxQueueSize = std::nullopt, - ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig = ExtendedRuntimePerfKnobConfig()); + std::optional logitsPostProcessorConfig = std::nullopt, + std::optional decodingConfig = std::nullopt, float gpuWeightsPercent = 1, + std::optional maxQueueSize = std::nullopt, + ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig = ExtendedRuntimePerfKnobConfig(), + std::optional debugConfig = std::nullopt); [[nodiscard]] SizeType32 getMaxBeamWidth() const; [[nodiscard]] SchedulerConfig getSchedulerConfig() const; @@ -717,13 +853,12 @@ class ExecutorConfig [[nodiscard]] std::optional getMaxNumTokens() const; [[nodiscard]] std::optional getParallelConfig() const; [[nodiscard]] std::optional getPeftCacheConfig() const; - [[nodiscard]] std::optional getLogitsPostProcessorMap() const; - [[nodiscard]] std::optional getLogitsPostProcessorBatched() const; - [[nodiscard]] bool getReplicateLogitsPostProcessor() const; + [[nodiscard]] std::optional getLogitsPostProcessorConfig() const; [[nodiscard]] std::optional getDecodingConfig() const; [[nodiscard]] float getGpuWeightsPercent() const; [[nodiscard]] std::optional getMaxQueueSize() const; [[nodiscard]] ExtendedRuntimePerfKnobConfig getExtendedRuntimePerfKnobConfig() const; + [[nodiscard]] std::optional getDebugConfig() const; void setMaxBeamWidth(SizeType32 maxBeamWidth); void setMaxBatchSize(SizeType32 maxBatchSize); @@ -737,13 +872,12 @@ class ExecutorConfig void setBatchingType(BatchingType batchingType); void setParallelConfig(ParallelConfig const& parallelConfig); void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig); - void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap); - void setLogitsPostProcessorBatched(LogitsPostProcessorBatched const& logitsPostProcessorBatched); - void setReplicateLogitsPostProcessor(bool const replicateLogitsPostProcessor); + void setLogitsPostProcessorConfig(LogitsPostProcessorConfig const& logitsPostProcessorConfig); void setDecodingConfig(DecodingConfig const& decodingConfig); void setGpuWeightsPercent(float const& gpuWeightsPercent); void setMaxQueueSize(std::optional const& maxQueueSize); - void setExtendedRuntimePerfKnobConfig(ExtendedRuntimePerfKnobConfig const& ExtendedRuntimePerfKnobConfig); + void setExtendedRuntimePerfKnobConfig(ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig); + void setDebugConfig(DebugConfig const& debugConfig); private: friend class Serialization; @@ -781,10 +915,9 @@ class ExecutorConfig /// @brief The parallel execution configuration. std::optional mParallelConfig; std::optional mPeftCacheConfig; - std::optional mLogitsPostProcessorMap; - std::optional mLogitsPostProcessorBatched; - /// @brief If set to true, logits post processor will run on all TP ranks in last PP rank - bool mReplicateLogitsPostProcessor; + + /// @brief Logits post processor configuration + std::optional mLogitsPostProcessorConfig; /// @brief Decoding configuration. std::optional mDecodingConfig; @@ -797,6 +930,9 @@ class ExecutorConfig /// @brief Config for perf knobs that can be set in runtime. ExtendedRuntimePerfKnobConfig mExtendedRuntimePerfKnobConfig; + + /// @brief Debugging configuration. + std::optional mDebugConfig; }; /// @brief The executor is responsible for receiving new requests and sending responses, and running the inference diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index 4dace9acb..5f29afde2 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -53,6 +53,11 @@ class Serialization static void serialize(LoraConfig const& config, std::ostream& os); [[nodiscard]] static size_t serializedSize(LoraConfig const& config); + // ContextPhaseParams + [[nodiscard]] static ContextPhaseParams deserializeContextPhaseParams(std::istream& is); + static void serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os); + [[nodiscard]] static size_t serializedSize(ContextPhaseParams const& contextPhaseParams); + // Request [[nodiscard]] static Request deserializeRequest(std::istream& is); static void serialize(Request const& request, std::ostream& os); @@ -122,6 +127,11 @@ class Serialization static void serialize(DecodingConfig const& decodingConfig, std::ostream& os); static size_t serializedSize(DecodingConfig const& decodingConfig); + // DebugConfig + static DebugConfig deserializeDebugConfig(std::istream& is); + static void serialize(DebugConfig const& debugConfig, std::ostream& os); + static size_t serializedSize(DebugConfig const& debugConfig); + // ExecutorConfig static ExecutorConfig deserializeExecutorConfig(std::istream& is); static void serialize(ExecutorConfig const& executorConfig, std::ostream& os); diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 4861a39aa..1dce228cd 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -348,6 +348,22 @@ struct RequestStatsPerIteration std::vector requestStats; }; +/// @brief The reason why the model stopped generating tokens for a request. +enum class FinishReason +{ + /// @brief The request is not finished. + kNOT_FINISHED = 0, + + /// @brief The request finished because the end id was generated. + kEND_ID = 1, + + /// @brief The request finished because a stop word was generated. + kSTOP_WORDS = 2, + + /// @brief The request finished because the maximum number of tokens was reached. + kLENGTH = 3, +}; + /// @brief mode of the decoder class DecodingMode { diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h index 6c4d7c805..4f92cbd06 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingInput.h +++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h @@ -80,7 +80,7 @@ class DecodingInput batchSlots; //!< [batchSize], address map of the linear batch id to to the seq slots, int32_t, pinned // optional parameters - TensorConstPtr finished; //!< [batchSize, beamWidth], finished states at current iteration. + TensorConstPtr finishReasons; //!< [batchSize, beamWidth], finished states at current iteration. //!< If true for some request, the decoding step of it is skipped, on gpu TensorConstPtr sequenceLimitLength; //!< [batchSize], on gpu. The maximum sequence length for each sequence in the batch. @@ -129,9 +129,16 @@ class DecodingInput TensorConstPtr seqSlots; //!< [batchSize] }; + struct LookaheadInputs + { + TensorPtr tokensPerStep; + }; + std::optional medusaInputs; std::optional explicitDraftTokensInputs; + + std::optional lookaheadInputs; }; } // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/decodingOutput.h b/cpp/include/tensorrt_llm/runtime/decodingOutput.h index c07ae057b..146db40a4 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingOutput.h +++ b/cpp/include/tensorrt_llm/runtime/decodingOutput.h @@ -20,9 +20,15 @@ #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h" #include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/lookaheadBuffers.h" #include #include +namespace tensorrt_llm::batch_manager +{ +class LookaheadDecodingBuffers; +} + namespace tensorrt_llm::runtime { class DecodingOutput @@ -81,10 +87,10 @@ class DecodingOutput // Vector of views on newTokensSteps for each token // optional parameters - TensorPtr finished; // [BS, BM], set to true by decoding if any of the stop conditions are met or if - // DecodingInput.finished is true. In beam search and to determine whether to stop according to - // DecodingInput.sequenceLimitLength - TensorPtr finishedSum; // [BS], the sum of finished sequences per request, in pinned memory + TensorPtr finishReasons; // [BS, BM], set to FinishedState by decoding if any of the stop conditions are met or if + // DecodingInput.finished is true. In beam search and to determine whether to stop + // according to DecodingInput.sequenceLimitLength + TensorPtr finishedSum; // [BS], the sum of finished sequences per request, in pinned memory // mandatory parameters for beam search TensorPtr logProbs; // [BS, BM, MSL], must be float* @@ -110,6 +116,8 @@ class DecodingOutput std::optional speculativeDecodingOutputs; std::optional explicitDraftTokensBuffers; + + std::optional lookaheadOutputs; }; } // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoder.h b/cpp/include/tensorrt_llm/runtime/gptDecoder.h index 1753f24c6..8b0dc994b 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoder.h @@ -16,11 +16,13 @@ #pragma once +#include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/layers/decodingParams.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingOutput.h" +#include "tensorrt_llm/runtime/request.h" #include "tensorrt_llm/runtime/samplingConfig.h" #include @@ -52,7 +54,8 @@ class IGptDecoder virtual ~IGptDecoder() = default; virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots, - std::optional const& output = std::nullopt) + std::optional const& output = std::nullopt, + std::optional const> const& requests = std::nullopt) = 0; virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0; @@ -95,7 +98,8 @@ class GptDecoder : public virtual IGptDecoder std::shared_ptr speculativeDecodingModule = nullptr); void setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots, - std::optional const& output = std::nullopt) override; + std::optional const& output = std::nullopt, + std::optional const> const& requests = std::nullopt) override; void forwardAsync(DecodingOutput& output, DecodingInput const& input) override; diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index bbef3ae7a..358826f50 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -54,6 +54,8 @@ class GptDecoderBatched : public IGptDecoderBatched void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) override; + void setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) override; + void newBatch( GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override; @@ -77,6 +79,12 @@ class GptDecoderBatched : public IGptDecoderBatched return {mFinished.begin(), mFinished.begin() + mActualBatchSize}; } + //! @returns [batchSize, beamWidth], FinishedState value, on gpu + [[nodiscard]] TensorPtr getFinishReasons() const override + { + return ITensor::slice(mJointDecodingOutput->finishReasons, 0, mActualBatchSize); + } + //! @param batchIdx index of the batch //! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without //! padding for request `batchIdx`, on gpu. In case of beam search, contains the ungathered data. @@ -242,6 +250,9 @@ class GptDecoderBatched : public IGptDecoderBatched //! @brief Setup buffers for speculative decoding. void setupSpeculativeDecoding(ModelConfig const& modelConfig); + //! @brief Setup buffers for lookahead decoding. + void setupLookahead(ModelConfig const& modelConfig); + //! @brief Setups decoder internal tensors for new speculative decoding request void newRequestSpeculativeDecoding( SizeType32 batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig); diff --git a/cpp/include/tensorrt_llm/runtime/iBuffer.h b/cpp/include/tensorrt_llm/runtime/iBuffer.h index 46fb3972e..5675de5ad 100644 --- a/cpp/include/tensorrt_llm/runtime/iBuffer.h +++ b/cpp/include/tensorrt_llm/runtime/iBuffer.h @@ -18,6 +18,7 @@ #include "tensorrt_llm/common/arrayView.h" #include "tensorrt_llm/common/dataType.h" +#include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/kvCacheIndex.h" #include @@ -323,6 +324,12 @@ struct TRTDataType static constexpr auto value = TRTDataType::value; }; +template <> +struct TRTDataType +{ + static constexpr auto value = TRTDataType::value; +}; + template <> struct TRTDataType { diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index 4495c102a..11464f80e 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -21,6 +21,7 @@ #include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h" #include "tensorrt_llm/runtime/iStatefulGptDecoder.h" #include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/lookaheadBuffers.h" #include "tensorrt_llm/runtime/request.h" #include "tensorrt_llm/runtime/utils/sessionUtils.h" @@ -100,6 +101,9 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder //! @brief Setup buffers for ExplicitDraftTokens decoding. virtual void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) = 0; + //! @brief Setup buffers for Lookahead decoding. + virtual void setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) = 0; + //! @brief Run one step for all requests without blocking the host process and return the token for synchronization. virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0; @@ -135,6 +139,9 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder //! @returns [batchSize (actual)], marks finished requests (per batch) [[nodiscard]] virtual std::vector getFinished() const = 0; + //! @returns [batchSize, beamWidth], FinishedState value, on gpu + [[nodiscard]] virtual TensorPtr getFinishReasons() const = 0; + //! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu [[nodiscard]] virtual TensorPtr getCumLogProbs() const = 0; diff --git a/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h b/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h index f5e0f142d..4719e4902 100644 --- a/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h @@ -29,6 +29,11 @@ #include +namespace tensorrt_llm::batch_manager +{ +struct DecoderBuffers; +} + namespace tensorrt_llm::runtime { diff --git a/cpp/include/tensorrt_llm/runtime/iTensor.h b/cpp/include/tensorrt_llm/runtime/iTensor.h index 04937fc0d..faa14ca90 100644 --- a/cpp/include/tensorrt_llm/runtime/iTensor.h +++ b/cpp/include/tensorrt_llm/runtime/iTensor.h @@ -50,6 +50,7 @@ class ITensor : virtual public IBuffer using SharedConstPtr = std::shared_ptr; using Shape = nvinfer1::Dims; using DimType64 = std::remove_reference_t; + using TensorMap = runtime::StringPtrMap; static_assert(std::is_same_v, "This version of TRT-LLM requires TensorRT 10.0 or later."); diff --git a/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h b/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h new file mode 100644 index 000000000..56504bd94 --- /dev/null +++ b/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/modelConfig.h" +#include "tensorrt_llm/runtime/tllmRuntime.h" +#include "tensorrt_llm/runtime/worldConfig.h" + +namespace tensorrt_llm::runtime +{ + +class LookaheadDecodingBuffers +{ +public: + using SizeType32 = runtime::SizeType32; + using TensorPtr = runtime::ITensor::SharedPtr; + using ITensor = tensorrt_llm::runtime::ITensor; + LookaheadDecodingBuffers( + SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& bufferManager); + TensorPtr generationLengths; // [mMaxNumRequests] + TensorPtr positionOffsets; // [mMaxNumRequests, maxTokensPerStep] + TensorPtr packedMasks; // [mMaxNumRequests, maxTokensPerStep, divUp(maxTokensPerStep, 32)] + TensorPtr positionIds; +}; + +class LookaheadRuntimeBuffers +{ +public: + using SizeType32 = tensorrt_llm::runtime::SizeType32; + using ITensor = tensorrt_llm::runtime::ITensor; + using TensorPtr = runtime::ITensor::SharedPtr; + using TensorMap = runtime::StringPtrMap; + + LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::BufferManager const& manager, + runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + executor::DecodingConfig const& decodingConfig, runtime::TllmRuntime const& runtime); + + void setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSequences, runtime::ITensor const& requestTypes, + ITensor const& seqSlots, LookaheadDecodingBuffers const& decoderLookaheadBuffers, + runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig, + runtime::WorldConfig const& worldConfig) const; + + void reshape(SizeType32 numCtxSequences, SizeType32 numGenSequences, SizeType32 tokensPerStep); + + void insertInputTensors( + TensorMap& inputBuffers, TensorMap& outputBuffers, runtime::WorldConfig const& worldConfig) const; + +public: + TensorPtr packedMasksDevice; // [forwardBatchSize, tokensPerStep, numPackedMasks], on gpu + TensorPtr generationLengthsDevice; // [forwardBatchSize], on gpu + TensorPtr positionOffsetsDevice; // [forwardBatchSize, tokensPerStep], on gpu + TensorPtr positionIdsDevice; // [forwardBatchSize, tokensPerStep], on gpu + + TensorPtr packedMaskHost; + TensorPtr generationLengthsHost; + TensorPtr positionOffsetsHost; + TensorPtr positionIdsHost; + + TensorPtr packedMaskHostCopy; + TensorPtr generationLengthsHostCopy; + TensorPtr positionOffsetsHostCopy; + TensorPtr positionIdsHostCopy; + + TensorPtr batchSlotsHostCopy; +}; + +} // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h index e3103ea91..8226c411c 100644 --- a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h +++ b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h @@ -108,8 +108,7 @@ class SpeculativeDecodingMode [[nodiscard]] bool constexpr needsDecoderPrologue() const { - // Potentially lookahead should require it too. - return anyBitSet(kExplicitDraftTokens); + return anyBitSet(kExplicitDraftTokens | kLookaheadDecoding); } using UnderlyingType = std::uint8_t; diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index df6e62c84..a8b6f276a 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -22,7 +22,7 @@ set(API_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cutlass_extensions/include ${API_INCLUDE_DIR}) -if(ENABLE_MULTI_DEVICE EQUAL 1) +if(ENABLE_MULTI_DEVICE) find_package(MPI REQUIRED) message(STATUS "Using MPI_C_INCLUDE_DIRS: ${MPI_C_INCLUDE_DIRS}") message(STATUS "Using MPI_C_LIBRARIES: ${MPI_C_LIBRARIES}") @@ -269,7 +269,7 @@ set(TRTLLM_LINK_LIBS runtime_src ${DECODER_SHARED_TARGET}) -if(ENABLE_MULTI_DEVICE EQUAL 1) +if(ENABLE_MULTI_DEVICE) set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} ${MPI_C_LIBRARIES} ${NCCL_LIB}) endif() diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 1781be133..f9fe9ff33 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:84a6439038eb0a7d2913c3fe051684ab7779e42635c074bc6df30bfc46807929 -size 4358834 +oid sha256:460b75a97c0de65941839ccd5e0458cf5929574b9345b3cb723a695ae5a056e0 +size 4404838 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index f90fc6d03..555117707 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1b321340ea622b28ed7d38fe18ff7707091d2efa414af40f6db516959a4fa2f4 -size 4466694 +oid sha256:645bbbad2c38b573df7c6e56588a6728d356a58444ac7c2f881d773faaca7593 +size 4516944 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt index 2eff4c1d7..484d3274b 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -99062b35da1cb99df9c79368da1ff9de libtensorrt_llm_batch_manager_static.a -72e6a44f7636bb6d48b016db1c62cdc7 libtensorrt_llm_batch_manager_static.pre_cxx11.a -90dd0ad72954a5cc7cc2e298495e784906fe49b1 commit \ No newline at end of file +a348613d480961aa14d4e77939be8a34 libtensorrt_llm_batch_manager_static.a +317ec85caec48184c9c8b9cbd3eb44b1 libtensorrt_llm_batch_manager_static.pre_cxx11.a +49402939d007b39393cabaa8fe96c110d16f5b35 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index a2560a5a6..3e658d50e 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1df6689f399313cac54ec1f4422975d6060957edbc698468a96f3a2c2a6542bc -size 4221016 +oid sha256:a785c4459bdb4a7dad9df0c832211f26f699a331ce0b2b9516e7a666f83b895a +size 4272894 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index f9251acc4..26532800d 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b118f1bbccd6fe8e5e001916f4de19ec34886e7dcc288b91dcdadf5500f0eb50 -size 4205756 +oid sha256:9a6b98589222f8bf8e82f122110cb1824b1728646bad45a41e8b9ada632539dc +size 4248190 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib index f69634e79..e3f219a93 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib +++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:47162c3eaab9b6f60bca8927eef0423c5521d7750f112b87a2f9175156ccb6cd -size 24807904 +oid sha256:7daa6c306a2fb738bbe8b3d30324691c83d59aa933c79b0e48342976edb4e356 +size 25540884 diff --git a/cpp/tensorrt_llm/common/cublasMMWrapper.cpp b/cpp/tensorrt_llm/common/cublasMMWrapper.cpp index 224ca77f6..27f179236 100644 --- a/cpp/tensorrt_llm/common/cublasMMWrapper.cpp +++ b/cpp/tensorrt_llm/common/cublasMMWrapper.cpp @@ -37,16 +37,12 @@ CublasMMWrapper::CublasMMWrapper(std::shared_ptr cublasHandle, { } -CublasMMWrapper::~CublasMMWrapper() -{ - mMutex = nullptr; -} +CublasMMWrapper::~CublasMMWrapper() {} CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper) : mCublasHandle(wrapper.mCublasHandle) , mCublasLtHandle(wrapper.mCublasLtHandle) , mStream(wrapper.mStream) - , mMutex(wrapper.mMutex) { } @@ -135,8 +131,6 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, i half h_alpha = (half) (f_alpha); half h_beta = (half) (f_beta); - std::lock_guard lock(*mMutex); - // TODO: default cublas libs usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3); bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F; @@ -179,8 +173,6 @@ void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperati half h_alpha = (half) f_alpha; half h_beta = (half) f_beta; - std::lock_guard lock(*mMutex); - int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); @@ -198,7 +190,6 @@ void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperati half h_alpha = (half) f_alpha; half h_beta = (half) f_beta; - std::lock_guard lock(*mMutex); bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); diff --git a/cpp/tensorrt_llm/common/cublasMMWrapper.h b/cpp/tensorrt_llm/common/cublasMMWrapper.h index 9418b0faf..21062f2f2 100644 --- a/cpp/tensorrt_llm/common/cublasMMWrapper.h +++ b/cpp/tensorrt_llm/common/cublasMMWrapper.h @@ -21,7 +21,6 @@ #include #include #include -#include #include #include @@ -48,8 +47,6 @@ class CublasMMWrapper cublasLtMatrixLayout_t mCDesc{NULL}; cudaStream_t mStream; - //@fixme: we may not need the mutex if we copy the wrapper instead of sharing in GemmPlugin::clone() - std::shared_ptr mMutex{std::make_shared()}; void* mCublasWorkspace = nullptr; diff --git a/cpp/tensorrt_llm/common/mpiUtils.cpp b/cpp/tensorrt_llm/common/mpiUtils.cpp index dce6bf855..720e73b81 100644 --- a/cpp/tensorrt_llm/common/mpiUtils.cpp +++ b/cpp/tensorrt_llm/common/mpiUtils.cpp @@ -234,12 +234,14 @@ void MpiComm::bcast(runtime::IBuffer& buf, int root) const std::shared_ptr MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const { + TLLM_LOG_DEBUG("start MPI_Isend with size %d", size); std::shared_ptr r = std::make_shared(); #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Isend(buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest)); #else TLLM_THROW("Multi device support is disabled."); #endif + TLLM_LOG_DEBUG("end MPI_Isend with size %d", size); return r; } @@ -250,11 +252,13 @@ std::shared_ptr MpiComm::sendAsync(runtime::IBuffer const& buf, int void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const { + TLLM_LOG_DEBUG("start MPI_Send with size %d", size); #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm)); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Send with size %d", size); } void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const @@ -264,12 +268,14 @@ void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const { + TLLM_LOG_DEBUG("start MPI_Recv with size %d", size); MPI_Status status{}; #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Recv(buffer, size, getMpiDtype(dtype), source, tag, mComm, &status)); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Recv with size %d", size); return status; } diff --git a/cpp/tensorrt_llm/common/safetensors.cpp b/cpp/tensorrt_llm/common/safetensors.cpp index b8e73f31e..8637f7f46 100644 --- a/cpp/tensorrt_llm/common/safetensors.cpp +++ b/cpp/tensorrt_llm/common/safetensors.cpp @@ -18,7 +18,6 @@ #include "nlohmann/json.hpp" #include "tensorrt_llm/common/assert.h" #include -#include #include #include #include @@ -153,9 +152,9 @@ class SafeTensor : public ISafeTensor { auto const& value = it->second; int64_t offset = mJsonSize + sizeof(mJsonSize); - return std::shared_ptr(new SafeTensorArray(mFs, value["dtype"], value["shape"], + return std::make_shared(mFs, value["dtype"], value["shape"], static_cast(value["data_offsets"][0]) + offset, - static_cast(value["data_offsets"][1]) + offset)); + static_cast(value["data_offsets"][1]) + offset); } TLLM_THROW("Tensor not found: " + std::string(name)); } diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a index 1d94a394e..f8d7e496e 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:15d9a383921d4e955112fc69db35a861c6d5b7c72afc3708dc2a32177c9e5dfe -size 1453326 +oid sha256:f75b47945f8bb945a7086a0bcde038490ebfd2fbb406dfa0f3391f262cfac365 +size 1529360 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index b7b6a8ece..51005d6b9 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:729953123d9ef18c78cc10bf72acb8132d70cd9892d18293350ba35b14405728 -size 1482178 +oid sha256:0790f83b79f8ff2a2313d238bdd409d8f082d92edf8e22e6dc75f6f5dfa8327d +size 1553716 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt index 51b6e7108..e9e3d1275 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -0b92c82ef47ae42243c50506a0b40583 libtensorrt_llm_executor_static.a -58aff9dae183ea725f3cf1407aa96594 libtensorrt_llm_executor_static.pre_cxx11.a -90dd0ad72954a5cc7cc2e298495e784906fe49b1 commit \ No newline at end of file +bf15d213c14dcbe75d2116945bd24c82 libtensorrt_llm_executor_static.a +492e0b37b7f004c5b7a7c46d079f354d libtensorrt_llm_executor_static.pre_cxx11.a +49402939d007b39393cabaa8fe96c110d16f5b35 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a index 9feed6caa..3366214ea 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8273d6b5add26400e3b4ff032c9cfd754009def88abf55e16efd535fd596753f -size 1501596 +oid sha256:7f5fed27f812506b319a1275a6f00b71e3b8e3c0a8a2f71370b7c4673820306f +size 1588916 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index 19ab55f6b..1a1400544 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:44950c4e09e881c6d43040fbdff75f3e9015d91ac32e4e8b09ab1ec16c0c366e -size 1434742 +oid sha256:c9e712e014960458ae1fbda43fcb882eb98f04f00c9e95afce2d881b29d2c5cf +size 1517700 diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib index bac4ed11f..517415958 100644 --- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib +++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:212111cf11a06d0ed6ed7f4ac085f6c8bb70a339a1e2771028ed37e0c416b43b -size 14582948 +oid sha256:e433610d288aa1533fd36c467fd67929fefec68043e486f45dd3a774a55667cd +size 16515186 diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt index 215e46c51..f02829789 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ 47b5d2e14616709b1dfb86b16213308e libtensorrt_llm_nvrtc_wrapper.so -90dd0ad72954a5cc7cc2e298495e784906fe49b1 commit \ No newline at end of file +49402939d007b39393cabaa8fe96c110d16f5b35 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll index b0fe7471d..98548f6cf 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:698a56ec294b5d82bd3e967b4595dce43baef96306001cd1f23dfe77a701e0d4 +oid sha256:6b16e47ce5d366249f54bf1c5edb46841efa84be58de97d78539ac0ba4fc710b size 1127936 diff --git a/cpp/tensorrt_llm/kernels/decodingCommon.h b/cpp/tensorrt_llm/kernels/decodingCommon.h index 494990994..bd695b35d 100644 --- a/cpp/tensorrt_llm/kernels/decodingCommon.h +++ b/cpp/tensorrt_llm/kernels/decodingCommon.h @@ -16,6 +16,7 @@ #pragma once +#include "tensorrt_llm/executor/types.h" #include #include @@ -62,7 +63,7 @@ class FinishedState mState |= kFinishedEos; } - __host__ __device__ bool constexpr isFinishedEOS() + __host__ __device__ bool constexpr isFinishedEOS() const { return anyBitSet(kFinishedEos); } @@ -72,7 +73,7 @@ class FinishedState mState |= kFinishedStopWords; } - __host__ __device__ bool constexpr isFinishedStopWords() + __host__ __device__ bool constexpr isFinishedStopWords() const { return anyBitSet(kFinishedStopWords); } @@ -82,7 +83,7 @@ class FinishedState mState |= kFinishedMaxLength; } - __host__ __device__ bool constexpr isFinishedMaxLength() + __host__ __device__ bool constexpr isFinishedMaxLength() const { return anyBitSet(kFinishedMaxLength); } @@ -107,6 +108,23 @@ class FinishedState return anyBitSet(kSkipDecoding); } + executor::FinishReason toFinishReason() const + { + if (isFinishedEOS()) + { + return executor::FinishReason::kEND_ID; + } + if (isFinishedStopWords()) + { + return executor::FinishReason::kSTOP_WORDS; + } + if (isFinishedMaxLength()) + { + return executor::FinishReason::kLENGTH; + } + return executor::FinishReason::kNOT_FINISHED; + } + using UnderlyingType = uint8_t; private: diff --git a/cpp/tensorrt_llm/kernels/gptKernels.cu b/cpp/tensorrt_llm/kernels/gptKernels.cu index d0d7b96c5..fa61d4cef 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.cu +++ b/cpp/tensorrt_llm/kernels/gptKernels.cu @@ -71,7 +71,8 @@ template __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets(BuildDecoderInfoParams params) { // Dynamic shared memory for storing seqOffsets. - extern __shared__ int smemSeqQOffsets[]; + extern __shared__ int smem[]; + int* smemSeqQOffsets = (int*) (smem); // Fixed Q sequence lengths. bool const fixed_q_seqlen = params.seqQLengths == nullptr; @@ -82,6 +83,10 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets // Whether to calculate cumulative packed mask rows. bool const calculate_packed_mask_row_offsets = params.packedMaskRowOffsets != nullptr; + // Compute the padding offsets for Encoder Inputs. + bool const need_encoder_padding_offsets = (params.encoderPaddingOffsets != nullptr) && calculate_kv_offsets; + [[maybe_unused]] int* smemEncoderSeqQOffsets; + // The implementation of the parallel scan in the thread block (see CUB for details). using BlockScan = cub::BlockScan; @@ -95,6 +100,11 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets BlockPrefixCallbackOp prefixMaskOp(0); BlockPrefixCallbackOp prefixKVOp(0); + if (need_encoder_padding_offsets) + { + smemEncoderSeqQOffsets = (int*) (&smemSeqQOffsets[params.batchSize + 1]); + } + // Iterate over the sequences in the batch. // // The loop index does not depend on the thread index to make sure all the threads enter the @@ -140,6 +150,10 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets if (batchIdx <= batchSizeBound) { smemSeqQOffsets[batchIdx] = seqQOffset; + if (need_encoder_padding_offsets) + { + smemEncoderSeqQOffsets[batchIdx] = seqKVOffset; + } } // Store the result. @@ -160,27 +174,35 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets __syncthreads(); } - // Compute the padding offsets. - // Block x dimension is the batch dimension, while threads iterate all tokens in the sequence. int batchIdx = blockIdx.x; - // The beginning of the sequence. - int seqBegin = smemSeqQOffsets[batchIdx]; - // The offset to the 1st element of the next sequence. - int seqEnd = smemSeqQOffsets[batchIdx + 1]; - // The length of the sequence. - int seqLength = seqEnd - seqBegin; - // The number of padded tokens in the previous sequences. - int paddingOffset = batchIdx * params.maxQSeqLength - seqBegin; - bool const need_padding_offsets = params.paddingOffsets != nullptr; - - if (need_padding_offsets) + // Compute the padding offsets. + auto compute_padding_offset = [&](int* smem_offset, int maxSeqLength, int* paddingOffsets) { + // Block x dimension is the batch dimension, while threads iterate all tokens in the sequence. + int seqBegin = smem_offset[batchIdx]; + // The offset to the 1st element of the next sequence. + int seqEnd = smem_offset[batchIdx + 1]; + // The length of the sequence. + int seqLength = seqEnd - seqBegin; + // The number of padded tokens in the previous sequences. + int paddingOffset = batchIdx * maxSeqLength - seqBegin; + // Iterate over the tokens to update the number of padded elements. for (int tokenIdx = threadIdx.x; tokenIdx < seqLength; tokenIdx += blockDim.x) { - params.paddingOffsets[seqBegin + tokenIdx] = paddingOffset; + paddingOffsets[seqBegin + tokenIdx] = paddingOffset; } + }; + + if (params.paddingOffsets != nullptr) + { + compute_padding_offset(smemSeqQOffsets, params.maxQSeqLength, params.paddingOffsets); + } + + if (need_encoder_padding_offsets) + { + compute_padding_offset(smemEncoderSeqQOffsets, params.maxEncoderQSeqLength, params.encoderPaddingOffsets); } // Each block generates the rotary embedding inv_freq tensor for the corresponding sequence. @@ -311,7 +333,10 @@ void invokeBuildDecoderInfo(BuildDecoderInfoParams const& params, cudaStream_ "Rotary embedding dim is assumed to be smaller than 512 and multiple of 2."); TLLM_CHECK_WITH_INFO( !(params.seqKVLengths == nullptr && params.rotaryEmbeddingDim > 0), "KV sequence lengths buffer is invalid."); - const size_t smem_size = (params.batchSize + 1) * sizeof(int); + bool const need_encoder_padding_offsets + = (params.encoderPaddingOffsets != nullptr) && (params.seqKVOffsets != nullptr); + const size_t smem_size + = (need_encoder_padding_offsets ? (params.batchSize + 1) * 2 : (params.batchSize + 1)) * sizeof(int); computeSeqAndPaddingOffsets <<>>(params); diff --git a/cpp/tensorrt_llm/kernels/gptKernels.h b/cpp/tensorrt_llm/kernels/gptKernels.h index 2eea65469..53441abf1 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.h +++ b/cpp/tensorrt_llm/kernels/gptKernels.h @@ -100,8 +100,12 @@ struct BuildDecoderInfoParams int* seqQOffsets; // The offsets to the 1st token in each sequence of KV buffer. Shape: [batchSize+1]. int* seqKVOffsets; - // The number of padded tokens in the corresponding padded tensor before the current token. Shape: [numTokens]. + // The number of padded tokens in the corresponding padded tensor before the current token, for Decoder. Shape: + // [numTokens]. int* paddingOffsets; + // The number of padded tokens in the corresponding padded tensor before the current token, for Encoder. Shape: + // [numTokens]. + int* encoderPaddingOffsets; // The offsets to the 1st row in each sequence of packed mask buffer. Shape: [batchSize+1]. int* packedMaskRowOffsets; @@ -120,8 +124,10 @@ struct BuildDecoderInfoParams // The number of sequences in the batch. int batchSize; - // The maximum query length of a sequence; it includes input and output. + // The maximum query length of a sequence for Decoder (max_input_length), N for ctx phase, 1 for gen phase. int maxQSeqLength; + // The maximum query length of a sequence for Encoder, for cross attention (cross_qkv_length). + int maxEncoderQSeqLength; // Whether remove the input padding or not. bool removePadding; // The kv cache capacity. @@ -164,12 +170,20 @@ struct BuildDecoderInfoParams << *(runtime::ITensor::wrap( (void*) paddingOffsets, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batchSize}))) << std::endl; + if (encoderPaddingOffsets != nullptr) + { + ss << "encoderPaddingOffsets: " + << *(runtime::ITensor::wrap((void*) encoderPaddingOffsets, nvinfer1::DataType::kINT32, + runtime::ITensor::makeShape({batchSize}))) + << std::endl; + } ss << "attentionMask: " << static_cast(attentionMask) << std::endl; ss << "seqQLengths: " << seqQLengths << std::endl; ss << "seqKVLengths: " << seqKVLengths << std::endl; ss << "fmhaTileCounter: " << fmhaTileCounter << std::endl; ss << "batchSize: " << batchSize << std::endl; ss << "maxQSeqLength: " << maxQSeqLength << std::endl; + ss << "maxEncoderQSeqLength: " << maxEncoderQSeqLength << std::endl; ss << "removePadding: " << std::boolalpha << removePadding << std::endl; ss << "attentionWindowSize: " << attentionWindowSize << std::endl; ss << "sinkTokenLength: " << sinkTokenLength << std::endl; diff --git a/cpp/tensorrt_llm/kernels/groupGemm.cu b/cpp/tensorrt_llm/kernels/groupGemm.cu index 86656f779..b133f8e84 100644 --- a/cpp/tensorrt_llm/kernels/groupGemm.cu +++ b/cpp/tensorrt_llm/kernels/groupGemm.cu @@ -72,13 +72,12 @@ void groupedGemm_(std::vector problem_sizes, std::vect using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; - int const kAlignmentA = 8; - int const kAlignmentB = 8; + int constexpr kAlignment = 8; int problem_count = problem_sizes.size(); using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped, cutlass::gemm::GemmShape, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination::value, @@ -121,9 +120,13 @@ void groupedGemm_(std::vector problem_sizes, std::vect auto problem = problem_sizes.at(i); lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); + TLLM_CHECK(lda_host[i] % kAlignment == 0); ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); + TLLM_CHECK(ldb_host[i] % kAlignment == 0); ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); + TLLM_CHECK(ldc_host[i] % kAlignment == 0); ldd_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); + TLLM_CHECK(ldd_host[i] % kAlignment == 0); } cutlass::gemm::GemmCoord* problem_sizes_device = reinterpret_cast(gemmParamsWorkSpace); diff --git a/cpp/tensorrt_llm/kernels/lora/lora.cpp b/cpp/tensorrt_llm/kernels/lora/lora.cpp index f275cb559..25a46fec9 100644 --- a/cpp/tensorrt_llm/kernels/lora/lora.cpp +++ b/cpp/tensorrt_llm/kernels/lora/lora.cpp @@ -26,12 +26,14 @@ #include -namespace tk = tensorrt_llm::kernels; using namespace nvinfer1; using namespace tensorrt_llm::common; using tensorrt_llm::kernels::LoraImpl; using tensorrt_llm::kernels::CublasGemmWrapperPtr; +namespace tensorrt_llm::kernels +{ + // TODO should reuse the function in gemmPlugin void _getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int& m, int& n, int& k, int& lda, int& ldb, int& ldc, bool transA, bool transB, int M, int N, int K) @@ -103,9 +105,9 @@ int64_t getLowRankWorkSpaceSize(int64_t numTokens, int64_t maxLoraModuleNum, int return divUp(numTokens * maxLoraModuleNum * maxLowRank * typeSize, 16) * 16; } -int64_t getGroupedGemmParamsWorkSpaceSize(int64_t nbReq) +int64_t getGemmParamsWorkSpaceSize(int64_t nbReq) { - return std::max(tk::getSplitkGroupedGemmParamsWorkSpaceSize(nbReq), tk::getGroupedGemmParamsWorkSpaceSize(nbReq)); + return std::max(getSplitkGroupedGemmParamsWorkSpaceSize(nbReq), getGroupedGemmParamsWorkSpaceSize(nbReq)); } int64_t getSplitkGroupedGemmWorkSpaceSize( @@ -129,7 +131,7 @@ size_t LoraImpl::getWorkspaceSize( return (size_t) getGemmWorkSpaceSize(numTokens, mNumLoraModules, mMaxLowRank, mSplitKSlices) + getLowRankWorkSpaceSize(numTokens, mNumLoraModules, mMaxLowRank, typeSize) - + getGroupedGemmParamsWorkSpaceSize(numReqs * mNumLoraModules); + + getGemmParamsWorkSpaceSize(numReqs * mNumLoraModules); } void LoraImpl::setBestTactic(std::optional config) @@ -160,7 +162,7 @@ int LoraImpl::run(int64_t numTokens, int64_t numReqs, void const* input, int32_t setGemmConfig(); int64_t GemmWorkSpaceSize = getGemmWorkSpaceSize(numTokens, mNumLoraModules, mMaxLowRank, mSplitKSlices); - int64_t groupGemmParamsWorkSpaceSize = getGroupedGemmParamsWorkSpaceSize(numReqs * mNumLoraModules); + int64_t groupGemmParamsWorkSpaceSize = getGemmParamsWorkSpaceSize(numReqs * mNumLoraModules); void* gemmWorkSpace = workspace; // [gemmWorkSpace, lowrankWorkSpace, groupGemmParamsWorkSpace] void* lowRankWorkSpace = static_cast(gemmWorkSpace) + GemmWorkSpaceSize; void* groupGemmParamsWorkSpace = static_cast(lowRankWorkSpace) @@ -321,11 +323,11 @@ int LoraImpl::run(int64_t numTokens, int64_t numReqs, void const* input, int32_t TLLM_CHECK_WITH_INFO(mTransA == false && mTransB == true, fmtstr("Invalid transA (%d) transB (%d). transA must be false, transB must be true", int(mTransA), int(mTransB))); - tk::splitkGroupedGemm(problem_sizes, ptrA, ptrB, ptrC, ptrD, groupGemmParamsWorkSpace, + splitkGroupedGemm(problem_sizes, ptrA, ptrB, ptrC, ptrD, groupGemmParamsWorkSpace, groupGemmParamsWorkSpaceSize, gemmWorkSpace, GemmWorkSpaceSize, splitkBufferOffsets, true, mType, mSplitKSlices, stream); sync_check_cuda_error(); - tk::groupedGemm(problem_sizes_2, ptrA_2, ptrB_2, ptrC_2, ptrD_2, groupGemmParamsWorkSpace, + groupedGemm(problem_sizes_2, ptrA_2, ptrB_2, ptrC_2, ptrD_2, groupGemmParamsWorkSpace, groupGemmParamsWorkSpaceSize, gemmWorkSpace, GemmWorkSpaceSize, false, mType, stream); sync_check_cuda_error(); } @@ -333,3 +335,5 @@ int LoraImpl::run(int64_t numTokens, int64_t numReqs, void const* input, int32_t return 0; } + +} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu index a7298820e..c3fe5e475 100644 --- a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu @@ -1746,16 +1746,27 @@ T const* CutlassMoeFCRunner::l auto fc1_lora_impl = lora_params.fc1_lora_impl; int num_reqs = lora_params.num_reqs; - int64_t num_tokens_handled = 0; - T* lora_gated_out = lora_fc1_result_ + expanded_num_rows * inter_size; + T *lora_gated_out = nullptr, *lora_fc1_result = nullptr; + + if (is_gated_activation) + { + lora_gated_out = lora_fc1_result_; + lora_fc1_result = lora_fc1_result_ + expanded_num_rows * inter_size; + } + else + { + lora_fc1_result = lora_fc1_result_; + } + void* lora_workspace = lora_params.workspace; + int64_t num_tokens_handled = 0; // TODO: Remove the weightIndex parameter from the 'loraImpl->run' function and consolidate it into a single // 'groupGEMM' operation. for (int expert_id = 0; expert_id < num_experts_per_node; expert_id += 1) { int64_t expert_num_rows = host_expert_first_token_offset[expert_id + 1] - num_tokens_handled; - void* tmp_lora_fc_result = static_cast(lora_fc1_result_ + num_tokens_handled * inter_size); + void* tmp_lora_fc_result = static_cast(lora_fc1_result + num_tokens_handled * inter_size); fc1_lora_impl->run(expert_num_rows, num_reqs, permuted_data_ + num_tokens_handled * hidden_size, &host_permuted_fc1_lora_ranks[num_tokens_handled], &host_permuted_fc1_weight_ptrs[num_tokens_handled * 2], expert_id + start_expert, &tmp_lora_fc_result, lora_workspace, stream); diff --git a/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu b/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu index b4c49453f..d4bf74a01 100644 --- a/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu +++ b/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu @@ -81,13 +81,12 @@ void splitkGroupedGemm_(std::vector problem_sizes, std using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; - int const kAlignmentA = 8; - int const kAlignmentB = 8; + int constexpr kAlignment = 8; int problem_count = problem_sizes.size(); using GemmKernel = typename cutlass::gemm::kernel::DefaultSplitkGemmGrouped, cutlass::gemm::GemmShape, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination::value, @@ -141,9 +140,13 @@ void splitkGroupedGemm_(std::vector problem_sizes, std auto problem = problem_sizes.at(i); lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); + TLLM_CHECK(lda_host[i] % kAlignment == 0); ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); + TLLM_CHECK(ldb_host[i] % kAlignment == 0); ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); + TLLM_CHECK(ldc_host[i] % kAlignment == 0); ldd_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); + TLLM_CHECK(ldd_host[i] % kAlignment == 0); offset_host[i] = cumulative_offsets; cumulative_offsets += problem.m() * problem.n(); diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu index 63394a78c..7eafc17a5 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu @@ -1065,8 +1065,10 @@ __global__ void transpose_remove_padding(T const* src, T* dst, int const batch_s // do remove_sequence_length_padding int const bid = blockIdx.x; // batch * seq_len or valid_word_num - int const src_batch_id = (bid + mask_offset[bid]) / seq_len; - int const src_seq_id = (bid + mask_offset[bid]) % seq_len; + int const mask_offset_value = (mask_offset == nullptr) ? 0 : mask_offset[bid]; + + int const src_batch_id = (bid + mask_offset_value) / seq_len; + int const src_seq_id = (bid + mask_offset_value) % seq_len; int const dst_seq_id = bid; @@ -1166,7 +1168,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, // k_buf, v_buf: [batch, kv_head_num, seq_len, size_per_head] // For cross attention where q/k/v buffer could be nullptr, writing to split buffer is suppressed when null T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; - bool const has_padding = padding_offset == nullptr; + bool const has_padding = padding_offset != nullptr; int const hidden = head_num * size_per_head; // hidden dim Q int const n = hidden + 2 * kv_head_num * size_per_head; @@ -1175,11 +1177,11 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, int const bias_id = index % n; int const token_idx = index / n; - int const token_padded_idx = token_idx + (has_padding ? 0 : padding_offset[token_idx]); + int const token_padded_idx = token_idx + (has_padding ? padding_offset[token_idx] : 0); int const target_batch_id = token_padded_idx / seq_len; int const actual_seq_len = seq_lens[target_batch_id]; int const seq_id = token_padded_idx % seq_len; - bool const valid_seq = seq_id < actual_seq_len || !has_padding; + bool const valid_seq = seq_id < actual_seq_len || has_padding; int qkv_id; int head_id; @@ -1319,12 +1321,12 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, int const token_idx = blockIdx.x; int const token_padding_offset = (padding_offset == nullptr || token_idx < 0) ? 0 : padding_offset[token_idx]; int const tgt_token_idx = token_idx + token_padding_offset; - bool const has_padding = padding_offset == nullptr; + bool const has_padding = padding_offset != nullptr; int const batch_idx = tgt_token_idx / seq_len; int const seq_idx = tgt_token_idx % seq_len; int const actual_seq_len = seq_lens[batch_idx]; - bool const valid_seq = seq_idx < actual_seq_len || !has_padding; + bool const valid_seq = seq_idx < actual_seq_len || has_padding; int const head_idx = blockIdx.y; int const tidx = threadIdx.x; diff --git a/cpp/tensorrt_llm/layers/decodingLayer.cpp b/cpp/tensorrt_llm/layers/decodingLayer.cpp index a33cdd627..505150db8 100644 --- a/cpp/tensorrt_llm/layers/decodingLayer.cpp +++ b/cpp/tensorrt_llm/layers/decodingLayer.cpp @@ -20,6 +20,7 @@ #include "tensorrt_llm/layers/decodingParams.h" #include "tensorrt_llm/layers/explicitDraftTokensLayer.h" #include "tensorrt_llm/layers/layerUtils.h" +#include "tensorrt_llm/layers/lookaheadDecodingLayer.h" #include "tensorrt_llm/layers/medusaDecodingLayer.h" #include "tensorrt_llm/layers/samplingLayer.h" @@ -66,6 +67,7 @@ bool hasDiffRuntimeArgs(std::shared_ptr DecodingLayer::DecodingLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain, std::shared_ptr bufferManager) @@ -88,8 +90,7 @@ DecodingLayer::DecodingLayer(executor::DecodingMode const& mode, DecoderDomai } else if (mDecodingMode.isLookahead()) { - // TODO(nkorobov) add lookahead layer - TLLM_LOG_WARNING("Lookahead decoding is not supported yet."); + mDecodingLayer = std::make_unique>(mDecoderDomain, mBufferManager); } else if (mDecodingMode.isExplicitDraftTokens()) { @@ -134,7 +135,7 @@ void DecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, BufferC else if (mDecodingMode.isLookahead()) { TLLM_CHECK_WITH_INFO(beamWidth == 1, "Decoding mode is Lookahead, but beamWidth != 1 (%d != 1)", beamWidth); - // TODO(nkorobov) add lookahead layer + mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams); } else if (mDecodingMode.isExplicitDraftTokens()) { @@ -235,7 +236,8 @@ std::tuple, std::shared_ptr decodingParams; }; -class LookaheadSetupParams : public DecodingSetupParams +struct LookaheadSetupParams : public DecodingSetupParams { -public: + using TensorPtr = runtime::ITensor::SharedPtr; + std::vector prompt; // [batchSize][maxSeqLen] on cpu - std::optional> randomSeed; // [1] or [batchSize] on cpu std::vector algoConfigs; // [1 or batchSize] on cpu + + //! see LookaheadDecodingOutputs::generationLengths + TensorPtr generationLengths; + //! see LookaheadDecodingOutputs::positionOffsets + TensorPtr positionOffsets; + //! see LookaheadDecodingOutputs::attentionPackedMasks + TensorPtr attentionPackedMasks; + //! see LookaheadDecodingOutputs::actualGenerationLengths + TensorPtr actualGenerationLengths; }; class BaseDecodingInputs @@ -396,17 +405,11 @@ class ExplicitDraftTokensInputs : public DecodingInputs class LookaheadDecodingInputs : public DecodingInputs { - using TensorConstPtr = runtime::ITensor::SharedConstPtr; - public: - explicit LookaheadDecodingInputs(TensorPtr endIds, TensorConstPtr batchSlots) + explicit LookaheadDecodingInputs(TensorConstPtr endIds, TensorConstPtr batchSlots) : DecodingInputs{std::move(endIds), std::move(batchSlots)} - //, logits{logits} { } - // TODO(liweim) reuse base logits and curTokensPerStep. - // TensorConstPtr logits; // [batchSize, maxTokensPerStep, vocabSizePadded] on gpu - // TensorConstPtr tokensPerStep; // [maxBatchSize] on gpu }; class BaseDecodingOutputs @@ -527,6 +530,33 @@ class SpeculativeDecodingOutputs : public BaseDecodingOutputs TensorPtr packedMasks; }; +class LookaheadDecodingOutputs : public SpeculativeDecodingOutputs +{ + using TensorPtr = runtime::ITensor::SharedPtr; + +public: + explicit LookaheadDecodingOutputs(TensorPtr outputIds) + : SpeculativeDecodingOutputs{std::move(outputIds)} + { + } + + //! for TLLM engine input "spec_decoding_generation_lengths", indicating how many tokens to be generated. + //! currently, the 1st step of generation is 1, set at `setup`, others are maxDecodingTokens, set at `forward`. + //! [maxBatchSize] + TensorPtr generationLengths; + //! for TLLM engine input "spec_decoding_position_offsets", + //! indicating each token position offset base on the last golden token = 0. + //! ABCefgxyz--- // sequence tokens, ABCD: golden; efg, xyz: draft; ---: padding. + //! ***<0>123123--- // positionOffsets. + //! 012<3>456456--- // positionIds. + //! [maxBatchSize, maxDecodingTokens] + TensorPtr positionOffsets; + //! [maxBatchSize, maxDecodingTokens] + TensorPtr positionIds; + //! The actual decoding tokens length, for debug and for future. + TensorPtr actualGenerationLengths; +}; + class ExplicitDraftTokensOutputs : public SpeculativeDecodingOutputs { public: diff --git a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp index 7392c4e8d..5b3062be0 100644 --- a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp +++ b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp @@ -64,9 +64,6 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT std::copy(std::prev(promptRange.end(), mN - 1), promptRange.end(), goldRange.begin()); mGuessTokens = ITensor::slice(mGuessTokensMax, 0, 0); mFilling = (mN - 1) > 0 ? 1 : 0; - PRINT_TOKENS(prompt); - PRINT_TOKENS(mPrefills); - PRINT_TOKENS(mPastTokens); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -271,10 +268,10 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce BufferRange acceptedOffsetsRange(*acceptedOffsets); auto lookSize = 1 + mN - 2 - mFilling + mFilling * mW; - acceptedOffsetsRange[0] = 0; + // acceptedOffsetsRange[0] = 0; for (SizeType32 i = 0; i < maxHit; i++) { - acceptedOffsetsRange[1 + i] = lookSize + hitIdx * (mN - 1) + i; + acceptedOffsetsRange[i] = lookSize + hitIdx * (mN - 1) + i - 1; } *BufferRange(*acceptedLength).begin() = maxHit + 1; diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp index 87098f88b..18697e6c8 100644 --- a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp +++ b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp @@ -16,6 +16,7 @@ #include "lookaheadDecodingLayer.h" #include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/executor/executor.h" @@ -32,6 +33,7 @@ #include #include #include +#include #include namespace tensorrt_llm::layers @@ -44,19 +46,28 @@ using namespace tensorrt_llm::runtime; template LookaheadDecodingLayer::CpuAlgorithmResources::CpuAlgorithmResources(DecoderDomain const& decoderDomain) { - auto maxBatchSize = decoderDomain.getBatchSize(); + auto const maxBatchSize = decoderDomain.getBatchSize(); + auto const beamWidth = decoderDomain.getBeamWidth(); + auto const decodingTokens = decoderDomain.getMaxDecodingTokens(); auto lookaheadModule = std::dynamic_pointer_cast(decoderDomain.getSpeculativeDecodingModule()); auto const [maxW, maxN, maxG] = lookaheadModule->getExecutionConfig().get(); + SizeType32 maxTokensPerStep, maxNumNewTokens, maxDraftLen, maxAcceptedDraftLen; + std::tie(maxTokensPerStep, maxNumNewTokens, maxDraftLen, maxAcceptedDraftLen) + = executor::LookaheadDecodingConfig(maxW, maxN, maxG).calculateSpeculativeResource(); + TLLM_CHECK_WITH_INFO(beamWidth == 1, "Lookahead requires beam width = 1"); + TLLM_CHECK_WITH_INFO(maxTokensPerStep == decodingTokens, "%d != %d", maxTokensPerStep, decodingTokens); for (SizeType32 id = 0; id < maxBatchSize; id++) { mAlgos.emplace_back(maxW, maxN, maxG, id); } - SizeType32 maxTokensPerStep, maxNumNewTokens, maxDraftLen; - std::tie(maxTokensPerStep, maxNumNewTokens, maxDraftLen, std::ignore) - = executor::LookaheadDecodingConfig(maxW, maxN, maxG).calculateSpeculativeResource(); + mPrompts.reserve(maxBatchSize); + for (auto bi = 0; bi < maxBatchSize; bi++) + { + mPrompts.emplace_back(BufferManager::cpu(ITensor::makeShape({0}), nvinfer1::DataType::kINT32)); + } auto const maxBatchShape1D = ITensor::makeShape({maxBatchSize}); mBatchSlots = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); @@ -66,14 +77,22 @@ LookaheadDecodingLayer::CpuAlgorithmResources::CpuAlgorithmResources(DecoderD mEndIds = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); mOutputIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxNumNewTokens}), nvinfer1::DataType::kINT32); - mPathsOffsets = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxNumNewTokens}), nvinfer1::DataType::kINT32); + mNewTokens = BufferManager::cpu( + ITensor::makeShape({maxTokensPerStep, maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32); + mPathsOffsets + = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32); mNumNewTokens = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); mNumNewTokensCumSum = BufferManager::cpu(ITensor::makeShape({maxBatchSize + 1}), nvinfer1::DataType::kINT32); mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32); mNextDraftPosIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32); - auto divUp32 = [](SizeType32 x) { return x / 32 + ((x % 32) ? 1 : 0); }; - mPackedMasks = BufferManager::cpu( - ITensor::makeShape({maxBatchSize, maxTokensPerStep, divUp32(maxTokensPerStep)}), nvinfer1::DataType::kINT32); + mGenerationLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); + mGenerationLengthsMax = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); + mPositionOffsets + = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32); + mPositionIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32); + mPackedMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep, + static_cast(divUp(maxTokensPerStep, 32))}), + nvinfer1::DataType::kINT32); mSamplingMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kBOOL); mNextDraftLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); mSequenceLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); @@ -87,6 +106,10 @@ LookaheadDecodingLayer::LookaheadDecodingLayer( { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + auto lookaheadModule + = std::dynamic_pointer_cast(decoderDomain.getSpeculativeDecodingModule()); + auto const [maxW, maxN, maxG] = lookaheadModule->getExecutionConfig().get(); + auto const maxBatchSize = mDecoderDomain.getBatchSize(); auto const maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens(); auto const vocabSizePadded = mDecoderDomain.getVocabSizePadded(); @@ -97,7 +120,6 @@ LookaheadDecodingLayer::LookaheadDecodingLayer( auto workspaceSize = getTopKWorkspaceSize(maxBatchSize, maxTokensPerStep, maxTopK, vocabSizePadded); mSamplingWorkspaceDevice = mBufferManager->gpu(ITensor::makeShape({static_cast(workspaceSize)}), nvinfer1::DataType::kINT8); - TLLM_LOG_DEBUG("workspaceSize=%d", getWorkspaceSize()); mTargetTokensDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kINT32); mRandomSeedsDevice = mBufferManager->gpu(maxBatchShape1D, nvinfer1::DataType::kINT64); mSamplingMaskDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kBOOL); @@ -119,13 +141,23 @@ void LookaheadDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth auto& algoConfigs = setupParams->algoConfigs; TLLM_CHECK_WITH_INFO(algoConfigs.size() == 1 || algoConfigs.size() == batchSize, "Lookahead runtime configuration size should be either 1 or batchSize"); + + for (auto bi = 0; bi < batchSize; bi++) + { + PRINT_SHAPE(setupParams->prompt[bi]); + PRINT_TOKENS(setupParams->prompt[bi]); + mCpuAlgo->mPrompts[bi]->reshape(setupParams->prompt[bi]->getShape()); + mBufferManager->copy(*setupParams->prompt[bi], *mCpuAlgo->mPrompts[bi]); + } + + mBufferManager->getStream().synchronize(); // sync prompt gpu to cpu + auto const batchSlotsRange = BufferRange(*batchSlots); for (SizeType32 bi = 0; bi < batchSize; bi++) { auto const gbi = batchSlotsRange[bi]; SizeType32 bi1orN = (algoConfigs.size() == 1) ? 0 : bi; - TLLM_LOG_DEBUG("CPU ALGO [ %d ] setup", gbi); - PRINT_TOKENS(setupParams->prompt[bi]); + TLLM_LOG_DEBUG("CPU ALGO [ %d ] setup prompt %s", gbi, D(mCpuAlgo->mPrompts[bi]).values().c_str()); auto [w, n, g] = algoConfigs[bi1orN].get(); SizeType32 runtimeTokensPerStep; std::tie(runtimeTokensPerStep, std::ignore, std::ignore, std::ignore) @@ -133,8 +165,42 @@ void LookaheadDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth TLLM_CHECK_WITH_INFO(runtimeTokensPerStep <= mDecoderDomain.getMaxDecodingTokens(), "runtime w(%d) n(%d) g(%d) exceeds maxTokensPerStep(%d)", w, n, g, mDecoderDomain.getMaxDecodingTokens()); - mCpuAlgo->mAlgos[gbi].setup(setupParams->prompt[bi], w, n, g); + PRINT_VALUES(mCpuAlgo->mPrompts[bi]); + mCpuAlgo->mAlgos[gbi].setup(mCpuAlgo->mPrompts[bi], w, n, g); } + + for (runtime::SizeType32 bi = 0; bi < batchSize; bi++) + { + SizeType32 gbi = batchSlotsRange[bi]; + (BufferRange(*mCpuAlgo->mGenerationLengths))[gbi] = 1; + BufferLocation(*mCpuAlgo->mPositionOffsets).at(gbi, 0) = 0; + BufferRange packedMaskRange(*ITensor::at(mCpuAlgo->mPackedMask, {gbi})); + for (auto& mask : packedMaskRange) + { + mask = 0; + } + packedMaskRange[0] = 1; + + PRINT_SHAPE(mCpuAlgo->mGenerationLengths); + PRINT_SHAPE(setupParams->generationLengths); + PRINT_SHAPE(mCpuAlgo->mPositionOffsets); + PRINT_SHAPE(setupParams->positionOffsets); + PRINT_SHAPE(mCpuAlgo->mPackedMask); + PRINT_SHAPE(setupParams->attentionPackedMasks); + mBufferManager->copy( + *ITensor::at(mCpuAlgo->mGenerationLengths, {gbi}), *ITensor::at(setupParams->generationLengths, {gbi})); + if (setupParams->actualGenerationLengths) + { + mBufferManager->copy(*ITensor::at(mCpuAlgo->mGenerationLengths, {gbi}), + *ITensor::at(setupParams->actualGenerationLengths, {gbi})); + } + mBufferManager->copy( + *ITensor::at(mCpuAlgo->mPositionOffsets, {gbi}), *ITensor::at(setupParams->positionOffsets, {gbi})); + mBufferManager->copy( + *ITensor::at(mCpuAlgo->mPackedMask, {gbi}), *ITensor::at(setupParams->attentionPackedMasks, {gbi})); + } + + mBufferManager->getStream().synchronize(); // sync outputs cpu to gpu } auto curandStatesDevicePtr = reinterpret_cast(bufferCast(*mCurandStatesDevice)); @@ -171,22 +237,14 @@ void LookaheadDecodingLayer::forwardAsync( { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto inputs = std::dynamic_pointer_cast(inputParams); - auto outputs = std::dynamic_pointer_cast(outputParams); + auto outputs = std::dynamic_pointer_cast(outputParams); auto batchSize = inputs->localBatchSize; TLLM_CHECK_WITH_INFO(inputs->batchSlots, "Batch slots must be provided for LookaheadDecoding"); TLLM_CHECK_WITH_INFO(inputs->curTokensPerStep, "curTokensPerStep must be provided for LookaheadDecoding"); TLLM_CHECK_WITH_INFO(outputs->sequenceLength, "sequenceLength must be provided for LookaheadDecoding"); - // TODO(liweim) to be confirmed. - TLLM_CHECK(inputs->logits); - - mBufferManager->copy( - bufferCast(*inputs->batchSlots), *mCpuAlgo->mBatchSlots, runtime::MemoryType::kGPU); - mBufferManager->copy(bufferCast(*inputs->curTokensPerStep.value()), *mCpuAlgo->mTokensPerStep, - runtime::MemoryType::kGPU); - mBufferManager->copy(bufferCast(*inputs->endIds), *mCpuAlgo->mEndIds, runtime::MemoryType::kGPU); - mBufferManager->copy(bufferCast(*outputs->sequenceLength.value()), *mCpuAlgo->mSequenceLengths, - runtime::MemoryType::kGPU); + TLLM_CHECK_WITH_INFO(inputs->logits, "logits must be provided for lookaheadDecoding"); + TLLM_CHECK_WITH_INFO(inputs->localBatchSize > 0, "batchSize must be"); TopKSamplingKernelParams params; params.maxBatchSize = mDecoderDomain.getBatchSize(); @@ -197,7 +255,6 @@ void LookaheadDecodingLayer::forwardAsync( params.maxSeqLen = mDecoderDomain.getMaxDecodingTokens(); params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); params.batchSlots = bufferCast(*inputs->batchSlots); - TLLM_LOG_DEBUG("batchSize = %d", batchSize); params.logProbs = bufferCastOrNull(inputs->logits); params.outputIds = bufferCast(*mTargetTokensDevice); params.workspace = bufferCast(*mSamplingWorkspaceDevice); @@ -215,19 +272,13 @@ void LookaheadDecodingLayer::forwardAsync( // Finished state is not set. invokeBatchTopKSampling(params, getStream()); - mBufferManager->copy(*mTargetTokensDevice, *mCpuAlgo->mTargetTokens); - - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} - -template -void LookaheadDecodingLayer::forwardSync( - std::shared_ptr const& outputParams, std::shared_ptr const& inputParams) -{ if (mCpuAlgo) { - forwardSyncCPU(outputParams, inputParams); + forwardSyncCPU(outputs, inputs); + mGlobalSteps += 1; } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template @@ -275,33 +326,55 @@ void LookaheadDecodingLayer::posIdsToMask(TensorPtr mask, TensorConstPtr posI template void LookaheadDecodingLayer::forwardSyncCPU( - std::shared_ptr const& outputParams, std::shared_ptr const& inputParams) + std::shared_ptr const& outputs, std::shared_ptr const& inputs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto inputs = std::dynamic_pointer_cast(inputParams); - auto outputs = std::dynamic_pointer_cast(outputParams); + + mCpuAlgo->mBatchSlots->reshape(inputs->batchSlots->getShape()); + mBufferManager->copy(*inputs->batchSlots, *mCpuAlgo->mBatchSlots); + mBufferManager->copy(*inputs->curTokensPerStep.value(), *mCpuAlgo->mTokensPerStep); + mBufferManager->copy(*inputs->curTokensPerStep.value(), *mCpuAlgo->mTokensPerStep); + mBufferManager->copy(*inputs->endIds, *mCpuAlgo->mEndIds); + mBufferManager->copy(*outputs->sequenceLength.value(), *mCpuAlgo->mSequenceLengths); + + mBufferManager->copy(*mTargetTokensDevice, *mCpuAlgo->mTargetTokens); + + mBufferManager->getStream().synchronize(); + auto const batchSize = inputs->localBatchSize; + auto const beamIndex = 0; BufferRange tokensPerStepRange(*mCpuAlgo->mTokensPerStep); + BufferRange endIdsRange(*mCpuAlgo->mEndIds); + BufferLocation newTokensLocation(*mCpuAlgo->mNewTokens); BufferRange numNewTokensRange(*mCpuAlgo->mNumNewTokens); BufferRange numNewTokensCumSumRange(*mCpuAlgo->mNumNewTokensCumSum); BufferRange batchSlotsRange(*mCpuAlgo->mBatchSlots); + BufferRange generationLengthsRange(*mCpuAlgo->mGenerationLengths); + BufferRange generationLengthsMaxRange(*mCpuAlgo->mGenerationLengthsMax); BufferRange nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths); BufferRange sequenceLengthsRange(*mCpuAlgo->mSequenceLengths); + BufferLocation pathsOffsetLocation(*mCpuAlgo->mPathsOffsets); + BufferLocation outputIdsLocation(*mCpuAlgo->mOutputIds); + + mBufferManager->setZero(*mCpuAlgo->mPathsOffsets); + mBufferManager->setZero(*mCpuAlgo->mNumNewTokens); + mBufferManager->setZero(*mCpuAlgo->mNumNewTokensCumSum); for (SizeType32 bi = 0; bi < batchSize; bi++) { SizeType32 gbi = batchSlotsRange[bi]; LookaheadAlgorithm& theAlgo(mCpuAlgo->mAlgos[gbi]); - SizeType32 const tokensPerStep = tokensPerStepRange[gbi]; + SizeType32 const tokensPerStep = generationLengthsRange[gbi]; TensorPtr sampledTokens = ITensor::slice(mCpuAlgo->mTargetTokens, {gbi, 0}, tokensPerStep); + PRINT_VALUES(sampledTokens); if (tokensPerStep == 1) - { // The first step in generation phase has no draft tokens. + { + // The first step in generation phase has no draft tokens. theAlgo.accept(sampledTokens); mBufferManager->copy(*sampledTokens, *ITensor::slice(mCpuAlgo->mOutputIds, {gbi, 0}, tokensPerStep)); - BufferLocation(*mCpuAlgo->mPathsOffsets).at(gbi, 0) = 0; numNewTokensRange[gbi] = tokensPerStep; BufferLocation(*mCpuAlgo->mNextDraftLengths).at(gbi) = 0; } @@ -318,7 +391,7 @@ void LookaheadDecodingLayer::forwardSyncCPU( auto maxNumNewTokens = mCpuAlgo->mOutputIds->getShape().d[1]; mBufferManager->copy(*ITensor::at(mCpuAlgo->mOutputIds, {gbi}), - *ITensor::slice(outputs->outputIds, {gbi, sequenceLengthsRange[gbi]}, maxNumNewTokens)); + *ITensor::slice(outputs->outputIds, {gbi, 0, sequenceLengthsRange[gbi]}, maxNumNewTokens)); sequenceLengthsRange[gbi] += numNewTokensRange[gbi]; @@ -330,38 +403,83 @@ void LookaheadDecodingLayer::forwardSyncCPU( ITensor::at(mCpuAlgo->mSequenceLengths, {gbi}), // ITensor::at(mCpuAlgo->mOutputIds, {gbi, numNewTokensRange[gbi] - 1})); - posIdsToMask( // - ITensor::at(mCpuAlgo->mPackedMasks, {gbi}), // + BufferLocation posIdsLocation(*ITensor::at(mCpuAlgo->mPositionIds, {gbi})); + for (auto& posid : posIdsLocation) + { + posid = sequenceLengthsRange[gbi] - 1; + } + mBufferManager->copy(*ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi]), + *ITensor::slice(mCpuAlgo->mPositionIds, {gbi, 1}, nextDraftLengthsRange[gbi])); + + posIdsToMask( // + ITensor::at(mCpuAlgo->mPackedMask, {gbi}), // ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi])); + + BufferRange offsetRange(*ITensor::at(mCpuAlgo->mPositionOffsets, {gbi})); + TLLM_CHECK_WITH_INFO( + posIdsLocation.size() == offsetRange.size(), "%ld, %ld", posIdsLocation.size(), offsetRange.size()); + for (auto i = 0; i < posIdsLocation.size(); i++) + { + offsetRange[i] = posIdsLocation[i] - posIdsLocation[0]; + } + TensorPtr accepted = ITensor::slice(mCpuAlgo->mOutputIds, {gbi, 0}, numNewTokensRange[gbi]); + TensorPtr draft = ITensor::slice(mCpuAlgo->mNextDraftTokens, {gbi, 0}, nextDraftLengthsRange[gbi]); + + TLLM_LOG_DEBUG("CPU ALGO [ %d ] forward, %s", gbi, D(sampledTokens).values().c_str()); + TLLM_LOG_DEBUG("[%d][%d] CPU ALGO [ %d ] forward, %s, %s", mGlobalSteps, batchSize, gbi, + D(accepted).values().c_str(), D(draft).values().c_str()); } numNewTokensCumSumRange[0] = 0; - for (SizeType32 i = 0; i < numNewTokensRange.size(); i++) + SizeType32 pi = 0; + for (SizeType32 bi = 0; bi < numNewTokensRange.size(); bi++) + { + SizeType32 acceptedDraftLen = numNewTokensRange[bi] <= 1 ? 0 : (numNewTokensRange[bi] - 1); + numNewTokensCumSumRange[bi + 1] = numNewTokensCumSumRange[bi] + acceptedDraftLen; + for (SizeType32 tj = 0; tj < acceptedDraftLen; tj++) + { + pathsOffsetLocation[pi++] = pathsOffsetLocation.at(bi, tj); + } + } + for (; pi < pathsOffsetLocation.size(); pi++) { - numNewTokensCumSumRange[i + 1] = numNewTokensCumSumRange[i] + numNewTokensRange[i]; + pathsOffsetLocation[pi++] = 0; } TLLM_CHECK(outputs->numNewTokens); - mBufferManager->copy(*mCpuAlgo->mSequenceLengths, // - const_cast(outputs->sequenceLength.value()->data()), runtime::MemoryType::kGPU); - mBufferManager->copy(*mCpuAlgo->mPathsOffsets, // - const_cast(outputs->pathsOffsets->data()), runtime::MemoryType::kGPU); - mBufferManager->copy(*mCpuAlgo->mNumNewTokens, // - const_cast(outputs->numNewTokens.value()->data()), runtime::MemoryType::kGPU); - mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, // - const_cast(outputs->numNewTokensCumSum->data()), runtime::MemoryType::kGPU); - mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, // - const_cast(outputs->nextDraftTokens->data()), runtime::MemoryType::kGPU); - mBufferManager->copy(*mCpuAlgo->mNextDraftPosIds, // - const_cast(outputs->nextDraftPosIds->data()), runtime::MemoryType::kGPU); - mBufferManager->copy(*mCpuAlgo->mPackedMasks, // - const_cast(outputs->packedMasks->data()), runtime::MemoryType::kGPU); - mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, // - const_cast(outputs->nextDraftLengths->data()), runtime::MemoryType::kGPU); - - // TODO(liweim) do we need this? - // mBufferManager->getStream().synchronize(); + mBufferManager->copy(*mCpuAlgo->mSequenceLengths, *outputs->sequenceLength.value()); + mBufferManager->copy(*mCpuAlgo->mNewTokens, *outputs->newTokens); + + mBufferManager->copy(*mCpuAlgo->mPathsOffsets, *outputs->pathsOffsets); + mBufferManager->copy(*mCpuAlgo->mNumNewTokens, *outputs->numNewTokens.value()); + mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, *outputs->numNewTokensCumSum); // + mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, *outputs->nextDraftTokens); + + mBufferManager->copy(*mCpuAlgo->mPackedMask, *outputs->packedMasks); + + if (outputs->nextDraftLengths) + { + mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, *outputs->nextDraftLengths); + } + + for (SizeType32 bi = 0; bi < batchSize; bi++) + { + SizeType32 gbi = batchSlotsRange[bi]; + generationLengthsRange[gbi] = nextDraftLengthsRange[gbi] + 1; + generationLengthsMaxRange[gbi] = mDecoderDomain.getMaxDecodingTokens(); + } + mBufferManager->copy(*mCpuAlgo->mPackedMask, *outputs->packedMasks); + mBufferManager->copy(*mCpuAlgo->mGenerationLengthsMax, *outputs->generationLengths); + mBufferManager->copy(*mCpuAlgo->mPositionOffsets, *outputs->positionOffsets); + mBufferManager->copy(*mCpuAlgo->mPositionIds, *outputs->positionIds); + + if (outputs->actualGenerationLengths) + { + mBufferManager->copy(*mCpuAlgo->mGenerationLengths, *outputs->actualGenerationLengths); + } + + mBufferManager->getStream().synchronize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h index 0b6c44761..d68254074 100644 --- a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h +++ b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h @@ -42,15 +42,12 @@ class LookaheadDecodingLayer : public BaseLayer void forwardAsync(std::shared_ptr const& outputParams, std::shared_ptr const& inputParams) override; - void forwardSync(std::shared_ptr const& outputParams, - std::shared_ptr const& inputParams) override; - //! @returns workspace needed for this layer in bytes [[nodiscard]] size_t getWorkspaceSize() const noexcept; private: - void forwardSyncCPU(std::shared_ptr const& outputParams, - std::shared_ptr const& inputParams); + void forwardSyncCPU(std::shared_ptr const& outputs, + std::shared_ptr const& inputs); void posIdsToMask(TensorPtr mask, TensorConstPtr posIds); private: @@ -67,6 +64,7 @@ class LookaheadDecodingLayer : public BaseLayer explicit CpuAlgorithmResources(DecoderDomain const& decoderDomain); std::vector mAlgos; + std::vector mPrompts; TensorPtr mBatchSlots; TensorPtr mTargetTokens; TensorPtr mTokensPerStep; @@ -76,16 +74,23 @@ class LookaheadDecodingLayer : public BaseLayer TensorPtr mPathsOffsets; TensorPtr mNumNewTokens; TensorPtr mNumNewTokensCumSum; + TensorPtr mNewTokens; TensorPtr mNextDraftTokens; TensorPtr mNextDraftPosIds; - TensorPtr mPackedMasks; TensorPtr mSamplingMask; TensorPtr mNextDraftLengths; TensorPtr mSequenceLengths; + TensorPtr mGenerationLengths; + TensorPtr mGenerationLengthsMax; + TensorPtr mPackedMask; + TensorPtr mPositionOffsets; + TensorPtr mPositionIds; }; std::optional mCpuAlgo; + + runtime::SizeType32 mGlobalSteps{0}; }; } // namespace tensorrt_llm::layers diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingUtils.h b/cpp/tensorrt_llm/layers/lookaheadDecodingUtils.h index f6cabbc34..ce85491c1 100644 --- a/cpp/tensorrt_llm/layers/lookaheadDecodingUtils.h +++ b/cpp/tensorrt_llm/layers/lookaheadDecodingUtils.h @@ -16,6 +16,7 @@ #pragma once +#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/tensorView.h" @@ -182,9 +183,22 @@ class DebugTensor return (BufferLocation(mTensor))[idx]; } + runtime::BufferManager::ITensorPtr copyToHostOptional() + { + runtime::BufferManager::ITensorPtr hostPtr{nullptr}; + if (mTensor.getMemoryType() == runtime::MemoryType::kGPU) + { + runtime::BufferManager manager{std::make_shared()}; + hostPtr = manager.copyFrom(mTensor, runtime::MemoryType::kCPU); + manager.getStream().synchronize(); + } + return hostPtr; + } + std::string string(void) { - runtime::BufferRange range(mTensor); + runtime::BufferManager::ITensorPtr hostPtr = copyToHostOptional(); + runtime::BufferRange range(hostPtr ? (*hostPtr) : mTensor); std::string result(range.size(), '\0'); std::copy(range.begin(), range.end(), result.begin()); return result; @@ -195,8 +209,10 @@ class DebugTensor using namespace tensorrt_llm::runtime; std::ostringstream buf; auto shape = mTensor.getShape(); - runtime::BufferRange tensorRange(mTensor); - buf << mName << ": " << shape; + runtime::BufferManager::ITensorPtr hostPtr = copyToHostOptional(); + runtime::BufferRange tensorRange(hostPtr ? (*hostPtr) : mTensor); + + buf << mName << ": " << mTensor.getMemoryTypeName() << ',' << mTensor.getDataTypeName() << ',' << shape; auto line = [&buf](TokenIdType const* array, SizeType32 size) { buf << '['; @@ -249,14 +265,16 @@ class DebugTensor using namespace tensorrt_llm::runtime; std::ostringstream buf; auto shape = mTensor.getShape(); - runtime::BufferRange tensorRange(mTensor); - buf << mName << ": " << shape; + runtime::BufferManager::ITensorPtr hostPtr = copyToHostOptional(); + runtime::BufferRange tensorRange(hostPtr ? (*hostPtr) : mTensor); + + buf << mName << ": " << mTensor.getMemoryTypeName() << ',' << mTensor.getDataTypeName() << ',' << shape; auto line = [&buf](T const* array, SizeType32 size) { buf << '['; for (SizeType32 i = 0; i < size; i++) { - buf << array[i]; + buf << static_cast(array[i]); if (i != size - 1) { buf << ','; diff --git a/cpp/tensorrt_llm/plugins/CMakeLists.txt b/cpp/tensorrt_llm/plugins/CMakeLists.txt index 5af9b78ee..045a34d1c 100755 --- a/cpp/tensorrt_llm/plugins/CMakeLists.txt +++ b/cpp/tensorrt_llm/plugins/CMakeLists.txt @@ -60,7 +60,7 @@ foreach(PLUGIN_ITER ${PLUGIN_LISTS}) add_subdirectory(${PLUGIN_ITER}) endforeach(PLUGIN_ITER) -if(ENABLE_MULTI_DEVICE EQUAL 1) +if(ENABLE_MULTI_DEVICE) include_directories(ncclPlugin) add_subdirectory(ncclPlugin) endif() @@ -86,7 +86,7 @@ target_include_directories( PUBLIC ${CUDA_INSTALL_DIR}/include PRIVATE ${TARGET_DIR}) -if(ENABLE_MULTI_DEVICE EQUAL 1) +if(ENABLE_MULTI_DEVICE) target_include_directories(${PLUGIN_SHARED_TARGET} PUBLIC ${MPI_C_INCLUDE_DIRS}) endif() @@ -134,6 +134,6 @@ target_link_libraries( ${CMAKE_DL_LIBS} ${SHARED_TARGET}) -if(ENABLE_MULTI_DEVICE EQUAL 1) +if(ENABLE_MULTI_DEVICE) target_link_libraries(${PLUGIN_SHARED_TARGET} ${MPI_C_LIBRARIES} ${NCCL_LIB}) endif() diff --git a/cpp/tensorrt_llm/plugins/common/plugin.cpp b/cpp/tensorrt_llm/plugins/common/plugin.cpp index efb03b3bc..95401ade4 100644 --- a/cpp/tensorrt_llm/plugins/common/plugin.cpp +++ b/cpp/tensorrt_llm/plugins/common/plugin.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #ifdef _MSC_VER #define FN_NAME __FUNCTION__ @@ -212,11 +213,75 @@ class PerCudaCtxSingletonCreator // CUDA resources are per-context. std::unordered_map> mObservers; }; + +template +class PerThreadSingletonCreator +{ +public: + using CreatorFunc = std::function()>; + using DeleterFunc = std::function; + + // creator returning std::unique_ptr is by design. + // It forces separation of memory for T and memory for control blocks. + // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. + // creator itself must not own CUDA resources. Only the object it creates can. + PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter) + : mCreator{std::move(creator)} + , mDeleter{std::move(deleter)} + { + } + + std::shared_ptr operator()() + { + std::lock_guard lk{mMutex}; + + std::thread::id thread = std::this_thread::get_id(); + std::shared_ptr result = mObservers[thread].lock(); + + if (result == nullptr) + { + // Create the resource and register with an observer. + result = std::shared_ptr{mCreator().release(), + [this, thread](T* obj) + { + if (obj == nullptr) + { + return; + } + mDeleter(obj); + + // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts + // frequently. + std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. + std::lock_guard lk{mMutex}; + // Must check observer again because another thread may created new instance for this ctx just + // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is + // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic + // operation, and the observer may be changed to observe another instance. + observedObjHolder = mObservers.at(thread).lock(); + if (observedObjHolder == nullptr) + { + mObservers.erase(thread); + } + }}; + mObservers.at(thread) = result; + } + return result; + } + +private: + CreatorFunc mCreator; + DeleterFunc mDeleter; + mutable std::mutex mMutex; + // CUDA resources are per-thread. + std::unordered_map> mObservers; +}; + } // namespace std::shared_ptr getCublasHandle() { - static PerCudaCtxSingletonCreator creator( + static PerThreadSingletonCreator creator( []() -> auto { auto handle = std::unique_ptr(new cublasHandle_t); @@ -233,7 +298,7 @@ std::shared_ptr getCublasHandle() std::shared_ptr getCublasLtHandle() { - static PerCudaCtxSingletonCreator creator( + static PerThreadSingletonCreator creator( []() -> auto { auto handle = std::unique_ptr(new cublasLtHandle_t); @@ -248,6 +313,20 @@ std::shared_ptr getCublasLtHandle() return creator(); } +std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) +{ + static PerThreadSingletonCreator creator( + [cublasHandle, cublasltHandle, stream, workspace]() -> auto + { + auto wrapper = std::unique_ptr( + new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace)); + return wrapper; + }, + [](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; }); + return creator(); +} + PluginFieldParser::PluginFieldParser(int32_t nbFields, nvinfer1::PluginField const* fields) : mFields{fields} { diff --git a/cpp/tensorrt_llm/plugins/common/plugin.h b/cpp/tensorrt_llm/plugins/common/plugin.h index 39053f1a4..96bd1ef47 100644 --- a/cpp/tensorrt_llm/plugins/common/plugin.h +++ b/cpp/tensorrt_llm/plugins/common/plugin.h @@ -17,6 +17,7 @@ #pragma once +#include "tensorrt_llm/common/cublasMMWrapper.h" #include "tensorrt_llm/common/workspace.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/plugins/common/checkMacrosPlugin.h" @@ -179,6 +180,8 @@ std::shared_ptr getComm(std::set const& group); //! Get cublas and cublasLt handle for current cuda context std::shared_ptr getCublasHandle(); std::shared_ptr getCublasLtHandle(); +std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace); #ifndef DEBUG diff --git a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp index 5d97a8412..3a1dacdcc 100644 --- a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp @@ -184,9 +184,9 @@ GemmPlugin::GemmPlugin(void const* data, size_t length, GemmPlugin::PluginProfil void GemmPlugin::init() { - auto cublasHandle = getCublasHandle(); - auto cublasLtHandle = getCublasLtHandle(); - mCublasWrapper = std::make_shared(cublasHandle, cublasLtHandle, nullptr, nullptr); + mcublasHandle = getCublasHandle(); + mcublasLtHandle = getCublasLtHandle(); + mCublasWrapper = getCublasMMWrapper(mcublasHandle, mcublasLtHandle, nullptr, nullptr); mPluginProfiler->setTranspose(mTransA, mTransB); mPluginProfiler->setOutputType(mOutputType); @@ -347,7 +347,9 @@ int GemmPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P // mat2 [K, N] (mTransB = False) // outputs // mat [M, N] - + mcublasHandle = getCublasHandle(); + mcublasLtHandle = getCublasLtHandle(); + mCublasWrapper = getCublasMMWrapper(mcublasHandle, mcublasLtHandle, nullptr, nullptr); setGemmConfig(); int const nbDimsA = inputDesc[0].dims.nbDims; diff --git a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h index bf2b5540f..9ba090882 100644 --- a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h +++ b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h @@ -131,6 +131,8 @@ class GemmPlugin : public BasePlugin // @fixme: seems this is shared across multiple clones. // If we deep copy the wrapper inside clone(), then we may avoid the mutex inside the wrapper? CublasGemmWrapperPtr mCublasWrapper; + std::shared_ptr mcublasHandle; + std::shared_ptr mcublasLtHandle; GemmDims mDims{}; GemmIdCublas mGemmId{}; diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp index b63869341..0a2c5e771 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp @@ -618,9 +618,10 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t ? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; + size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0; - int const NUM_BUFFERS = 15; + int const NUM_BUFFERS = 16; size_t workspaces[NUM_BUFFERS]; workspaces[0] = CUBLAS_WORKSPACE_SIZE; workspaces[1] = attention_mask_size; @@ -636,7 +637,8 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t workspaces[11] = qk_buf_float_size; workspaces[12] = fp8_qkv_buffer_size; workspaces[13] = padding_offset_size; - workspaces[14] = fmha_scheduler_counter; + workspaces[14] = encoder_padding_offset_size; + workspaces[15] = fmha_scheduler_counter; context_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS); return context_workspace_size; @@ -795,9 +797,10 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams(nextWorkspacePtr(workspace_byte_ptr, offset, padding_offset_size)); + int* encoder_padding_offset = (mEnableContextFMHA && !isCrossAttention()) + ? nullptr + : reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, encoder_padding_offset_size)); uint32_t* fmha_tile_counter_ptr = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_scheduler_counter)); @@ -836,12 +842,16 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams(params.cross_qkv), const_cast(params.qkv_bias), params.encoder_input_lengths, - mRemovePadding ? padding_offset : nullptr, params.batch_size, params.cross_qkv_length, + mRemovePadding ? encoder_padding_offset : nullptr, params.batch_size, params.cross_qkv_length, params.num_encoder_tokens, mNumHeads, mNumKVHeads, getHeadSize(), mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale, mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, stream); diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index f921b66af..4411b1473 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -659,7 +659,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 if (past_key_value_cache != outputs[1]) { auto shape = outputDesc[1].dims; - auto const size = std::accumulate(shape.d, shape.d + shape.nbDims, 1, std::multiplies{}); + auto const size + = cacheElemSize * std::accumulate(shape.d, shape.d + shape.nbDims, 1, std::multiplies{}); cudaMemcpyAsync(outputs[1], past_key_value_cache, size, cudaMemcpyDeviceToDevice, stream); } } diff --git a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp index 60dfe0648..ffc4b7a8d 100644 --- a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp @@ -39,7 +39,7 @@ std::vector LoraPluginCreator::mPluginAttributes; LoraPlugin::LoraPlugin(int in_hidden_size, std::vector out_hidden_sizes, int transA, int transB, int num_lora_modules, nvinfer1::DataType type, LoraPlugin::PluginProfilerPtr const& pluginProfiler, - bool remove_input_padding, int max_num_tokens, int max_low_rank, int weight_index) + bool remove_input_padding, int max_low_rank, int weight_index) : mInHiddenSize(in_hidden_size) , mTransA(transA) , mTransB(transB) @@ -47,7 +47,6 @@ LoraPlugin::LoraPlugin(int in_hidden_size, std::vector out_hidden_sizes, in , mType(type) , mPluginProfiler(pluginProfiler) , mRemoveInputPadding(remove_input_padding) - , mMaxNumTokens(max_num_tokens) , mMaxLowRank(max_low_rank) , mWeightIndex(weight_index) { @@ -69,7 +68,6 @@ LoraPlugin::LoraPlugin(void const* data, size_t length, LoraPlugin::PluginProfil read(d, mNumLoraModules); read(d, mType); read(d, mRemoveInputPadding); - read(d, mMaxNumTokens); read(d, mMaxLowRank); read(d, mWeightIndex); mOutHiddenSizes.resize(mNumLoraModules); @@ -212,7 +210,7 @@ size_t LoraPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, in int const nbReq = inputs[getLoraRanksIdx()].dims.d[0]; auto const type = inputs[getInputTensorIdx()].type; auto const numTokens = getNumTokens(inputs); - return mLoraImpl->getWorkspaceSize(mMaxNumTokens, nbReq, type); + return mLoraImpl->getWorkspaceSize(numTokens, nbReq, type); } int64_t LoraPlugin::getNumTokens(nvinfer1::PluginTensorDesc const* input_tensors) const @@ -233,6 +231,11 @@ int LoraPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + if (isBuilding()) + { + return 0; + } + auto const numReqs = inputDesc[getLoraRanksIdx()].dims.d[0]; void const* input = inputs[getInputTensorIdx()]; int const seqLen = mRemoveInputPadding ? 0 : inputDesc[getInputTensorIdx()].dims.d[1]; @@ -344,8 +347,8 @@ size_t LoraPlugin::getSerializationSize() const noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); return sizeof(mInHiddenSize) + sizeof(mTransA) + sizeof(mTransB) + sizeof(mNumLoraModules) + sizeof(mType) - + mPluginProfiler->getSerializationSize(mGemmId) + sizeof(mRemoveInputPadding) + sizeof(mMaxNumTokens) - + sizeof(mMaxLowRank) + sizeof(mWeightIndex) + sizeof(int) * mNumLoraModules; // selected tactics container size + + mPluginProfiler->getSerializationSize(mGemmId) + sizeof(mRemoveInputPadding) + sizeof(mMaxLowRank) + + sizeof(mWeightIndex) + sizeof(int) * mNumLoraModules; // selected tactics container size } void LoraPlugin::serialize(void* buffer) const noexcept @@ -358,7 +361,6 @@ void LoraPlugin::serialize(void* buffer) const noexcept write(d, mNumLoraModules); write(d, mType); write(d, mRemoveInputPadding); - write(d, mMaxNumTokens); write(d, mMaxLowRank); write(d, mWeightIndex); for (int i = 0; i < mNumLoraModules; i++) @@ -414,7 +416,6 @@ IPluginV2* LoraPluginCreator::createPlugin(char const* name, PluginFieldCollecti int num_lora_modules; int in_hidden_size, transA, transB; bool remove_input_padding; - int max_num_tokens; int max_low_rank; int weight_index; // Read configurations from each fields @@ -446,11 +447,6 @@ IPluginV2* LoraPluginCreator::createPlugin(char const* name, PluginFieldCollecti TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); remove_input_padding = static_cast(*(static_cast(fields[i].data))); } - else if (!strcmp(attrName, "max_num_tokens")) - { - TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - max_num_tokens = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "max_low_rank")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); @@ -488,7 +484,7 @@ IPluginV2* LoraPluginCreator::createPlugin(char const* name, PluginFieldCollecti // FIXME enable tactic profiler auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false, /* skip */ true); auto* obj = new LoraPlugin(in_hidden_size, out_hidden_sizes, transA, transB, num_lora_modules, type, - pluginProfiler, remove_input_padding, max_num_tokens, max_low_rank, weight_index); + pluginProfiler, remove_input_padding, max_low_rank, weight_index); obj->setPluginNamespace(mNamespace.c_str()); return obj; } diff --git a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h index 6104abd72..7795f7b7c 100644 --- a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h +++ b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h @@ -37,8 +37,8 @@ class LoraPlugin : public BasePlugin LoraPlugin() = delete; LoraPlugin(int in_hidden_size, std::vector out_hidden_sizes, int transA, int transB, int num_lora_modules, - nvinfer1::DataType type, PluginProfilerPtr const& profiler, bool remove_input_padding, int max_num_tokens, - int max_low_rank, int weight_index); + nvinfer1::DataType type, PluginProfilerPtr const& profiler, bool remove_input_padding, int max_low_rank, + int weight_index); LoraPlugin(void const* data, size_t length, PluginProfilerPtr const& profiler); @@ -117,7 +117,6 @@ class LoraPlugin : public BasePlugin int mTransB; nvinfer1::DataType mType; bool mRemoveInputPadding; - int mMaxNumTokens; int mNumLoraModules; int mInHiddenSize; int mMaxLowRank; diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp index 279eb825d..7586aa4ba 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -647,6 +647,11 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace_ptr, cudaStream_t stream) noexcept { + if (isBuilding()) + { + return 0; + } + int64_t const num_tokens = getNumTokens(inputDesc); int64_t const num_reqs = getNumLoraRequests(inputDesc); int64_t const num_not_finished = num_tokens; // TODO Take this as an input diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp index 110a07000..784055cc5 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp @@ -46,10 +46,6 @@ AllreducePlugin::AllreducePlugin(std::set group, nvinfer1::DataType type, A , mAffine(affine) , mBias(bias) { - if (std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY") != nullptr) - { - mStrategy = AllReduceStrategyType::NCCL; - } } // Parameterized constructor @@ -58,10 +54,6 @@ AllreducePlugin::AllreducePlugin(void const* data, size_t length) char const *d = reinterpret_cast(data), *a = d; read(d, mType); read(d, mStrategy); - if (std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY") != nullptr) - { - mStrategy = AllReduceStrategyType::NCCL; - } read(d, mConfig); read(d, mOp); read(d, mEps); @@ -239,7 +231,9 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe kernels::AllReduceStrategyType runtimeStrategy; - if (mStrategy == AllReduceStrategyType::NCCL) + static char* forceNcclAllReduceStrategyChar = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY"); + bool forceNcclAllReduceStrategy = (forceNcclAllReduceStrategyChar != nullptr); + if (forceNcclAllReduceStrategy || mStrategy == AllReduceStrategyType::NCCL) { runtimeStrategy = AllReduceStrategyType::NCCL; } diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.cpp index 96f30363f..62e3e4a6b 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.cpp @@ -16,6 +16,7 @@ */ #include "recvPlugin.h" +#include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/mpiUtils.h" #include @@ -91,7 +92,9 @@ int RecvPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P { size *= inputDesc[0].dims.d[i]; } + TLLM_LOG_DEBUG("start ncclRecv with size %d", size); NCCLCHECK(ncclRecv(outputs[0], size, (*getDtypeMap())[inputDesc[0].type], 0, mComm, stream)); + TLLM_LOG_DEBUG("end ncclRecv with size %d", size); return 0; } diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.cpp index 3d20b9911..de0bce1c6 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.cpp @@ -16,6 +16,7 @@ */ #include "sendPlugin.h" +#include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/mpiUtils.h" #include @@ -93,7 +94,9 @@ int SendPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P size *= inputDesc[0].dims.d[i]; } + TLLM_LOG_DEBUG("start ncclSend with size %d", size); NCCLCHECK(ncclSend(inputs[0], size, (*getDtypeMap())[inputDesc[0].type], 1, mComm, stream)); + TLLM_LOG_DEBUG("end ncclSend with size %d", size); return 0; } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp index 0f51ef0d7..05f2b93e1 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp @@ -77,8 +77,8 @@ std::shared_ptr LlmRequest::toTrtLlm() const return std::make_shared(mRequestId, mMaxNewTokens, std::make_shared>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, mLoraTaskId, loraWeights, - loraConfig, returnLogProbs(), mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits, - mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor), mApplyLogitsPostProcessorBatched, + loraConfig, mLookaheadConfig, returnLogProbs(), mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, + draftLogits, mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor), mApplyLogitsPostProcessorBatched, mEncoderTokens, mReturnEncoderOutput, mClientId, mPriority); } @@ -90,21 +90,23 @@ void LlmRequest::initBindings(py::module_& m) std::optional, std::optional, std::optional, std::optional, std::optional, std::optional, std::optional, - std::optional, bool, bool, bool, std::optional, - std::optional, bool, std::optional, bool, - std::optional, bool, std::optional, executor::PriorityType>(), + std::optional, std::optional, bool, bool, + bool, std::optional, std::optional, bool, + std::optional, bool, std::optional, bool, + std::optional, executor::PriorityType>(), py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"), py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt, py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt, py::arg("stop_words_list") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt, py::arg("prompt_vocab_size") = std::nullopt, py::arg("lora_task_id") = std::nullopt, py::arg("lora_weights") = std::nullopt, py::arg("lora_config") = std::nullopt, - py::arg("return_log_probs") = false, py::arg("return_context_logits") = false, - py::arg("return_generation_logits") = false, py::arg("draft_tokens") = std::nullopt, - py::arg("draft_logits") = std::nullopt, py::arg("exclude_input_from_output") = false, - py::arg("logits_post_processor") = std::nullopt, py::arg("apply_logits_post_processor_batched") = false, - py::arg("encoder_input_tokens") = std::nullopt, py::arg("return_encoder_output") = false, - py::arg("client_id") = std::nullopt, py::arg("priority") = executor::Request::kDefaultPriority) + py::arg("lookahead_config") = std::nullopt, py::arg("return_log_probs") = false, + py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false, + py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt, + py::arg("exclude_input_from_output") = false, py::arg("logits_post_processor") = std::nullopt, + py::arg("apply_logits_post_processor_batched") = false, py::arg("encoder_input_tokens") = std::nullopt, + py::arg("return_encoder_output") = false, py::arg("client_id") = std::nullopt, + py::arg("priority") = executor::Request::kDefaultPriority) .def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam")) .def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens) .def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos")) @@ -122,6 +124,7 @@ void LlmRequest::initBindings(py::module_& m) .def_property_readonly("lora_task_id", &LlmRequest::getLoraTaskId) .def_property_readonly("lora_weights", &LlmRequest::getLoraWeights) .def_property_readonly("lora_config", &LlmRequest::getLoraConfig) + .def_property_readonly("lookahead_config", &LlmRequest::getLookaheadConfig) .def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias) .def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList) .def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h index 84c593ff0..a4415c581 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h @@ -55,7 +55,8 @@ class LlmRequest : public tb::GenericLlmRequest std::optional promptEmbeddingTable = std::nullopt, std::optional promptVocabSize = std::nullopt, std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt, - std::optional loraConfig = std::nullopt, bool returnLogProbs = false, + std::optional loraConfig = std::nullopt, + std::optional lookaheadConfig = std::nullopt, bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt, @@ -64,7 +65,7 @@ class LlmRequest : public tb::GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, - promptVocabSize, loraTaskId, loraWeights, loraConfig, returnLogProbs, returnContextLogits, + promptVocabSize, loraTaskId, loraWeights, loraConfig, lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, draftTokens.has_value() ? std::make_shared(std::move(draftTokens.value())) : std::make_shared(), diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index 83aac4e98..4a91d0b8b 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -20,17 +20,20 @@ #include #include #include -#include #include "bindings.h" #include "executor.h" #include "streamCaster.h" #include "tensorCaster.h" +#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/tensor.h" #include "tensorrt_llm/executor/types.h" +#include +#include + namespace py = pybind11; namespace tle = tensorrt_llm::executor; using Tensor = tle::Tensor; @@ -54,6 +57,15 @@ void InitBindings(pybind11::module_& m) .value("STATIC", tle::BatchingType::kSTATIC) .value("INFLIGHT", tle::BatchingType::kINFLIGHT); + auto decodingModeGetstate = [](tle::DecodingMode const& self) { return py::make_tuple(self.getState()); }; + auto decodingModeSetstate = [](py::tuple state) + { + if (state.size() != 1) + { + throw std::runtime_error("Invalid state!"); + } + return tle::DecodingMode(state[0].cast()); + }; py::class_(m, "DecodingMode") .def("Auto", &tle::DecodingMode::Auto) .def("TopK", &tle::DecodingMode::TopK) @@ -69,7 +81,8 @@ void InitBindings(pybind11::module_& m) .def("isTopKandTopP", &tle::DecodingMode::isTopKandTopP) .def("isBeamSearch", &tle::DecodingMode::isBeamSearch) .def("isMedusa", &tle::DecodingMode::isMedusa) - .def("isLookahead", &tle::DecodingMode::isLookahead); + .def("isLookahead", &tle::DecodingMode::isLookahead) + .def(py::pickle(decodingModeGetstate, decodingModeSetstate)); py::enum_(m, "CapacitySchedulerPolicy") .value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION) @@ -223,22 +236,30 @@ void InitBindings(pybind11::module_& m) .def_property_readonly("weights", &tle::LoraConfig::getWeights) .def_property_readonly("config", &tle::LoraConfig::getConfig); + py::class_(m, "LookaheadDecodingConfig") + .def(py::init(), py::arg("max_window_size"), py::arg("max_ngram_size"), + py::arg("max_verification_set_size")) + .def_property_readonly("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize) + .def_property_readonly("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize) + .def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize); + py::class_ request(m, "Request"); request .def(py::init const&, std::optional const&, std::optional>, std::optional>, std::optional, std::optional, std::optional, - std::optional, std::optional, std::optional, - std::optional, bool>(), + std::optional, std::optional, + std::optional, std::optional, std::optional, bool>(), py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false, py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"), py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(), py::arg("pad_id") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(), py::arg("embedding_bias") = py::none(), py::arg("external_draft_tokens_config") = py::none(), py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none(), - py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none(), - py::arg("client_id") = py::none(), py::arg("return_all_generated_tokens") = false) + py::arg("lookahead_config") = py::none(), py::arg("logits_post_processor_name") = py::none(), + py::arg("encoder_input_token_ids") = py::none(), py::arg("client_id") = py::none(), + py::arg("return_all_generated_tokens") = false) .def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds) .def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens) .def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -254,6 +275,7 @@ void InitBindings(pybind11::module_& m) .def_property( "prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig) .def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig) + .def_property("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig) .def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName, &tle::Request::setLogitsPostProcessorName) .def_property( @@ -263,6 +285,12 @@ void InitBindings(pybind11::module_& m) &tle::Request::setReturnAllGeneratedTokens); request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName; + py::enum_(m, "FinishReason") + .value("NOT_FINISHED", tle::FinishReason::kNOT_FINISHED) + .value("END_ID", tle::FinishReason::kEND_ID) + .value("STOP_WORDS", tle::FinishReason::kSTOP_WORDS) + .value("LENGTH", tle::FinishReason::kLENGTH); + py::class_(m, "Result") .def(py::init<>()) .def_readwrite("is_final", &tle::Result::isFinal) @@ -271,7 +299,8 @@ void InitBindings(pybind11::module_& m) .def_readwrite("log_probs", &tle::Result::logProbs) .def_readwrite("context_logits", &tle::Result::contextLogits) .def_readwrite("generation_logits", &tle::Result::generationLogits) - .def_readwrite("encoder_output", &tle::Result::encoderOutput); + .def_readwrite("encoder_output", &tle::Result::encoderOutput) + .def_readwrite("finish_reasons", &tle::Result::finishReasons); py::class_(m, "Response") .def(py::init(), py::arg("request_id"), py::arg("error_msg")) @@ -421,13 +450,18 @@ void InitBindings(pybind11::module_& m) .def_property_readonly("host_cache_size", &tle::PeftCacheConfig::getHostCacheSize) .def(py::pickle(peftCacheConfigGetstate, peftCacheConfigSetstate)); - py::class_(m, "LookaheadDecodingConfig") - .def(py::init(), py::arg("max_window_size"), py::arg("max_ngram_size"), - py::arg("max_verification_set_size")) - .def_property_readonly("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize) - .def_property_readonly("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize) - .def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize); - + auto decodingConfigGetstate = [](tle::DecodingConfig const& self) + { return py::make_tuple(self.getDecodingMode(), self.getLookaheadDecodingConfig(), self.getMedusaChoices()); }; + auto decodingConfigSetstate = [](py::tuple state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid state!"); + } + return tle::DecodingConfig(state[0].cast>(), + state[1].cast>(), + state[2].cast>()); + }; py::class_(m, "DecodingConfig") .def(py::init, std::optional, std::optional>(), @@ -436,7 +470,55 @@ void InitBindings(pybind11::module_& m) .def_property("decoding_mode", &tle::DecodingConfig::getDecodingMode, &tle::DecodingConfig::setDecodingMode) .def_property("lookahead_decoding_config", &tle::DecodingConfig::getLookaheadDecodingConfig, &tle::DecodingConfig::setLookaheadDecoding) - .def_property("medusa_choices", &tle::DecodingConfig::getMedusaChoices, &tle::DecodingConfig::setMedusaChoices); + .def_property("medusa_choices", &tle::DecodingConfig::getMedusaChoices, &tle::DecodingConfig::setMedusaChoices) + .def(py::pickle(decodingConfigGetstate, decodingConfigSetstate)); + + auto debugConfigGetstate = [](tle::DebugConfig const& self) + { return py::make_tuple(self.getDumpInputTensors(), self.getDumpOutputTensors(), self.getDebugTensorNames()); }; + auto debugConfigSetstate = [](py::tuple state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid state!"); + } + return tle::DebugConfig( + state[0].cast(), state[1].cast(), state[2].cast>()); + }; + py::class_(m, "DebugConfig") + .def(py::init>(), py::arg("dump_input_tensors") = false, + py::arg("dump_output_tensors") = false, py::arg("debug_tensor_names") = py::none()) + .def_property( + "dump_input_tensors", &tle::DebugConfig::getDumpInputTensors, &tle::DebugConfig::setDumpInputTensors) + .def_property( + "dump_output_tensors", &tle::DebugConfig::getDumpOutputTensors, &tle::DebugConfig::setDumpOuputTensors) + .def_property( + "debug_tensor_names", &tle::DebugConfig::getDebugTensorNames, &tle::DebugConfig::setDebugTensorNames) + .def(py::pickle(debugConfigGetstate, debugConfigSetstate)); + + auto logitsPostProcessorConfigGetstate = [](tle::LogitsPostProcessorConfig const& self) + { return py::make_tuple(self.getProcessorMap(), self.getProcessorBatched(), self.getReplicate()); }; + auto logitsPostProcessorConfigSetstate = [](py::tuple state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid LogitsPostProcessorConfig state!"); + } + return tle::LogitsPostProcessorConfig(state[0].cast>(), + state[1].cast>(), state[2].cast()); + }; + + py::class_(m, "LogitsPostProcessorConfig") + .def(py::init, std::optional, + bool>(), + py::arg("processor_map") = py::none(), py::arg("processor_batched") = py::none(), + py::arg("replicate") = true) + .def_property("processor_map", &tle::LogitsPostProcessorConfig::getProcessorMap, + &tle::LogitsPostProcessorConfig::setProcessorMap) + .def_property("processor_batched", &tle::LogitsPostProcessorConfig::getProcessorBatched, + &tle::LogitsPostProcessorConfig::setProcessorBatched) + .def_property( + "replicate", &tle::LogitsPostProcessorConfig::getReplicate, &tle::LogitsPostProcessorConfig::setReplicate) + .def(py::pickle(logitsPostProcessorConfigGetstate, logitsPostProcessorConfigSetstate)); auto extendedRuntimePerfKnobConfigSetstate = [](py::tuple state) { @@ -457,71 +539,38 @@ void InitBindings(pybind11::module_& m) &tle::ExtendedRuntimePerfKnobConfig::setEnableContextFMHAFP32Acc) .def(py::pickle(extendedRuntimePerfKnobConfigGetstate, extendedRuntimePerfKnobConfigSetstate)); - auto executorConfigGetState - = [&peftCacheConfigGetstate, &kvCacheConfigGetstate, &schedulerConfigGetstate, ¶llelConfigGetstate, - &extendedRuntimePerfKnobConfigGetstate](tle::ExecutorConfig const& self) + auto executorConfigGetState = [](tle::ExecutorConfig const& self) { - py::object peftCacheConfigState = py::none(); - - if (self.getPeftCacheConfig().has_value()) - { - peftCacheConfigState = peftCacheConfigGetstate(self.getPeftCacheConfig().value()); - } - auto kvCacheConfigState = kvCacheConfigGetstate(self.getKvCacheConfig()); - auto schedulerConfigState = schedulerConfigGetstate(self.getSchedulerConfig()); - auto extendedRuntimePerfKnobConfigState - = extendedRuntimePerfKnobConfigGetstate(self.getExtendedRuntimePerfKnobConfig()); - py::object parallelConfigState = py::none(); - if (self.getParallelConfig().has_value()) - { - parallelConfigState = parallelConfigGetstate(self.getParallelConfig().value()); - } - - return py::make_tuple(self.getMaxBeamWidth(), schedulerConfigState, kvCacheConfigState, + return py::make_tuple(self.getMaxBeamWidth(), self.getSchedulerConfig(), self.getKvCacheConfig(), self.getEnableChunkedContext(), self.getNormalizeLogProbs(), self.getIterStatsMaxIterations(), self.getRequestStatsMaxIterations(), self.getBatchingType(), self.getMaxBatchSize(), self.getMaxNumTokens(), - parallelConfigState, peftCacheConfigState, self.getLogitsPostProcessorMap(), - self.getLogitsPostProcessorBatched(), self.getReplicateLogitsPostProcessor(), self.getDecodingConfig(), - self.getGpuWeightsPercent(), self.getMaxQueueSize(), extendedRuntimePerfKnobConfigState); + self.getParallelConfig(), self.getPeftCacheConfig(), self.getLogitsPostProcessorConfig(), + self.getDecodingConfig(), self.getGpuWeightsPercent(), self.getMaxQueueSize(), + self.getExtendedRuntimePerfKnobConfig(), self.getDebugConfig()); }; - auto executorConfigSetState = [&kvCacheConfigSetstate, &peftCacheConfigSetstate, &schedulerConfigSetstate, - ¶llelConfigSetstate, &extendedRuntimePerfKnobConfigSetstate](py::tuple state) + auto executorConfigSetState = [](py::tuple state) { - if (state.size() != 19) + if (state.size() != 18) { throw std::runtime_error("Invalid state!"); } - auto kvCacheConfig = kvCacheConfigSetstate(state[2].cast()); - auto schedulerConfig = schedulerConfigSetstate(state[1].cast()); - auto extendedRuntimePerfKnobConfig = extendedRuntimePerfKnobConfigSetstate(state[18].cast()); - - std::optional peftCacheConfig; - if (state[11].cast() != py::none()) - { - peftCacheConfig = peftCacheConfigSetstate(state[11].cast()); - } - std::optional parallelConfig; - if (state[10].cast() != py::none()) - { - parallelConfig = parallelConfigSetstate(state[10].cast()); - } - - return tle::ExecutorConfig(state[0].cast(), schedulerConfig, kvCacheConfig, state[3].cast(), - state[4].cast(), state[5].cast(), state[6].cast(), - state[7].cast(), state[8].cast>(), - state[9].cast>(), parallelConfig, peftCacheConfig, - state[12].cast>(), - state[13].cast>(), state[14].cast(), - state[15].cast>(), state[16].cast(), - state[17].cast>(), extendedRuntimePerfKnobConfig); + return tle::ExecutorConfig(state[0].cast(), state[1].cast(), + state[2].cast(), state[3].cast(), state[4].cast(), + state[5].cast(), state[6].cast(), state[7].cast(), + state[8].cast>(), state[9].cast>(), + state[10].cast>(), state[11].cast>(), + state[12].cast>(), + state[13].cast>(), state[14].cast(), + state[15].cast>(), state[16].cast(), + state[17].cast>()); }; py::class_(m, "ExecutorConfig") .def(py::init, std::optional, std::optional, tle::PeftCacheConfig const&, - std::optional, std::optional, bool, - std::optional, float, std::optional, - tle::ExtendedRuntimePerfKnobConfig const&>(), + std::optional, std::optional, float, + std::optional, tle::ExtendedRuntimePerfKnobConfig const&, + std::optional>(), py::arg("max_beam_width") = 1, py::arg_v("scheduler_config", tle::SchedulerConfig(), "SchedulerConfig()"), py::arg_v("kv_cache_config", tle::KvCacheConfig(), "KvCacheConfig()"), py::arg("enable_chunked_context") = false, py::arg("normalize_log_probs") = true, @@ -531,11 +580,11 @@ void InitBindings(pybind11::module_& m) py::arg("max_batch_size") = py::none(), py::arg("max_num_tokens") = py::none(), py::arg("parallel_config") = py::none(), py::arg_v("peft_cache_config", tle::PeftCacheConfig(), "PeftCacheConfig()"), - py::arg("logits_post_processor_map") = py::none(), py::arg("logits_post_processor_batched") = py::none(), - py::arg("replicate_logits_post_processor") = true, py::arg("decoding_config") = py::none(), + py::arg("logits_post_processor_config") = py::none(), py::arg("decoding_config") = py::none(), py::arg("gpu_weights_percent") = 1.0, py::arg("max_queue_size") = py::none(), py::arg_v("extended_runtime_perf_knob_config", tle::ExtendedRuntimePerfKnobConfig(), - "ExtendedRuntimePerfKnobConfig()")) + "ExtendedRuntimePerfKnobConfig()"), + py::arg("debug_config") = py::none()) .def_property("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth) .def_property("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize) .def_property("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens) @@ -555,12 +604,8 @@ void InitBindings(pybind11::module_& m) "parallel_config", &tle::ExecutorConfig::getParallelConfig, &tle::ExecutorConfig::setParallelConfig) .def_property( "peft_cache_config", &tle::ExecutorConfig::getPeftCacheConfig, &tle::ExecutorConfig::setPeftCacheConfig) - .def_property("logits_post_processor_map", &tle::ExecutorConfig::getLogitsPostProcessorMap, - &tle::ExecutorConfig::setLogitsPostProcessorMap) - .def_property("logits_post_processor_batched", &tle::ExecutorConfig::getLogitsPostProcessorBatched, - &tle::ExecutorConfig::setLogitsPostProcessorBatched) - .def_property("replicate_logits_post_processor", &tle::ExecutorConfig::getReplicateLogitsPostProcessor, - &tle::ExecutorConfig::setReplicateLogitsPostProcessor) + .def_property("logits_post_processor_config", &tle::ExecutorConfig::getLogitsPostProcessorConfig, + &tle::ExecutorConfig::setLogitsPostProcessorConfig) .def_property( "decoding_config", &tle::ExecutorConfig::getDecodingConfig, &tle::ExecutorConfig::setDecodingConfig) .def_property("gpu_weights_percent", &tle::ExecutorConfig::getGpuWeightsPercent, @@ -568,6 +613,7 @@ void InitBindings(pybind11::module_& m) .def_property("max_queue_size", &tle::ExecutorConfig::getMaxQueueSize, &tle::ExecutorConfig::setMaxQueueSize) .def_property("extended_runtime_perf_knob_config", &tle::ExecutorConfig::getExtendedRuntimePerfKnobConfig, &tle::ExecutorConfig::setExtendedRuntimePerfKnobConfig) + .def_property("debug_config", &tle::ExecutorConfig::getDebugConfig, &tle::ExecutorConfig::setDebugConfig) .def(py::pickle(executorConfigGetState, executorConfigSetState)); tensorrt_llm::pybind::executor::Executor::initBindings(m); diff --git a/cpp/tensorrt_llm/pybind/executor/executor.cpp b/cpp/tensorrt_llm/pybind/executor/executor.cpp index 87af73ab6..c9e10673e 100644 --- a/cpp/tensorrt_llm/pybind/executor/executor.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executor.cpp @@ -42,11 +42,15 @@ Executor::Executor(std::filesystem::path const& encoderModelPath, std::filesyste mExecutor = std::make_unique(encoderModelPath, decoderModelPath, modelType, executorConfig); } -Executor::Executor(std::string const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, +Executor::Executor(pybind11::buffer engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig) { - mExecutor = std::make_unique( - std::vector(engineBuffer.begin(), engineBuffer.end()), jsonConfigStr, modelType, executorConfig); + py::buffer_info info = engineBuffer.request(); + auto begin = reinterpret_cast(info.ptr); + // the buffer is just 1-D array of uint8_t, so .shape[0] == number of bytes + auto end = reinterpret_cast(begin) + info.shape[0]; + mExecutor + = std::make_unique(std::vector(begin, end), jsonConfigStr, modelType, executorConfig); } Executor::Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, @@ -92,7 +96,7 @@ void Executor::initBindings(py::module_& m) tle::ExecutorConfig const&>(), py::arg("encoder_model_path"), py::arg("decoder_model_path"), py::arg("model_type"), py::arg("executor_config")) - .def(py::init(), + .def(py::init(), py::arg("engine_buffer"), py::arg("json_config_str"), py::arg("model_type"), py::arg("executor_config")) .def(py::init(), diff --git a/cpp/tensorrt_llm/pybind/executor/executor.h b/cpp/tensorrt_llm/pybind/executor/executor.h index 5c950a0ff..e19cfcc77 100644 --- a/cpp/tensorrt_llm/pybind/executor/executor.h +++ b/cpp/tensorrt_llm/pybind/executor/executor.h @@ -34,7 +34,7 @@ class Executor Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); - Executor(std::string const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, + Executor(pybind11::buffer engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, diff --git a/cpp/tensorrt_llm/runtime/CMakeLists.txt b/cpp/tensorrt_llm/runtime/CMakeLists.txt index 10ddbc3a2..07c1bd6fb 100644 --- a/cpp/tensorrt_llm/runtime/CMakeLists.txt +++ b/cpp/tensorrt_llm/runtime/CMakeLists.txt @@ -20,6 +20,7 @@ set(SRCS utils/debugUtils.cu bufferManager.cpp explicitDraftTokensBuffers.cpp + lookaheadBuffers.cpp layerProfiler.cpp loraManager.cpp loraUtils.cpp @@ -68,6 +69,6 @@ set_property(TARGET runtime_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_include_directories(runtime_src PRIVATE ${MPI_C_INCLUDE_DIRS}) -if(ENABLE_MULTI_DEVICE EQUAL 1) +if(ENABLE_MULTI_DEVICE) target_link_libraries(runtime_src PUBLIC ${NCCL_LIB}) endif() diff --git a/cpp/tensorrt_llm/runtime/gptDecoder.cpp b/cpp/tensorrt_llm/runtime/gptDecoder.cpp index 484736afb..39c09ce59 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoder.cpp @@ -59,7 +59,8 @@ GptDecoder::GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSiz template void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots, - std::optional const& output) + std::optional const& output, + std::optional const> const& requestsOpt) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -135,7 +136,29 @@ void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize setupParams->decodingParams = explicitDraftTokensParams; } + else if (mDecodingMode.isLookahead()) + { + TLLM_CHECK_WITH_INFO(output.has_value(), "Output tensors must be provided for Lookahead decoding"); + TLLM_LOG_DEBUG("gptDecoder setup lookahead, batchSize=%d", batchSize); + auto lookaheadParams = std::make_shared(); + TLLM_CHECK(requestsOpt); + auto& requests = requestsOpt.value(); + lookaheadParams->prompt.resize(0); + lookaheadParams->prompt.reserve(batchSize); + lookaheadParams->algoConfigs.resize(0); + lookaheadParams->algoConfigs.reserve(batchSize); + for (size_t bi = 0; bi < batchSize; bi++) + { + lookaheadParams->prompt.emplace_back(ITensor::slice(requests[bi].ids, 0, requests[bi].inputLen)); + TLLM_CHECK(requests[bi].lookaheadRuntimeConfig); + lookaheadParams->algoConfigs.emplace_back(requests[bi].lookaheadRuntimeConfig.value()); + } + lookaheadParams->generationLengths = output->lookaheadOutputs->generationLengths; + lookaheadParams->positionOffsets = output->lookaheadOutputs->positionOffsets; + lookaheadParams->attentionPackedMasks = output->lookaheadOutputs->packedMasks; + setupParams->decodingParams = std::move(lookaheadParams); + } setupParams->decodingParams->randomSeed = mSamplingConfig.randomSeed; mDynamicDecodeLayer->setup(batchSize, mSamplingConfig.beamWidth, batchSlots, setupParams); @@ -248,6 +271,18 @@ void prepareExplicitDraftTokensInput( TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +void prepareLookaheadInputs( + DecodingInput const& inputs, size_t maxBatchSize, std::shared_ptr& baseInputs) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto inputParams = std::dynamic_pointer_cast(baseInputs); + auto const& lookaheadInputs = inputs.lookaheadInputs.value(); + inputParams->curTokensPerStep = lookaheadInputs.tokensPerStep; + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + template std::shared_ptr prepareInputs(DecodingInput const& input, size_t maxBatchSize, tle::DecodingMode const& decodingMode, std::shared_ptr bufferManager) @@ -272,7 +307,7 @@ std::shared_ptr prepareInputs(DecodingInput const& input } else if (decodingMode.isLookahead()) { - // TODO add lookahead inputs + forwardParams = std::make_shared(input.endIds, input.batchSlots); } else if (decodingMode.isExplicitDraftTokens()) { @@ -319,9 +354,9 @@ std::shared_ptr prepareInputs(DecodingInput const& input forwardParams->stopCriteriaInputs = prepareStopCriteriaInputs(input); - if (input.finished) + if (input.finishReasons) { - forwardParams->finished = input.finished; + forwardParams->finished = input.finishReasons; } // Medusa @@ -336,6 +371,12 @@ std::shared_ptr prepareInputs(DecodingInput const& input prepareExplicitDraftTokensInput(input, maxBatchSize, forwardParams); } + if (input.lookaheadInputs) + { + prepareLookaheadInputs(input, maxBatchSize, forwardParams); + forwardParams->localBatchSize = input.batchSize; + } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return forwardParams; @@ -430,6 +471,15 @@ void prepareSpeculativeDecodingOutputs(DecodingOutput& output, std::shared_ptrgenerationLengthsHost = explicitDraftTokensBuffers->generationLengthsHost; outputParams->maxGenLengthHost = explicitDraftTokensBuffers->maxGenLengthHost; } + if (decodingMode.isLookahead()) + { + TLLM_CHECK(output.lookaheadOutputs); + auto outputParams = std::dynamic_pointer_cast(baseOutputs); + outputParams->packedMasks = output.lookaheadOutputs->packedMasks; + outputParams->positionIds = output.lookaheadOutputs->positionIds; + outputParams->positionOffsets = output.lookaheadOutputs->positionOffsets; + outputParams->generationLengths = output.lookaheadOutputs->generationLengths; + } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -444,10 +494,14 @@ std::shared_ptr prepareOutputs( { outputParams = std::make_shared(output.ids); } - else if (decodingMode.isMedusa() || decodingMode.isLookahead()) + else if (decodingMode.isMedusa()) { outputParams = std::make_shared(output.ids); } + else if (decodingMode.isLookahead()) + { + outputParams = std::make_shared(output.ids); + } else if (decodingMode.isExplicitDraftTokens()) { outputParams = std::make_shared(output.ids); @@ -470,9 +524,9 @@ std::shared_ptr prepareOutputs( outputParams->parentIds = output.parentIds; } - if (output.finished) + if (output.finishReasons) { - outputParams->finished = output.finished; + outputParams->finished = output.finishReasons; } if (output.finishedSum) @@ -609,8 +663,7 @@ void GptDecoder::gatherTree(DecodingOutput const& decodingOutput, DecodingInp bh.numBeamsCBA = bufferCast(*decodingOutput.beamHypotheses.numBeamsCBA); bh.minNormedScoresCBA = bufferCast(*decodingOutput.beamHypotheses.minNormedScoresCBA); bh.batchDones = bufferCast(*decodingOutput.beamHypotheses.batchDones); - bh.finished = reinterpret_cast( - bufferCast(*decodingOutput.finished)); + bh.finished = bufferCast(*decodingOutput.finishReasons); bh.outputIdsUnfinish = bufferCast(*decodingOutput.ids); bh.parentIdsUnfinish = bufferCast(*decodingOutput.parentIds); diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index e7151cca2..b343796db 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -20,6 +20,8 @@ #include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/cudaEvent.h" +#include "tensorrt_llm/runtime/memoryCounters.h" +#include "tensorrt_llm/runtime/runtimeBuffers.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include @@ -122,6 +124,8 @@ GptDecoderBatched::GptDecoderBatched(std::size_t vocabSize, std::size_t vocabSiz dOutput->cumLogProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dOutput->logProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); dOutput->beamHypotheses.empty(mBufferManager); + dOutput->finishReasons + = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); mNumDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); mCurandStates = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT8); @@ -180,6 +184,10 @@ void GptDecoderBatched::allocateSpeculativeDecodingBuffers() = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); } } + if (mSpeculativeDecodingMode.isLookaheadDecoding()) + { + dInput->lookaheadInputs = DecodingInput::LookaheadInputs(); + } if (mSpeculativeDecodingMode.needsKVCacheRewind()) { speculativeDecodingOutputs.acceptedTokensLen @@ -204,6 +212,17 @@ void GptDecoderBatched::setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inp TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +void GptDecoderBatched::setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + TLLM_CHECK(mSpeculativeDecodingMode.isLookaheadDecoding()); + mJointDecodingOutput->lookaheadOutputs = std::move(lookaheadDecodingBuffers); + mJointDecodingInput->lookaheadInputs->tokensPerStep = mJointDecodingOutput->lookaheadOutputs->generationLengths; + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, SizeType32 maxTokensPerEngineStep, nvinfer1::DataType dtype, ModelConfig const& modelConfig) @@ -270,6 +289,9 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max mFinishedSteps->reshape(maxTokensPerStepXmaxBatchSizeXmaxBeamWidth); mBufferManager.setZero(*mFinishedSteps); + dOutput.finishReasons->reshape(maxBatchSizeXmaxBeamWidth); + mBufferManager.setZero(*dOutput.finishReasons); + mBatchSlotsSetup->reshape(ITensor::makeShape({maxBatchSize})); mBatchSlotsDecoder->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize})); mBatchSlotsAcceptTokens->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize})); @@ -684,8 +706,15 @@ void GptDecoderBatched::newRequestLookahead(SizeType32 batchIdx, decoder_batch:: { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - // TODO(nkorobov) add lookahead layer - TLLM_LOG_WARNING("Lookahead decoding is not supported yet."); + TLLM_CHECK(mJointDecodingOutput->lookaheadOutputs); + + auto& stream = mRuntimeStream; + + // The first generation step only generate 1 token. + TensorPtr curTokensPerStepSlice + = ITensor::slice(constPointerCast(mJointDecodingInput->lookaheadInputs->tokensPerStep), batchIdx, 1); + kernels::invokeFill(*curTokensPerStepSlice, 1, *stream); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -747,7 +776,7 @@ void GptDecoderBatched::newRequests(std::vector const& seqSlots, TensorPtr batchSlotsView = ITensor::slice(mBatchSlotsSetup, 0, localBatchSize); auto samplingConfig = SamplingConfig(samplingConfigs); - mDecoder->setup(samplingConfig, localBatchSize, batchSlotsView, {*mJointDecodingOutput}); + mDecoder->setup(samplingConfig, localBatchSize, batchSlotsView, {*mJointDecodingOutput}, {requests}); auto const& stream = mDecoderStream; CudaEvent event{}; @@ -905,7 +934,7 @@ void GptDecoderBatched::forwardDecoder( TensorPtr newTokensStepView = ITensor::slice(dOutput.newTokensSteps, step, mMaxDecodingDecoderTokens); dInput.logitsVec = logitsVec; - dInput.finished = finishedStepsInput; + dInput.finishReasons = finishedStepsInput; if (maxBeamWidth > 1 && input.seqSlots) { @@ -925,7 +954,7 @@ void GptDecoderBatched::forwardDecoder( } dOutput.newTokens = newTokensStepView; - dOutput.finished = finishedStepsOutput; + dOutput.finishReasons = finishedStepsOutput; dOutput.lengths = sequenceLengths; if (localBatchDecoderIdx > 0) @@ -1057,7 +1086,7 @@ CudaEvent GptDecoderBatched::postProcessRequest( slice(dOutput.cumLogProbs, dJointOutput.cumLogProbs); slice(dOutput.cacheIndirection, dJointOutput.cacheIndirection); slice(dOutput.lengths, dJointOutput.lengths); - slice(dOutput.finished, dJointOutput.finished); + slice(dOutput.finishReasons, dJointOutput.finishReasons); slice(dOutput.logProbs, dJointOutput.logProbs); dOutput.newTokens = ITensor::view(dJointOutput.newTokens); diff --git a/cpp/tensorrt_llm/runtime/gptSession.cpp b/cpp/tensorrt_llm/runtime/gptSession.cpp index 63687f042..a871113b5 100644 --- a/cpp/tensorrt_llm/runtime/gptSession.cpp +++ b/cpp/tensorrt_llm/runtime/gptSession.cpp @@ -20,7 +20,6 @@ #include "iBuffer.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/safetensors.h" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/runtime/gptDecoderBatched.h" #include "tensorrt_llm/runtime/ipcUtils.h" @@ -386,7 +385,7 @@ void GptSession::setup(Config const& sessionConfig) if (mModelConfig.getManageWeightsType() != ModelConfig::ManageWeightsType::kDisabled) { TLLM_CHECK_WITH_INFO(sessionConfig.enginePath.has_value(), "Engine path is not set."); - auto weightPath = sessionConfig.enginePath.value().parent_path() + auto weightPath = sessionConfig.enginePath->parent_path() / ("rank" + std::to_string(mWorldConfig.getLocalRank()) + "_managed_weights.safetensors"); mRuntime->loadManagedWeights(weightPath.string()); } diff --git a/cpp/tensorrt_llm/runtime/ipcUtils.cpp b/cpp/tensorrt_llm/runtime/ipcUtils.cpp index 49c4bdca1..727e34a21 100644 --- a/cpp/tensorrt_llm/runtime/ipcUtils.cpp +++ b/cpp/tensorrt_llm/runtime/ipcUtils.cpp @@ -29,7 +29,7 @@ namespace tensorrt_llm::runtime namespace { -bool setPeerAccess(WorldConfig const& worldConfig, bool enable) +bool canAccessPeer(WorldConfig const& worldConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const srcDevice = worldConfig.getDevice(); @@ -50,20 +50,6 @@ bool setPeerAccess(WorldConfig const& worldConfig, bool enable) TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return false; } - - if (enable) - { - cudaDeviceEnablePeerAccess(destDevice, 0); - } - else - { - cudaDeviceDisablePeerAccess(destDevice); - } - auto const error = cudaGetLastError(); - if (error != cudaErrorPeerAccessAlreadyEnabled && error != cudaErrorPeerAccessNotEnabled) - { - TLLM_CUDA_CHECK(error); - } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return true; @@ -147,7 +133,7 @@ AllReduceBuffers::AllReduceBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWi SizeType32 hiddenSize, BufferManager const& manager, WorldConfig const& worldConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto const isP2pSupported = setPeerAccess(worldConfig, true); + auto const isP2pSupported = canAccessPeer(worldConfig); auto const tpSize = worldConfig.getTensorParallelism(); auto const bufferSize = tpSize diff --git a/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp b/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp new file mode 100644 index 000000000..465641bf0 --- /dev/null +++ b/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp @@ -0,0 +1,153 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "tensorrt_llm/runtime/lookaheadBuffers.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/runtime/common.h" + +namespace tensorrt_llm::runtime +{ + +LookaheadDecodingBuffers::LookaheadDecodingBuffers( + SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& bufferManager) + : generationLengths(bufferManager.gpu(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32)) + , positionOffsets( + bufferManager.gpu(ITensor::makeShape({maxNumSequences, maxTokensPerStep}), nvinfer1::DataType::kINT32)) + , packedMasks(bufferManager.gpu(ITensor::makeShape({maxNumSequences, maxTokensPerStep, + static_cast(common::divUp(maxTokensPerStep, 32))}), + nvinfer1::DataType::kINT32)) + , positionIds( + bufferManager.gpu(ITensor::makeShape({maxNumSequences, maxTokensPerStep}), nvinfer1::DataType::kINT32)) +{ + TLLM_LOG_DEBUG( + "LookaheadDecodingBuffers, maxNumSequences = %d, maxTokensPerStep = %d", maxNumSequences, maxTokensPerStep); +} + +LookaheadRuntimeBuffers::LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, + runtime::BufferManager const& manager, runtime::ModelConfig const& modelConfig, + runtime::WorldConfig const& worldConfig, executor::DecodingConfig const& /* decodingConfig */, + runtime::TllmRuntime const& runtime) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + TLLM_CHECK_WITH_INFO(maxBeamWidth == 1, "Lookahead decoding does not support beam search"); + + // auto const tokensPerStep = modelConfig.getMaxTokensPerStep(); + auto const tokensPerStep = modelConfig.getMaxDecodingTokens(); + auto const numPackedMasks = static_cast(tensorrt_llm::common::divUp(tokensPerStep, 32)); + + // Copy buffers to device + packedMasksDevice + = manager.gpu(ITensor::makeShape({maxBatchSize * tokensPerStep, numPackedMasks}), nvinfer1::DataType::kINT32); + positionOffsetsDevice = manager.gpu(ITensor::makeShape({maxBatchSize, tokensPerStep}), nvinfer1::DataType::kINT32); + generationLengthsDevice = manager.gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + positionIdsDevice = manager.gpu(ITensor::makeShape({maxBatchSize, tokensPerStep}), nvinfer1::DataType::kINT32); + + packedMaskHost = manager.cpu(packedMasksDevice->getShape(), nvinfer1::DataType::kINT32); + positionOffsetsHost = manager.cpu(positionOffsetsDevice->getShape(), nvinfer1::DataType::kINT32); + generationLengthsHost = manager.cpu(generationLengthsDevice->getShape(), nvinfer1::DataType::kINT32); + positionIdsHost = manager.gpu(positionOffsetsDevice->getShape(), nvinfer1::DataType::kINT32); + + packedMaskHostCopy = manager.cpu(packedMasksDevice->getShape(), nvinfer1::DataType::kINT32); + positionOffsetsHostCopy = manager.cpu(positionOffsetsDevice->getShape(), nvinfer1::DataType::kINT32); + generationLengthsHostCopy = manager.cpu(generationLengthsDevice->getShape(), nvinfer1::DataType::kINT32); + positionIdsHostCopy = manager.cpu(positionIdsDevice->getShape(), nvinfer1::DataType::kINT32); + + batchSlotsHostCopy = manager.cpu(generationLengthsDevice->getShape(), nvinfer1::DataType::kINT32); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +void LookaheadRuntimeBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSequences, + ITensor const& requestTypes, ITensor const& seqSlots, LookaheadDecodingBuffers const& decoderLookaheadBuffers, + TllmRuntime const& runtime, ModelConfig const& modelConfig, WorldConfig const& worldConfig) const +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const& manager = runtime.getBufferManager(); + + auto const tokensPerStep = modelConfig.getMaxDecodingTokens(); + + manager.copy(*decoderLookaheadBuffers.positionOffsets, *positionOffsetsHostCopy); + manager.copy(*decoderLookaheadBuffers.packedMasks, *packedMaskHostCopy); + manager.copy(*decoderLookaheadBuffers.positionIds, *positionIdsHostCopy); + manager.copy(seqSlots, *batchSlotsHostCopy); + manager.copy(*decoderLookaheadBuffers.generationLengths, *generationLengthsHostCopy); + + manager.getStream().synchronize(); + + BufferRange batchSlotsRange(*batchSlotsHostCopy); + for (SizeType32 bi = 0; bi < numGenSequences; bi++) + { + SizeType32 gbi = batchSlotsRange[bi + numCtxSequences]; + manager.copy(*ITensor::at(generationLengthsHostCopy, {gbi}), *ITensor::at(generationLengthsHost, {bi})); + manager.copy(*ITensor::at(positionOffsetsHostCopy, {gbi}), *ITensor::at(positionOffsetsHost, {bi})); + manager.copy(*ITensor::slice(packedMaskHostCopy, gbi * tokensPerStep, tokensPerStep), + *ITensor::slice(packedMaskHost, bi * tokensPerStep, tokensPerStep)); + manager.copy(*ITensor::at(positionIdsHostCopy, {gbi}), *ITensor::at(positionIdsHost, {bi})); + } + manager.copy(*ITensor::slice(generationLengthsHost, 0, numGenSequences), + *ITensor::slice(generationLengthsDevice, 0, numGenSequences)); + manager.copy(*ITensor::slice(positionOffsetsHost, 0, numGenSequences), + *ITensor::slice(positionOffsetsDevice, 0, numGenSequences)); + manager.copy(*ITensor::slice(packedMaskHost, 0, numGenSequences * tokensPerStep), + *ITensor::slice(packedMasksDevice, 0, numGenSequences * tokensPerStep)); + manager.copy( + *ITensor::slice(positionIdsHost, 0, numGenSequences), *ITensor::slice(positionIdsDevice, 0, numGenSequences)); + + manager.getStream().synchronize(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +void LookaheadRuntimeBuffers::reshape(SizeType32 numCtxSequences, SizeType32 numGenSequences, SizeType32 tokensPerStep) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const numSequences = numGenSequences; + + auto packedMaskShape = packedMasksDevice->getShape(); + packedMaskShape.d[0] = numSequences * tokensPerStep; + packedMasksDevice->reshape(packedMaskShape); + packedMaskHost->reshape(packedMaskShape); + + auto generationLengthsShape = generationLengthsDevice->getShape(); + generationLengthsShape.d[0] = numSequences; + generationLengthsDevice->reshape(generationLengthsShape); + generationLengthsHost->reshape(generationLengthsShape); + + auto positionOffsetsShape = positionOffsetsDevice->getShape(); + positionOffsetsShape.d[0] = numSequences; + positionOffsetsDevice->reshape(positionOffsetsShape); + positionOffsetsHost->reshape(positionOffsetsShape); + + auto positionIdsShape = positionIdsDevice->getShape(); + positionIdsShape.d[0] = numSequences; + positionIdsDevice->reshape(positionIdsShape); + positionIdsHost->reshape(positionIdsShape); + + auto batchSlotsShape = batchSlotsHostCopy->getShape(); + batchSlotsShape.d[0] = numCtxSequences + numGenSequences; + batchSlotsHostCopy->reshape(batchSlotsShape); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +void LookaheadRuntimeBuffers::insertInputTensors( + TensorMap& inputBuffers, TensorMap& /* outputBuffers */, runtime::WorldConfig const& /* worldConfig */) const +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + inputBuffers.insert_or_assign("spec_decoding_packed_mask", packedMasksDevice); + inputBuffers.insert_or_assign("spec_decoding_generation_lengths", generationLengthsDevice); + inputBuffers.insert_or_assign("spec_decoding_position_offsets", positionOffsetsDevice); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +} // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/statefulGptDecoder.cpp b/cpp/tensorrt_llm/runtime/statefulGptDecoder.cpp index 48a399e12..8601ad0cc 100644 --- a/cpp/tensorrt_llm/runtime/statefulGptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/statefulGptDecoder.cpp @@ -55,7 +55,7 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS dOutput->newTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); dOutput->parentIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); - dOutput->finished + dOutput->finishReasons = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); dOutput->finishedSum = mBufferManager.pinnedPool(ITensor::makeShape({1}), nvSizeType); dOutput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); @@ -129,9 +129,9 @@ void StatefulGptDecoder::reshapeBuffers(SizeType32 batchSize, SizeType32 beamWid dOutput.newTokens->reshape(batchSizeXbeamWidth); mBufferManager.setZero(*dOutput.newTokens); dOutput.parentIds->reshape(outputIdsShape); - dOutput.finished->reshape(batchSizeXbeamWidth); - dInput.finished = ITensor::view(dOutput.finished); - mBufferManager.setZero(*dOutput.finished); + dOutput.finishReasons->reshape(batchSizeXbeamWidth); + dInput.finishReasons = ITensor::view(dOutput.finishReasons); + mBufferManager.setZero(*dOutput.finishReasons); dOutput.finishedSum->reshape(batchSizeShape); mBufferManager.setZero(*dOutput.finishedSum); @@ -266,7 +266,7 @@ void StatefulGptDecoder::newBatch( // output auto& dOutput = *mDecodingOutput; manager.setZero(*dOutput.newTokens); - manager.setZero(*dOutput.finished); + manager.setZero(*dOutput.finishReasons); manager.setZero(*dOutput.finishedSum); // If outputs contains cumLogProbs, use that diff --git a/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp b/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp index ceb9b0704..63516f916 100644 --- a/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp +++ b/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp @@ -17,6 +17,7 @@ #include "tensorrt_llm/runtime/utils/numpyUtils.h" #include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/runtime/bufferManager.h" @@ -50,8 +51,10 @@ std::string getNumpyTypeDesc(nvinfer1::DataType type) return type_map.count(type) > 0 ? type_map.at(type) : "x"; } -nvinfer1::DataType typeFromNumpyDesc(std::string type) +nvinfer1::DataType typeFromNumpyDesc(std::string const& type) { + TLLM_LOG_DEBUG("numpy type: %s", type.c_str()); + using dt = nvinfer1::DataType; static const std::unordered_map type_map{{"?", dt::kBOOL}, {"u1", dt::kUINT8}, {"i1", dt::kINT8}, {"i4", dt::kINT32}, {"i8", dt::kINT64}, {"f2", dt::kHALF}, {"f4", dt::kFLOAT}}; @@ -77,6 +80,8 @@ void parseNpyIntro(FILE*& f_ptr, uint32_t& header_len, uint32_t& start_data) n_elems = fread((void*) &npy_major, sizeof(uint8_t), 1, f_ptr); n_elems += fread((void*) &npy_minor, sizeof(uint8_t), 1, f_ptr); + TLLM_LOG_DEBUG("npy format version: %d.%d", npy_major, npy_minor); + if (npy_major == 1) { uint16_t header_len_u16 = 0; @@ -109,11 +114,18 @@ int parseNpyHeader(FILE*& f_ptr, uint32_t header_len, nvinfer1::DataType& type, std::string header(header_c, header_len); free(header_c); + TLLM_LOG_DEBUG("npy header: %s", header.c_str()); + size_t start, end; start = header.find("'descr'") + 7; start = header.find("'", start); + // ignore byte order specifier + if (header[start + 1] == '<' || header[start + 1] == '>' || header[start + 1] == '=') + { + ++start; + } end = header.find("'", start + 1); - type = typeFromNumpyDesc(header.substr(start + 2, end - start - 2)); + type = typeFromNumpyDesc(header.substr(start + 1, end - start - 1)); start = header.find("'fortran_order'") + 15; start = header.find(":", start); diff --git a/cpp/tensorrt_llm/runtime/utils/numpyUtils.h b/cpp/tensorrt_llm/runtime/utils/numpyUtils.h index 6cb93ddbf..b5b253068 100644 --- a/cpp/tensorrt_llm/runtime/utils/numpyUtils.h +++ b/cpp/tensorrt_llm/runtime/utils/numpyUtils.h @@ -25,8 +25,7 @@ namespace tensorrt_llm::runtime::utils { //! \brief Create new tensor from numpy file. -[[nodiscard]] ITensor::UniquePtr loadNpy( - BufferManager const& manager, std::string const& npyFile, const MemoryType where); +[[nodiscard]] ITensor::UniquePtr loadNpy(BufferManager const& manager, std::string const& npyFile, MemoryType where); //! \brief Save tensor to numpy file. void saveNpy(BufferManager const& manager, ITensor const& tensor, std::string const& filename); diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index bb99db098..035ad18b5 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -82,7 +82,7 @@ add_gtest(gptDecoderBatchedTest runtime/gptDecoderBatchedTest.cpp) add_gtest(gptSessionTest runtime/gptSessionTest.cpp) target_link_libraries(gptSessionTest PRIVATE modelSpecStatic) add_gtest(memoryUtilsTest common/memoryUtilsTest.cu) -if(ENABLE_MULTI_DEVICE EQUAL 1) +if(ENABLE_MULTI_DEVICE) add_gtest(mpiUtilsTest common/mpiUtilsTest.cpp) endif() add_gtest(quantizationTest common/quantizationTest.cpp) @@ -99,6 +99,7 @@ add_gtest(samplingTest runtime/samplingTest.cpp) add_gtest(samplingConfigTest runtime/samplingConfigTest.cpp) add_gtest(iTensorTest runtime/iTensorTest.cpp) add_gtest(iBufferTest runtime/iBufferTest.cpp) +add_gtest(utilsTest runtime/utilsTest.cpp) add_gtest(worldConfigTest runtime/worldConfigTest.cpp) add_gtest(medusaModuleTest runtime/medusaModuleTest.cpp) add_gtest(mixtureOfExpertsTest kernels/mixtureOfExpertsTest.cu) diff --git a/cpp/tests/kernels/ropeTest.cu b/cpp/tests/kernels/ropeTest.cu index f2ca57eb4..36bc3478c 100644 --- a/cpp/tests/kernels/ropeTest.cu +++ b/cpp/tests/kernels/ropeTest.cu @@ -275,7 +275,8 @@ protected: std::shared_ptr mBufferManager; std::shared_ptr mStream; BufferManager::ITensorPtr cu_q_seqlens_tensor{nullptr}, cu_kv_seqlens_tensor{nullptr}, - padding_offset_tensor{nullptr}, fmha_tile_counter_ptr_tensor{nullptr}, rotary_inv_freq_buf_tensor{nullptr}; + padding_offset_tensor{nullptr}, encoder_padding_offset_tensor{nullptr}, fmha_tile_counter_ptr_tensor{nullptr}, + rotary_inv_freq_buf_tensor{nullptr}; std::mt19937 gen; ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // initialize params coming from GPTAttentionPluginCommon @@ -364,9 +365,10 @@ protected: cu_q_seqlens_tensor = mBufferManager->pinned(ITensor::makeShape({cu_seqlens_size}), nvinfer1::DataType::kINT32); cu_kv_seqlens_tensor = mBufferManager->pinned(ITensor::makeShape({cu_seqlens_size}), nvinfer1::DataType::kINT32); - padding_offset_tensor = mBufferManager->pinned( - ITensor::makeShape({batch_size, (mCrossAttention ? cross_qkv_length : input_seq_length)}), - nvinfer1::DataType::kINT32); + padding_offset_tensor + = mBufferManager->pinned(ITensor::makeShape({batch_size, input_seq_length}), nvinfer1::DataType::kINT32); + encoder_padding_offset_tensor + = mBufferManager->pinned(ITensor::makeShape({batch_size, cross_qkv_length}), nvinfer1::DataType::kINT32); fmha_tile_counter_ptr_tensor = mBufferManager->pinned(ITensor::makeShape({mEnableContextFMHA ? 1 : 0}), nvinfer1::DataType::kINT32); rotary_inv_freq_buf_tensor = mBufferManager->pinned( @@ -470,6 +472,7 @@ protected: cu_q_seqlens = bufferCast(*(this->cu_q_seqlens_tensor)); int* cu_kv_seqlens = bufferCast(*(this->cu_kv_seqlens_tensor)); int* padding_offset = bufferCast(*(this->padding_offset_tensor)); + int* encoder_padding_offset = bufferCast(*(this->encoder_padding_offset_tensor)); uint32_t* fmha_tile_counter_ptr = bufferCast(*(this->fmha_tile_counter_ptr_tensor)); rotary_inv_freq_buf = bufferCast(*(this->rotary_inv_freq_buf_tensor)); @@ -478,12 +481,14 @@ protected: decoderParams.seqQOffsets = cu_q_seqlens; decoderParams.seqKVOffsets = cu_kv_seqlens; decoderParams.paddingOffsets = padding_offset; + decoderParams.encoderPaddingOffsets = mCrossAttention ? encoder_padding_offset : nullptr; decoderParams.attentionMask = mCrossAttention ? nullptr : attention_mask; // manually set for cross attn // Fixed sequence length offset if not removing the padding (cu_q_seqlens[ii] = ii * seq_length). - decoderParams.seqQLengths = mCrossAttention ? encoder_input_lengths : q_seq_lengths; + decoderParams.seqQLengths = q_seq_lengths; decoderParams.seqKVLengths = mCrossAttention ? encoder_input_lengths : kv_seq_lengths; decoderParams.batchSize = batch_size; - decoderParams.maxQSeqLength = mCrossAttention ? cross_qkv_length : input_seq_length; + decoderParams.maxQSeqLength = input_seq_length; + decoderParams.maxEncoderQSeqLength = mCrossAttention ? cross_qkv_length : 0; decoderParams.removePadding = mRemovePadding; decoderParams.attentionWindowSize = cyclic_attention_window_size; decoderParams.sinkTokenLength = sink_token_length; diff --git a/cpp/tests/layers/lookaheadAlgorithmTest.cpp b/cpp/tests/layers/lookaheadAlgorithmTest.cpp index d3075c8f5..fc70b2bff 100644 --- a/cpp/tests/layers/lookaheadAlgorithmTest.cpp +++ b/cpp/tests/layers/lookaheadAlgorithmTest.cpp @@ -41,9 +41,9 @@ bool verifyAcceptOffsets(TensorPtr output, TensorPtr accepted, TensorPtr accepte BufferRange acceptedRange(*accepted); BufferRange offsetsRange(*acceptedOffsets); bool result = true; - for (SizeType32 i = 0; i < acceptedRange.size(); i++) + for (SizeType32 i = 1; i < acceptedRange.size(); i++) { - result &= outputRange[offsetsRange[i]] == acceptedRange[i]; + result &= outputRange[offsetsRange[i - 1] + 1] == acceptedRange[i]; } return result; } diff --git a/cpp/tests/layers/lookaheadDecodingLayerTest.cpp b/cpp/tests/layers/lookaheadDecodingLayerTest.cpp index 745d36aae..c66735a1f 100644 --- a/cpp/tests/layers/lookaheadDecodingLayerTest.cpp +++ b/cpp/tests/layers/lookaheadDecodingLayerTest.cpp @@ -26,6 +26,7 @@ #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" +#include "tensorrt_llm/layers/decodingParams.h" #include "tensorrt_llm/layers/lookaheadDecodingLayer.h" #include "tensorrt_llm/layers/lookaheadDecodingUtils.h" #include "tensorrt_llm/runtime/common.h" @@ -220,7 +221,6 @@ class LookaheadDecodingLayerTest : public testing::Test TensorPtr mAlgoConfigBatch; - TensorPtr mFinished; TensorPtr mOutputIds; TensorPtr mSequenceLengths; TensorPtr mProbs; @@ -230,14 +230,19 @@ class LookaheadDecodingLayerTest : public testing::Test TensorPtr mBatchSlots; TensorPtr mBatchSlotsMax; + TensorPtr mNewTokens; TensorPtr mNumNewTokens; - TensorPtr mKNumNewTokensCumSum; + TensorPtr mNumNewTokensCumSum; TensorPtr mPathsOffsets; TensorPtr mDraftLengths; TensorPtr mDraftTokens; - TensorPtr mDraftPosIds; TensorPtr mPackedMasks; TensorPtr mPackedMasksBool; + TensorPtr mGenerationLengths; + TensorPtr mGenerationLengthsMax; + TensorPtr mPositionOffsets; + TensorPtr mPositionIds; + TensorPtr mAttentionPackedMask; TensorPtr mInputTokensBatch; TensorPtr mPositionIdsBatch; @@ -279,9 +284,10 @@ void LookaheadDecodingLayerTest::allocateBuffers() TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const maxBatchSize = mTestParam.maxBatchSize; auto const vocabSize = mAscii->getVocabSize(); + auto const maxBeamSize = 1; - SizeType32 maxNumNewTokens, maxDraftLen; - std::tie(mMaxTokensPerStep, maxNumNewTokens, maxDraftLen, std::ignore) + SizeType32 maxNumNewTokens, maxDraftLen, maxAcceptedDraftLen; + std::tie(mMaxTokensPerStep, maxNumNewTokens, maxDraftLen, maxAcceptedDraftLen) = executor::LookaheadDecodingConfig(mTestParam.maxW, mTestParam.maxN, mTestParam.maxG) .calculateSpeculativeResource(); // mMaxTokensPerStep = maxTokensPerStep; @@ -348,12 +354,11 @@ void LookaheadDecodingLayerTest::allocateBuffers() mAlgoConfigBatch = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, 3}), nvinfer1::DataType::kINT32); - mFinished = BufferManager::pinnedPool(maxBatchShape1D, TRTDataType::value); mEndIds = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); mTokensPerStep = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); mOutputIds = BufferManager::pinnedPool( - ITensor::makeShape({maxBatchSize, mMaxSeqLen + mMaxTokensPerStep}), nvinfer1::DataType::kINT32); + ITensor::makeShape({maxBatchSize, maxBeamSize, mMaxSeqLen + mMaxTokensPerStep}), nvinfer1::DataType::kINT32); mSequenceLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); mProbs = BufferManager::pinnedPool( @@ -366,21 +371,27 @@ void LookaheadDecodingLayerTest::allocateBuffers() mPositionIdsBatch = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, mMaxTokensPerStep}), nvinfer1::DataType::kINT32); + mNewTokens = BufferManager::pinnedPool( + ITensor::makeShape({mMaxTokensPerStep, maxBatchSize, 1}), nvinfer1::DataType::kINT32); mNumNewTokens = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); mDraftLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); mDraftTokens = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32); - mDraftPosIds - = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32); - auto divUp32 = [](SizeType32 x) { return x / 32 + ((x % 32) ? 1 : 0); }; - mPackedMasks = BufferManager::pinnedPool( - ITensor::makeShape({maxBatchSize, mMaxTokensPerStep, divUp32(mMaxTokensPerStep)}), nvinfer1::DataType::kINT32); + auto packedMaskShape = ITensor::makeShape( + {maxBatchSize, mMaxTokensPerStep, static_cast(common::divUp(mMaxTokensPerStep, 32))}); + mPackedMasks = BufferManager::pinnedPool(packedMaskShape, nvinfer1::DataType::kINT32); mPackedMasksBool = BufferManager::pinnedPool( ITensor::makeShape({maxBatchSize, mMaxTokensPerStep, mMaxTokensPerStep}), nvinfer1::DataType::kBOOL); - mKNumNewTokensCumSum - = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize + 1}), nvinfer1::DataType::kINT32); - mPathsOffsets - = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, maxNumNewTokens}), nvinfer1::DataType::kINT32); + mNumNewTokensCumSum = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize + 1}), nvinfer1::DataType::kINT32); + mPathsOffsets = BufferManager::pinnedPool( + ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32); + mGenerationLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); + mGenerationLengthsMax = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); + mPositionOffsets + = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, mMaxTokensPerStep}), nvinfer1::DataType::kINT32); + mPositionIds + = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, mMaxTokensPerStep}), nvinfer1::DataType::kINT32); + mAttentionPackedMask = BufferManager::pinnedPool(packedMaskShape, nvinfer1::DataType::kINT32); mBatchSlotsMax = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); @@ -390,11 +401,9 @@ void LookaheadDecodingLayerTest::allocateBuffers() mBatchSlots = ITensor::slice(mBatchSlotsMax, 0, batchSize); - trk::invokeFill(*mFinished, uint8_t{0}, *mStream); trk::invokeFill(*mEndIds, mAscii->getEndToken(), *mStream); trk::invokeFill(*mOutputIds, int32_t{0}, *mStream); trk::invokeFill(*mSequenceLengths, int32_t{0}, *mStream); - // trk::invokeFill(*mGeneratedLengths, int32_t{0}, *mStream); trk::invokeFill(*mTokensPerStep, mMaxTokensPerStep, *mStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -419,7 +428,7 @@ void LookaheadDecodingLayerTest::newRequests(std::vector requestIds) TokenIdType contextToken = mOracle[gbi][len]; SizeType32 contextLen = len + 1; - BufferRange outputRange(*ITensor::at(mOutputIds, {gbi})); + BufferRange outputRange(*ITensor::at(mOutputIds, {gbi, 0})); for (auto& v : outputRange) { v = 0; @@ -430,7 +439,7 @@ void LookaheadDecodingLayerTest::newRequests(std::vector requestIds) BufferLocation(*mDraftLengths).at(gbi) = 0; BufferLocation(*mNumNewTokens).at(gbi) = 0; - mPrompt[gbi] = ITensor::slice(mOutputIds, {gbi, 0}, len + 1); + mPrompt[gbi] = ITensor::slice(mOutputIds, {gbi, 0, 0}, len + 1); for (auto& v : BufferRange(*mHistogram[gbi])) { @@ -455,6 +464,11 @@ void LookaheadDecodingLayerTest::newRequests(std::vector requestIds) setupParams->prompt.emplace_back(mPrompt[gbi]); setupParams->algoConfigs.emplace_back(mTestParam.w, mTestParam.n, mTestParam.g); PRINT_TOKENS(setupParams->prompt[bi]); + setupParams->generationLengths = mGenerationLengthsMax; + setupParams->actualGenerationLengths = mGenerationLengths; + setupParams->positionOffsets = mPositionOffsets; + // setupParams->outputs.positionIds = mPositionIds; + setupParams->attentionPackedMasks = mPackedMasks; } std::vector seed(requestIds.begin(), requestIds.end()); setupParams->randomSeed = std::make_optional(seed); @@ -463,6 +477,8 @@ void LookaheadDecodingLayerTest::newRequests(std::vector requestIds) PRINT_VALUES(mBatchSlotsMax); mDecoder->setup(requestSize, beamSize, newRequestSlots, setupParams); + PRINT_VALUES(mPositionOffsets); + batchSize += requestIds.size(); mBatchSlots = ITensor::slice(mBatchSlotsMax, 0, batchSize); TLLM_LOG_DEBUG("newwRequests mBatchSlots %s", D(mBatchSlots).values().c_str()); @@ -493,7 +509,7 @@ void LookaheadDecodingLayerTest::manageBatch(void) SizeType32 gbi = batchSlotsRange[bi]; SizeType32 nbi = newBatchSize; - TensorPtr theSequence = ITensor::at(mOutputIds, {gbi}); + TensorPtr theSequence = ITensor::at(mOutputIds, {gbi, 0}); BufferRange theSequenceRange(*theSequence); auto theSequenceLength = BufferRange(*mSequenceLengths)[gbi]; auto theNumNewTokens = BufferRange(*mNumNewTokens)[gbi]; @@ -520,19 +536,16 @@ void LookaheadDecodingLayerTest::manageBatch(void) } auto theDraftLen = BufferRange(*mDraftLengths)[gbi]; - BufferLocation(*mTokensPerStep).at(gbi) = 1 + theDraftLen; + auto theGenerationLength = BufferRange(*mGenerationLengths)[gbi]; + TLLM_CHECK_DEBUG_WITH_INFO( + theDraftLen + 1 == theGenerationLength, "%d + 1 == %d", theDraftLen, theGenerationLength); + BufferLocation(*mTokensPerStep).at(gbi) = theGenerationLength; - BufferLocation(*mPositionIdsBatch).at(nbi, 0) = theSequenceLength - 1; BufferLocation(*mInputTokensBatch).at(nbi, 0) = theSequenceRange[theSequenceLength - 1]; - - TLLM_LOG_DEBUG("W=%d, N=%d, G=%d, w=%d, n=%d, g=%d, draftLen = %d", mTestParam.maxW, mTestParam.maxN, - mTestParam.maxG, mTestParam.w, mTestParam.n, mTestParam.g, theDraftLen); - PRINT_VALUES(mInputTokensBatch); - mBufferManager->copy(*ITensor::slice(mDraftTokens, {gbi, 0}, theDraftLen), *ITensor::slice(mInputTokensBatch, {nbi, 1}, theDraftLen)); - mBufferManager->copy(*ITensor::slice(mDraftPosIds, {gbi, 0}, theDraftLen), - *ITensor::slice(mPositionIdsBatch, {nbi, 1}, theDraftLen)); + mBufferManager->copy(*ITensor::slice(mPositionIds, {gbi, 0}), *ITensor::slice(mPositionIdsBatch, {nbi, 0})); + BufferLocation(*mPositionIdsBatch).at(nbi, 0) = theSequenceLength - 1; TLLM_LOG_DEBUG("W=%d, N=%d, G=%d, w=%d, n=%d, g=%d, draftLen = %d", mTestParam.maxW, mTestParam.maxN, mTestParam.maxG, mTestParam.w, mTestParam.n, mTestParam.g, theDraftLen); @@ -599,16 +612,38 @@ void LookaheadDecodingLayerTest::llmForward(void) for (SizeType32 bi = 0; bi < batchSize; bi++) { auto gbi = BufferRange(*mBatchSlots)[bi]; + auto start = BufferRange(*mSequenceLengths)[gbi] - 1; auto len = BufferRange(*mTokensPerStep)[gbi]; + TLLM_LOG_DEBUG("LookaheadDecodingLayerTest::llmForward input len=%d", len); TensorPtr output = ITensor::slice(mProbs, {bi, 0}, len); TensorPtr golden = ITensor::slice(mGoldenSampledTokens, {gbi, 0}, len); - convertInt32ToBool(ITensor::at(mPackedMasksBool, {gbi}), ITensor::at(mPackedMasks, {gbi})); + BufferRange idRange(*ITensor::slice(mPositionIdsBatch, {bi, 0}, len)); + BufferRange offsetRange(*ITensor::slice(mPositionOffsets, {gbi, 0}, len)); + PRINT_VALUES(ITensor::slice(mPositionIdsBatch, {bi, 0})); + PRINT_VALUES(ITensor::slice(mPositionOffsets, {bi, 0})); + for (auto i = 0; i < idRange.size(); i++) + { + TLLM_CHECK(idRange[i] == start + offsetRange[i]); + } - mLlm[gbi]->forward(output, // - ITensor::slice(mInputTokensBatch, {bi, 0}, len), // - ITensor::slice(mPositionIdsBatch, {bi, 0}, len), // - ITensor::at(mPackedMasksBool, {gbi})); + if (false) + { + convertInt32ToBool(ITensor::at(mPackedMasksBool, {gbi}), ITensor::at(mPackedMasks, {gbi})); + mLlm[gbi]->forward(output, // + ITensor::slice(mInputTokensBatch, {bi, 0}, len), // + ITensor::slice(mPositionIdsBatch, {bi, 0}, len), // + ITensor::at(mPackedMasksBool, {gbi})); + } + else + { + convertInt32ToBool(ITensor::at(mPackedMasksBool, {gbi}), ITensor::at(mPackedMasks, {gbi})); + mLlm[gbi]->forward(output, // + start, // + ITensor::slice(mInputTokensBatch, {bi, 0}, len), // + ITensor::slice(mPositionOffsets, {gbi, 0}, len), // + ITensor::at(mPackedMasksBool, {gbi})); + } mAscii->logitsToTensor(golden, output); TLLM_LOG_DEBUG("batch[%d] LLM golden: '%s'", gbi, D(golden).tokens().c_str()); @@ -627,21 +662,25 @@ void LookaheadDecodingLayerTest::decodeForward(void) auto inputParams = std::make_shared(mEndIds, mBatchSlots); inputParams->localBatchSize = batchSize; inputParams->logits = ITensor::slice(mProbs, 0, batchSize); - inputParams->finished = mFinished; // TODO(liweim) ask finished protocol + inputParams->batchSlots = mBatchSlots; inputParams->curTokensPerStep = mTokensPerStep; - auto outputParams = std::make_shared(mOutputIds); + auto outputParams = std::make_shared(mOutputIds); PRINT_VALUES(mSequenceLengths); outputParams->sequenceLength = mSequenceLengths; - outputParams->finished = mFinished; outputParams->nextDraftLengths = mDraftLengths; outputParams->nextDraftTokens = mDraftTokens; - outputParams->nextDraftPosIds = mDraftPosIds; outputParams->packedMasks = mPackedMasks; outputParams->numNewTokens = mNumNewTokens; - outputParams->numNewTokensCumSum = mKNumNewTokensCumSum; + outputParams->newTokens = mNewTokens; + outputParams->numNewTokensCumSum = mNumNewTokensCumSum; outputParams->pathsOffsets = mPathsOffsets; + outputParams->generationLengths = mGenerationLengthsMax; + outputParams->actualGenerationLengths = mGenerationLengths; + outputParams->positionOffsets = mPositionOffsets; + outputParams->positionIds = mPositionIds; + outputParams->packedMasks = mPackedMasks; PRINT_VALUES(mTokensPerStep); @@ -663,28 +702,42 @@ void LookaheadDecodingLayerTest::verifyDecode(void) { auto gbi = BufferRange(*mBatchSlots)[bi]; auto len = BufferRange(*mTokensPerStep)[gbi]; - TensorPtr golden = ITensor::slice(mGoldenSampledTokens, {gbi, 0}, len); auto sequenceLength = BufferLocation(*mSequenceLengths).at(gbi); - auto numNewTokens = BufferLocation(*mNumNewTokens).at(gbi); - TensorPtr newTokens = ITensor::slice(mOutputIds, {gbi, sequenceLength - numNewTokens}, numNewTokens); - TensorPtr pathOffsets = ITensor::slice(mPathsOffsets, {gbi, 0}, numNewTokens); - BufferRange goldenRange(*golden); - BufferRange newTokensRange(*newTokens); - BufferRange offsetsRange(*pathOffsets); - for (SizeType32 i = 0; i < newTokensRange.size(); i++) + auto draftLength = BufferLocation(*mDraftLengths).at(gbi); + auto generationLength = BufferLocation(*mGenerationLengths).at(gbi); + BufferRange posOffsetRange(*ITensor::slice(mPositionOffsets, {gbi, 0}, generationLength)); + BufferRange posIdRange(*ITensor::slice(mPositionIds, {gbi, 0}, generationLength)); + TLLM_LOG_DEBUG("generationLength = %d, draftLength = %d", generationLength, draftLength); + TLLM_CHECK(draftLength + 1 == generationLength); + TLLM_CHECK(posOffsetRange[0] == 0); + TLLM_CHECK(posIdRange[0] == sequenceLength - 1); + for (SizeType32 i = 0; i < posIdRange.size(); i++) { - TLLM_CHECK(goldenRange[offsetsRange[i]] == newTokensRange[i]); + TLLM_CHECK(posIdRange[i] == posOffsetRange[i] + sequenceLength - 1); } } - BufferRange cumSumRange(*mKNumNewTokensCumSum); - SizeType32 sum = 0; - TLLM_CHECK(cumSumRange[0] == sum); + + BufferRange cumSumRange(*mNumNewTokensCumSum); + BufferRange pathOffsetsRange(*mPathsOffsets); + PRINT_VALUES(mNumNewTokensCumSum); for (SizeType32 gbi = 0; gbi < mTestParam.maxBatchSize; gbi++) { + SizeType32 pathOffsetBegin = cumSumRange[gbi]; + SizeType32 pathOffsetEnd = cumSumRange[gbi + 1]; + TensorPtr golden = ITensor::at(mGoldenSampledTokens, {gbi}); + auto sequenceLength = BufferLocation(*mSequenceLengths).at(gbi); auto numNewTokens = BufferLocation(*mNumNewTokens).at(gbi); - sum += numNewTokens; - TLLM_CHECK(cumSumRange[gbi + 1] == sum); + TensorPtr newTokens = ITensor::slice(mOutputIds, {gbi, 0, sequenceLength - numNewTokens}, numNewTokens); + BufferRange goldenRange(*ITensor::at(mGoldenSampledTokens, {gbi})); + BufferRange newTokensRange( + *ITensor::slice(mOutputIds, {gbi, 0, sequenceLength - numNewTokens}, numNewTokens)); + + SizeType32 ni = 1; + for (SizeType32 poi = pathOffsetBegin; poi < pathOffsetEnd; poi++) + { + TLLM_CHECK(goldenRange[pathOffsetsRange[poi] + 1] == newTokensRange[ni++]); + } } } diff --git a/cpp/tests/layers/randomLlm.cpp b/cpp/tests/layers/randomLlm.cpp index f644ae797..2116186a6 100644 --- a/cpp/tests/layers/randomLlm.cpp +++ b/cpp/tests/layers/randomLlm.cpp @@ -206,6 +206,19 @@ bool RandomLlm::verify(SizeType32 const offset, TensorConstPtr const& script) co return result; } +void RandomLlm::forward(TensorPtr const& output, runtime::SizeType32 startId, TensorConstPtr const& input, + TensorConstPtr const& offsets, TensorConstPtr const mask) const +{ + TensorPtr posIds = BufferManager::cpu(input->getShape(), nvinfer1::DataType::kINT32); + BufferRange idRange(*posIds); + BufferRange offsetRange(*offsets); + for (auto i = 0; i < idRange.size(); i++) + { + idRange[i] = startId + offsetRange[i]; + } + forward(output, input, posIds, mask); +} + void RandomLlm::forward(TensorPtr const& output, TensorConstPtr const& input, TensorConstPtr const& position, TensorConstPtr const mask) const { diff --git a/cpp/tests/layers/randomLlm.h b/cpp/tests/layers/randomLlm.h index 191aa7d1e..a6e898bae 100644 --- a/cpp/tests/layers/randomLlm.h +++ b/cpp/tests/layers/randomLlm.h @@ -109,6 +109,8 @@ class RandomLlm } // simulate forward in a LLM. + void forward(TensorPtr const& output, runtime::SizeType32 startId, TensorConstPtr const& input, + TensorConstPtr const& offsets, TensorConstPtr const mask = nullptr) const; void forward(TensorPtr const& output, TensorConstPtr const& input, TensorConstPtr const& position, TensorConstPtr const mask = nullptr) const; //! set inout[i] invalid if mask[i]==false; diff --git a/cpp/tests/resources/data/test_model_lora_config.json b/cpp/tests/resources/data/test_model_lora_config.json index 18affde58..73a598d01 100644 --- a/cpp/tests/resources/data/test_model_lora_config.json +++ b/cpp/tests/resources/data/test_model_lora_config.json @@ -160,7 +160,6 @@ "streamingllm": false }, "use_strip_plan": false, - "max_encoder_input_len": 1024, - "use_fused_mlp": false + "max_encoder_input_len": 1024 } } diff --git a/cpp/tests/resources/scripts/build_chatglm_engines.py b/cpp/tests/resources/scripts/build_chatglm_engines.py index cabe6107c..39829f25c 100644 --- a/cpp/tests/resources/scripts/build_chatglm_engines.py +++ b/cpp/tests/resources/scripts/build_chatglm_engines.py @@ -132,7 +132,7 @@ def build_engines(model_cache: typing.Optional[str] = None, model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.CONTINUOUS) model_spec_obj.use_gpt_plugin() engine_dir = Path( model_dir @@ -142,7 +142,7 @@ def build_engines(model_cache: typing.Optional[str] = None, build_engine(ckpt_dir, engine_dir, False, is_chatglm_6b_or_glm_10b) model_spec_obj.use_packed_input() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) engine_dir = Path( model_dir ) / "rt_engine" / model_name / model_spec_obj.get_model_path( diff --git a/cpp/tests/resources/scripts/build_engines_utils.py b/cpp/tests/resources/scripts/build_engines_utils.py index 46fcce51b..2c57e5a3d 100644 --- a/cpp/tests/resources/scripts/build_engines_utils.py +++ b/cpp/tests/resources/scripts/build_engines_utils.py @@ -65,10 +65,17 @@ def wincopy(source: str, dest: str, isdir: bool, cwd=None) -> None: # Helper function to locate model_spec module. -def init_model_spec_module(): +def init_model_spec_module(force_init_trtllm_bindings=True): import os + # model spec depends on tensorrt_llm bindings. This will trigger initialization of bindings. # Rely on unique built model_spec to locate the module. + if force_init_trtllm_bindings: + import tensorrt_llm.bindings as _tb + + # Ensure the KVCacheType enum is available. + assert _tb.KVCacheType('PAGED') is not None + cpp_root_dir = _pl.Path(__file__).parent.resolve().parent.parent.parent found_locations = [] diff --git a/cpp/tests/resources/scripts/build_gpt_engines.py b/cpp/tests/resources/scripts/build_gpt_engines.py index 02bfa0266..6cdf1c7bf 100755 --- a/cpp/tests/resources/scripts/build_gpt_engines.py +++ b/cpp/tests/resources/scripts/build_gpt_engines.py @@ -202,23 +202,21 @@ def build_engines(model_cache: Optional[str] = None, no_kv_cache_args = ['--kv_cache_type=disabled'] def get_ifb_args(kv_cache_type): - if kv_cache_type == model_spec.KVCacheType.DISABLED: + if kv_cache_type == _tb.KVCacheType.DISABLED: return ifb_base_args + no_kv_cache_args - elif kv_cache_type == model_spec.KVCacheType.PAGED: + elif kv_cache_type == _tb.KVCacheType.PAGED: return ifb_base_args + paged_kv_cache_args else: assert False, f"Unsupported kv_cache_type: {kv_cache_type}" model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_packed_input() model_spec_current = model_spec_obj.__copy__() - for kv_cache_type in [ - model_spec.KVCacheType.DISABLED, model_spec.KVCacheType.PAGED - ]: + for kv_cache_type in [_tb.KVCacheType.DISABLED, _tb.KVCacheType.PAGED]: model_spec_current.set_kv_cache_type(kv_cache_type) build_engine( str(fp16_ckpt_dir), @@ -235,7 +233,7 @@ def get_ifb_args(kv_cache_type): str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir), f'--max_draft_len={max_draft_tokens}', '--speculative_decoding_mode=draft_tokens_external', - *get_ifb_args(model_spec.KVCacheType.PAGED)) + *get_ifb_args(_tb.KVCacheType.PAGED)) model_spec_current = model_spec_obj.__copy__() model_spec_current.use_multiple_profiles() @@ -243,8 +241,7 @@ def get_ifb_args(kv_cache_type): build_engine( str(fp16_ckpt_dir), str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir), - '--multiple_profiles=enable', - *get_ifb_args(model_spec.KVCacheType.PAGED)) + '--multiple_profiles=enable', *get_ifb_args(_tb.KVCacheType.PAGED)) model_spec_current = model_spec_obj.__copy__() max_input_len = 128 @@ -253,7 +250,7 @@ def get_ifb_args(kv_cache_type): build_engine(str(fp16_ckpt_dir), str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir), - *get_ifb_args(model_spec.KVCacheType.PAGED), + *get_ifb_args(_tb.KVCacheType.PAGED), max_input_len=max_input_len) # Build the target model with return accepted token logits @@ -270,8 +267,7 @@ def get_ifb_args(kv_cache_type): str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir), f'--max_draft_len={max_draft_len}', '--speculative_decoding_mode=draft_tokens_external', - '--gather_generation_logits', - *get_ifb_args(model_spec.KVCacheType.PAGED)) + '--gather_generation_logits', *get_ifb_args(_tb.KVCacheType.PAGED)) # We build almost the same engine twice. But this engine has gather_all_token_logits # to extract logits from python runtime and uses context FMHA for generation to match draft model executions, @@ -283,19 +279,7 @@ def get_ifb_args(kv_cache_type): build_engine( str(fp16_ckpt_dir), str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir), - '--gather_all_token_logits', - *get_ifb_args(model_spec.KVCacheType.PAGED)) - - model_spec_current = model_spec_obj.__copy__() - model_spec_current.use_look_ahead_decoding() - max_draft_len = 64 - model_spec_current.set_draft_tokens(max_draft_len) - build_engine( - str(fp16_ckpt_dir), - str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir), - f'--max_draft_len={max_draft_len}', - '--speculative_decoding_mode=lookahead_decoding', - *get_ifb_args(model_spec.KVCacheType.PAGED)) + '--gather_all_token_logits', *get_ifb_args(_tb.KVCacheType.PAGED)) # build engine with lora enabled model_spec_current = model_spec_obj.__copy__() @@ -304,7 +288,7 @@ def get_ifb_args(kv_cache_type): str(fp16_ckpt_dir), str(engine_dir / model_spec_current.get_model_path() / tp_pp_dir), "--lora_target_modules=attn_qkv", '--lora_plugin=float16', - *get_ifb_args(model_spec.KVCacheType.PAGED)) + *get_ifb_args(_tb.KVCacheType.PAGED)) if model_cache: llm_datasets_root = Path(model_cache) / "datasets" @@ -326,9 +310,7 @@ def get_ifb_args(kv_cache_type): model_spec_current.use_packed_input() model_spec_current.set_quant_method(model_spec.QuantMethod.SMOOTH_QUANT) - for kv_cache_type in [ - model_spec.KVCacheType.DISABLED, model_spec.KVCacheType.PAGED - ]: + for kv_cache_type in [_tb.KVCacheType.DISABLED, _tb.KVCacheType.PAGED]: model_spec_current.set_kv_cache_type(kv_cache_type) build_engine( str(fp16_sq_ckpt_dir), diff --git a/cpp/tests/resources/scripts/build_gptj_engines.py b/cpp/tests/resources/scripts/build_gptj_engines.py index 96410da88..65d1a15dd 100755 --- a/cpp/tests/resources/scripts/build_gptj_engines.py +++ b/cpp/tests/resources/scripts/build_gptj_engines.py @@ -133,7 +133,7 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False): get_ckpt_with_modelopt_quant(hf_dir, fp8_ckpt_path, model_cache) model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.FP8) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_packed_input() build_engine(fp8_ckpt_path, engine_dir / model_spec_obj.get_model_path() / tp_pp_dir, @@ -146,7 +146,7 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False): print("\nBuilding fp16-plugin engine") model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.CONTINUOUS) build_engine(fp16_ckpt_path, engine_dir / model_spec_obj.get_model_path() / tp_pp_dir, @@ -163,7 +163,7 @@ def build_engines(model_cache: _tp.Optional[str] = None, only_fp8=False): '--remove_input_padding=enable', "--context_fmha=disable") print("\nBuilding fp16-plugin-packed-paged engine") - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) build_engine(fp16_ckpt_path, engine_dir / model_spec_obj.get_model_path() / tp_pp_dir, '--gpt_attention_plugin=float16', diff --git a/cpp/tests/resources/scripts/build_llama_engines.py b/cpp/tests/resources/scripts/build_llama_engines.py index 48f4b2c21..c9a5c8494 100644 --- a/cpp/tests/resources/scripts/build_llama_engines.py +++ b/cpp/tests/resources/scripts/build_llama_engines.py @@ -27,17 +27,18 @@ import tensorrt_llm.bindings as _tb -def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, *args): +def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, convert_extra_args, + build_extra_args): ckpt_dir = engine_dir / 'ckpt' - covert_cmd = [_sys.executable, "examples/llama/convert_checkpoint.py" - ] + ([f'--model_dir={weight_dir}'] if weight_dir else []) + [ - f'--output_dir={ckpt_dir}', - '--dtype=float16', - ] + list(args) + convert_cmd = [_sys.executable, "examples/llama/convert_checkpoint.py" + ] + ([f'--model_dir={weight_dir}'] if weight_dir else []) + [ + f'--output_dir={ckpt_dir}', + '--dtype=float16', + ] + convert_extra_args - run_command(covert_cmd) + run_command(convert_cmd) build_args = [ 'trtllm-build', @@ -52,7 +53,7 @@ def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, *args): '--log_level=error', '--paged_kv_cache=enable', '--remove_input_padding=enable', - ] + ] + build_extra_args run_command(build_args) @@ -83,7 +84,7 @@ def build_engines(model_cache: str, only_multi_gpu: bool): model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_packed_input() tp_pp_sizes = [(1, 1)] @@ -97,7 +98,16 @@ def build_engines(model_cache: str, only_multi_gpu: bool): build_engine(hf_dir, engine_dir / model_spec_obj.get_model_path() / tp_pp_dir, - f'--tp_size={tp_size}', f'--pp_size={pp_size}') + [f'--tp_size={tp_size}', f'--pp_size={pp_size}'], []) + + ## build lookahead engine + model_spec_obj.use_lookahead_decoding() + build_engine(hf_dir, + engine_dir / model_spec_obj.get_model_path() / 'tp1-pp1-gpu', + [], [ + '--max_draft_len=39', + '--speculative_decoding_mode=lookahead_decoding' + ]) print("Done.") diff --git a/cpp/tests/resources/scripts/build_mamba_engines.py b/cpp/tests/resources/scripts/build_mamba_engines.py index b920ab96f..a3a7fed18 100644 --- a/cpp/tests/resources/scripts/build_mamba_engines.py +++ b/cpp/tests/resources/scripts/build_mamba_engines.py @@ -112,7 +112,7 @@ def build_engines(model_cache: _tp.Optional[str] = None): ckpt_dir = models_dir / 'rt_ckpt' / model_name engine_dir = models_dir / 'rt_engine' / model_name model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.CONTINUOUS) model_spec_obj.use_tensor_parallelism(tp_size) model_spec_obj.use_pipeline_parallelism(pp_size) @@ -132,7 +132,7 @@ def build_engines(model_cache: _tp.Optional[str] = None): engine_dir / model_spec_obj.get_model_path() / tp_pp_dir, '--remove_input_padding=enable', '--paged_state=disable') print("\nBuilding fp16-plugin-packed-paged engine") - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir, engine_dir / model_spec_obj.get_model_path() / tp_pp_dir, '--remove_input_padding=enable', '--paged_state=enable') diff --git a/cpp/tests/resources/scripts/build_medusa_engines.py b/cpp/tests/resources/scripts/build_medusa_engines.py index b6f52c30b..401f7ac0d 100755 --- a/cpp/tests/resources/scripts/build_medusa_engines.py +++ b/cpp/tests/resources/scripts/build_medusa_engines.py @@ -92,7 +92,7 @@ def build_engines(model_cache: str): model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_packed_input() model_spec_obj.use_medusa() diff --git a/cpp/tests/resources/scripts/build_recurrentgemma_engines.py b/cpp/tests/resources/scripts/build_recurrentgemma_engines.py index d912455dc..495c849d2 100644 --- a/cpp/tests/resources/scripts/build_recurrentgemma_engines.py +++ b/cpp/tests/resources/scripts/build_recurrentgemma_engines.py @@ -114,7 +114,7 @@ def build_engines(model_cache: _tp.Optional[str] = None): model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() model_spec_obj.use_packed_input() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) print("\nBuilding fp16-plugin-packed-paged engine") build_engine(hf_dir, ckpt_dir / model_spec_obj.get_model_path() / tp_pp_dir, diff --git a/cpp/tests/resources/scripts/generate_expected_chatglm_output.py b/cpp/tests/resources/scripts/generate_expected_chatglm_output.py index 000e1e7b1..d90fa1877 100755 --- a/cpp/tests/resources/scripts/generate_expected_chatglm_output.py +++ b/cpp/tests/resources/scripts/generate_expected_chatglm_output.py @@ -55,9 +55,9 @@ def generate_output( model_spec_obj_list = [ model_spec.ModelSpec( input_file, _tb.DataType.HALF).use_gpt_plugin().set_kv_cache_type( - model_spec.KVCacheType.CONTINUOUS), + _tb.KVCacheType.CONTINUOUS), model_spec.ModelSpec(input_file, _tb.DataType.HALF).use_gpt_plugin(). - use_packed_input().set_kv_cache_type(model_spec.KVCacheType.PAGED), + use_packed_input().set_kv_cache_type(_tb.KVCacheType.PAGED), ] for model_spec_obj in model_spec_obj_list: diff --git a/cpp/tests/resources/scripts/generate_expected_gpt_output.py b/cpp/tests/resources/scripts/generate_expected_gpt_output.py index 020d70a9f..4037a236f 100755 --- a/cpp/tests/resources/scripts/generate_expected_gpt_output.py +++ b/cpp/tests/resources/scripts/generate_expected_gpt_output.py @@ -112,7 +112,7 @@ def generate_outputs(num_beams): print('Generating GPT2 FP32 outputs') model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.FLOAT) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.CONTINUOUS) if num_beams == 1: generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, @@ -126,7 +126,7 @@ def generate_outputs(num_beams): print('Generating GPT2 FP16 outputs') model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.CONTINUOUS) if num_beams == 1: generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, @@ -142,7 +142,7 @@ def generate_outputs(num_beams): num_beams=num_beams, input_name=input_name, model_spec_obj=model_spec_obj) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.gather_logits() generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, @@ -163,7 +163,7 @@ def generate_outputs(num_beams): model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_packed_input() generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, @@ -183,7 +183,7 @@ def generate_outputs(num_beams): model_spec_obj = model_spec.ModelSpec(input_name_long, _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() model_spec_obj.use_packed_input() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, input_name=input_name_long, @@ -193,7 +193,7 @@ def generate_outputs(num_beams): model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() model_spec_obj.use_packed_input() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.set_quant_method(model_spec.QuantMethod.SMOOTH_QUANT) generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, diff --git a/cpp/tests/resources/scripts/generate_expected_gptj_output.py b/cpp/tests/resources/scripts/generate_expected_gptj_output.py index b18291535..e10c8e42f 100755 --- a/cpp/tests/resources/scripts/generate_expected_gptj_output.py +++ b/cpp/tests/resources/scripts/generate_expected_gptj_output.py @@ -70,7 +70,7 @@ def generate_outputs(only_fp8, num_beams): if only_fp8 and num_beams == 1: model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.FP8) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_packed_input() print('Generating GPT-J FP8-kv-cache outputs') @@ -81,7 +81,7 @@ def generate_outputs(only_fp8, num_beams): print('Generating GPT-J FP16 outputs') model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.CONTINUOUS) generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, model_spec_obj=model_spec_obj) @@ -91,7 +91,7 @@ def generate_outputs(only_fp8, num_beams): num_beams=num_beams, model_spec_obj=model_spec_obj) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, model_spec_obj=model_spec_obj) diff --git a/cpp/tests/resources/scripts/generate_expected_llama_output.py b/cpp/tests/resources/scripts/generate_expected_llama_output.py index d0e51faba..08d904201 100644 --- a/cpp/tests/resources/scripts/generate_expected_llama_output.py +++ b/cpp/tests/resources/scripts/generate_expected_llama_output.py @@ -79,7 +79,7 @@ def generate_outputs(num_beams, only_multi_gpu=False): ) model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_packed_input() for tp_size, pp_size in tp_pp_sizes: diff --git a/cpp/tests/resources/scripts/generate_expected_mamba_output.py b/cpp/tests/resources/scripts/generate_expected_mamba_output.py index cdfeb2151..080da0ade 100644 --- a/cpp/tests/resources/scripts/generate_expected_mamba_output.py +++ b/cpp/tests/resources/scripts/generate_expected_mamba_output.py @@ -74,7 +74,7 @@ def generate_outputs(num_beams): print('Generating Mamba FP16 outputs') input_name = 'input_tokens.npy' model_spec_obj = model_spec.ModelSpec(input_name, _tb.DataType.HALF) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.CONTINUOUS) generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, @@ -96,7 +96,7 @@ def generate_outputs(num_beams): model_spec_obj=model_spec_obj) print('Generating Mamba FP16-plugin-packed-paged outputs') - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) generate_output(engine=model_spec_obj.get_model_path(), num_beams=num_beams, input_name=input_name, diff --git a/cpp/tests/resources/scripts/generate_expected_medusa_output.py b/cpp/tests/resources/scripts/generate_expected_medusa_output.py index 4703edd9d..e87a6ab57 100755 --- a/cpp/tests/resources/scripts/generate_expected_medusa_output.py +++ b/cpp/tests/resources/scripts/generate_expected_medusa_output.py @@ -68,7 +68,7 @@ def generate_outputs(): model_spec_obj.use_gpt_plugin() model_spec_obj.set_max_output_length(max_output_len) model_spec_obj.use_packed_input() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_medusa() generate_output(engine=model_spec_obj.get_model_path(), diff --git a/cpp/tests/resources/scripts/generate_expected_recurrentgemma_output.py b/cpp/tests/resources/scripts/generate_expected_recurrentgemma_output.py index 66fdc6b74..9fa89fc32 100644 --- a/cpp/tests/resources/scripts/generate_expected_recurrentgemma_output.py +++ b/cpp/tests/resources/scripts/generate_expected_recurrentgemma_output.py @@ -74,7 +74,7 @@ def generate_outputs(num_beams): input_file = 'input_tokens.npy' model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) model_spec_obj.use_gpt_plugin() - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_packed_input() print('Generating RecurrentGemma FP16-plugin-packed-paged outputs') diff --git a/cpp/tests/resources/scripts/test_cpp.py b/cpp/tests/resources/scripts/test_cpp.py index bb02e5468..6b3a50d61 100755 --- a/cpp/tests/resources/scripts/test_cpp.py +++ b/cpp/tests/resources/scripts/test_cpp.py @@ -406,6 +406,9 @@ def prepare_model_tests(model_name: str, python_exe, str(scripts_dir / f"build_{model_name}_engines.py") ] + model_cache_arg + only_fp8_arg + only_multi_gpu_arg + enc_dec_model_name_arg + + if model_name in ['gpt']: + build_engines += ['--clean'] run_command(build_engines, cwd=root_dir, env=model_env, timeout=1800) model_env["PYTHONPATH"] = "examples" @@ -415,6 +418,10 @@ def prepare_model_tests(model_name: str, ] + only_fp8_arg + only_multi_gpu_arg + enc_dec_model_name_arg if "enc_dec" in model_name: generate_expected_output += model_cache_arg + + if model_name in ['gpt']: + generate_expected_output += ['--clean'] + if only_multi_gpu_arg and model_name != 'enc_dec': for world_size in (2, 4): generate_command = [ @@ -543,6 +550,16 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500): ] run_command(mpi_utils_test, cwd=tests_dir, env=cpp_env, timeout=300) + # Cache transceiver tests + cache_trans_test = [ + "mpirun", + "-n", + "2", + "--allow-run-as-root", + "batch_manager/cacheTransceiverTest", + ] + run_command(cache_trans_test, cwd=tests_dir, env=cpp_env, timeout=300) + xml_output_file = build_dir / "results-multi-gpu-real-decoder.xml" trt_model_test = produce_mpirun_command( global_commands=["mpirun", "--allow-run-as-root"], @@ -653,7 +670,7 @@ def run_benchmarks(model_name: str, python_exe: str, root_dir: _pl.Path, if model_name == "gpt": input_file = 'input_tokens.npy' model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.CONTINUOUS) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.CONTINUOUS) model_spec_obj.use_gpt_plugin() model_engine_path = model_engine_dir / model_spec_obj.get_model_path( ) / "tp1-pp1-gpu" @@ -694,7 +711,7 @@ def run_benchmarks(model_name: str, python_exe: str, root_dir: _pl.Path, if model_name == "gpt": input_file = 'input_tokens.npy' model_spec_obj = model_spec.ModelSpec(input_file, _tb.DataType.HALF) - model_spec_obj.set_kv_cache_type(model_spec.KVCacheType.PAGED) + model_spec_obj.set_kv_cache_type(_tb.KVCacheType.PAGED) model_spec_obj.use_gpt_plugin() model_spec_obj.use_packed_input() model_engine_path = model_engine_dir / model_spec_obj.get_model_path( @@ -887,7 +904,7 @@ def run_benchmarks(model_name: str, python_exe: str, root_dir: _pl.Path, from build_engines_utils import init_model_spec_module - init_model_spec_module() + init_model_spec_module(force_init_trtllm_bindings=False) if test_args.run_all_models: test_args.run_gpt = True diff --git a/cpp/tests/runtime/gptDecoderTest.cpp b/cpp/tests/runtime/gptDecoderTest.cpp index e76eebb47..36ed4a4d2 100644 --- a/cpp/tests/runtime/gptDecoderTest.cpp +++ b/cpp/tests/runtime/gptDecoderTest.cpp @@ -39,7 +39,7 @@ bool forwardAndSync(std::unique_ptr const& decoder, DecodingOutput& BufferManager::ITensorPtr finishedSum; std::int32_t* finishedSumHost = nullptr; - if (input.sequenceLimitLength && output.finished) + if (input.sequenceLimitLength && output.finishReasons) { finishedSumHost = bufferCast(*output.finishedSum); for (SizeType32 bi = 0; bi < maxBatchSize; ++bi) @@ -52,7 +52,7 @@ bool forwardAndSync(std::unique_ptr const& decoder, DecodingOutput& if (finishedSumHost) { - auto const numToFinish = output.finished->getSize(); + auto const numToFinish = output.finishReasons->getSize(); TLLM_CUDA_CHECK(::cudaStreamSynchronize(stream->get())); SizeType32 finishedSum = 0; @@ -149,10 +149,10 @@ void testDecoder(nvinfer1::DataType const dtype, SamplingConfig const& samplingC std::vector sequenceLengthsVec(batchSize * beamWidth, maxInputLength); outputs.lengths = manager.copyFrom(sequenceLengthsVec, ITensor::makeShape({batchSize, beamWidth}), MemoryType::kGPU); - outputs.finished = manager.gpu(ITensor::makeShape({batchSize, beamWidth}), + outputs.finishReasons = manager.gpu(ITensor::makeShape({batchSize, beamWidth}), TRTDataType::value); - inputs.finished = ITensor::view(outputs.finished); - manager.setZero(*outputs.finished); + inputs.finishReasons = ITensor::view(outputs.finishReasons); + manager.setZero(*outputs.finishReasons); outputs.finishedSum = BufferManager::pinnedPool(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); auto finishedSumHost = bufferCast(*outputs.finishedSum); for (SizeType32 bi = 0; bi < batchSize; ++bi) @@ -227,7 +227,7 @@ void testDecoder(nvinfer1::DataType const dtype, SamplingConfig const& samplingC { finishedSum += finishedSumHost[bi]; } - EXPECT_EQ(finishedSum, outputs.finished->getSize()); + EXPECT_EQ(finishedSum, outputs.finishReasons->getSize()); } } diff --git a/cpp/tests/runtime/utilsTest.cpp b/cpp/tests/runtime/utilsTest.cpp new file mode 100644 index 000000000..a7a6aa1e6 --- /dev/null +++ b/cpp/tests/runtime/utilsTest.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/iBuffer.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/utils/numpyUtils.h" + +#include +#include + +#include +#include +#include +#include + +using namespace tensorrt_llm::runtime; +namespace tc = tensorrt_llm::common; +namespace fs = std::filesystem; + +class UtilsTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init) +{ +protected: + void SetUp() override + { + mDeviceCount = tc::getDeviceCount(); + if (mDeviceCount == 0) + GTEST_SKIP(); + + mStream = std::make_unique(); + mManager = std::make_unique(mStream); + } + + void TearDown() override {} + + int mDeviceCount; + std::unique_ptr mManager; + BufferManager::CudaStreamPtr mStream; +}; + +TEST_F(UtilsTest, LoadNpy) +{ + auto const testResourcePath = fs::path{TOP_LEVEL_DIR} / "cpp/tests/resources"; + auto const inputFile = testResourcePath / "data/input_tokens.npy"; + + auto loadedTensor = utils::loadNpy(*mManager, inputFile.string(), MemoryType::kCPU); + + ASSERT_EQ(loadedTensor->getSize(), 96); + EXPECT_EQ(loadedTensor->getShape().nbDims, 2); + EXPECT_EQ(loadedTensor->getShape().d[0], 8); + EXPECT_EQ(loadedTensor->getShape().d[1], 12); +} + +TEST_F(UtilsTest, LoadStoreNpy) +{ + auto dims = ITensor::makeShape({2, 3, 4}); + auto constexpr dataType = nvinfer1::DataType::kFLOAT; + ITensor::SharedPtr tensor{BufferManager::cpu(dims, dataType)}; + auto tensorRange = BufferRange(*tensor); + std::iota(tensorRange.begin(), tensorRange.end(), 0); + + std::string filename{"tensor.npy"}; + utils::saveNpy(*mManager, *tensor, filename); + auto loadedTensor = utils::loadNpy(*mManager, filename, MemoryType::kCPU); + + ASSERT_EQ(loadedTensor->getSize(), tensor->getSize()); + EXPECT_EQ(loadedTensor->getShape().nbDims, tensor->getShape().nbDims); + EXPECT_EQ(loadedTensor->getShape().d[0], tensor->getShape().d[0]); + EXPECT_EQ(loadedTensor->getShape().d[1], tensor->getShape().d[1]); + EXPECT_EQ(loadedTensor->getShape().d[2], tensor->getShape().d[2]); + + auto loadedTensorRange = BufferRange(*loadedTensor); + for (size_t i = 0; i < tensor->getSize(); ++i) + { + EXPECT_EQ(loadedTensorRange[i], tensorRange[i]); + } +} + +TEST_F(UtilsTest, LoadStoreNpyGPU) +{ + auto dims = ITensor::makeShape({2, 3, 4}); + auto constexpr dataType = nvinfer1::DataType::kFLOAT; + ITensor::SharedPtr tensor{BufferManager::cpu(dims, dataType)}; + auto tensorRange = BufferRange(*tensor); + std::iota(tensorRange.begin(), tensorRange.end(), 0); + + auto deviceTensor = mManager->copyFrom(*tensor, MemoryType::kGPU); + + std::string filename{"tensor.npy"}; + utils::saveNpy(*mManager, *deviceTensor, filename); + auto loadedTensor = utils::loadNpy(*mManager, filename, MemoryType::kGPU); + + ASSERT_EQ(loadedTensor->getSize(), tensor->getSize()); + EXPECT_EQ(loadedTensor->getShape().nbDims, tensor->getShape().nbDims); + EXPECT_EQ(loadedTensor->getShape().d[0], tensor->getShape().d[0]); + EXPECT_EQ(loadedTensor->getShape().d[1], tensor->getShape().d[1]); + EXPECT_EQ(loadedTensor->getShape().d[2], tensor->getShape().d[2]); + + auto hostTensor = mManager->copyFrom(*loadedTensor, MemoryType::kCPU); + + auto loadedTensorRange = BufferRange(*hostTensor); + for (size_t i = 0; i < tensor->getSize(); ++i) + { + EXPECT_EQ(loadedTensorRange[i], tensorRange[i]); + } +} diff --git a/docs/source/architecture/model-weights-loader.md b/docs/source/architecture/model-weights-loader.md new file mode 100644 index 000000000..8b5c02369 --- /dev/null +++ b/docs/source/architecture/model-weights-loader.md @@ -0,0 +1,254 @@ +# TensorRT-LLM Model Weights Loader + +## Overview + +The weights loader is designed for easily converting and loading external weight checkpoints into TensorRT-LLM models. + +## Workflow + +Weight checkpoints can be generated from all sources, and may have different naming and data layouts compared to TRT-LLM's requirements. E.g.: + +```bash +# HuggingFace LLaMA checkpoints +{ + "model.embed_tokens.weight": torch.Tensor([vocab_size, hidden_size]) + "model.layers.0.input_layernorm.weight": torch.Tensor([hidden_size]), + "model.layers.0.mlp.down_proj.weight": torch.Tensor([hidden_size, inter_size]), + "model.layers.0.mlp.gate_proj.weight": torch.Tensor([inter_size, hidden_size]), + "model.layers.0.mlp.up_proj.weight": torch.Tensor([inter_size, hidden_size]), + "model.layers.0.post_attention_layernorm.weight": torch.Tensor([hidden_size]), + "model.layers.0.self_attn.q_proj.weight": torch.Tensor([hidden_size, hidden_size]), + "model.layers.0.self_attn.k_proj.weight": torch.Tensor([hidden_size, hidden_size]), + "model.layers.0.self_attn.v_proj.weight": torch.Tensor([hidden_size, hidden_size]), + "model.layers.0.self_attn.o_proj.weight": torch.Tensor([hidden_size, hidden_size]), + ..., +} +# TensorRT-LLM expected weights +{ + "transformer.vocab_embedding.weight": torch.Tensor([vocab_size, hidden_size]) + "transformer.layers.0.input_layernorm.weight": torch.Tensor([hidden_size]), + "transformer.layers.0.mlp.down_proj.weight": torch.Tensor([hidden_size, inter_size]), + "transformer.layers.0.mlp.gate_proj.weight": torch.Tensor([inter_size, hidden_size]), + "transformer.layers.0.mlp.up_proj.weight": torch.Tensor([inter_size, hidden_size]), + "transformer.layers.0.post_layernorm.weight": torch.Tensor([hidden_size]), + "transformer.layers.0.attention.qkv.weight": torch.Tensor([hidden_size * 3, hidden_size]), # Different layout + "transformer.layers.0.attention.dense.weight": torch.Tensor([hidden_size, hidden_size]), + ..., +} +``` + +Conversion means converting the dictionary of `{external_keys:external_weights}` into `{tllm_keys:tllm_weights}`, it includes changing the naming logic and data layouts, and is contains of the following parts: + +1. Translate a TRT-LLM parameter name into external-format name(s). +2. Loading tensor slice(s) according to the translated names. +3. Postprocess the tensor(s) into target layout. + +### Translator + +TRT-LLM parameter names are translated in units of sections divided by dots. E.g.: + +| TensorRT-LLM key | `transformer` |.| `layers` |.| `0` |.| `attention` |.| `dense` |.| `weight` | +| :---------------------: | :-----------: |-| :------: |-|:---:|-| :---------: |-| :------: |-| :------: | +| Translated external key | `model` |.| `layers` |.| `0` |.| `self_attn` |.| `o_proj` |.| `weight` | + +The mapping between TRT-LLM keywords and HF keywords are described in `tllm_to_externel_key_dict` of `ModelWeightsLoader` class object. \ +If any of the mappings has one-to-multiple corresponding, the translated key will get multiplied accordingly. E.g.: + +| TensorRT-LLM key and related keyword mapping | Translated external keys | +| :----------------------------------------------------------: | :----------------------: | +| `transformer.layers.0.attention.qkv.weight`
`{"qkv":[q_proj, k_proj, v_proj]}` | `model.layers.0.self_attn.q_proj.weights`
`model.layers.0.self_attn.k_proj.weights`
`model.layers.0.self_attn.v_proj.weights`| +| `transformer.layers.0.mlp.fc.weight`
`{"weight":[qweight, qzeros, scales]}` | `model.layers.0.mlp.gate_proj.qweight`
`model.layers.0.mlp.gate_proj.qzeros`
`model.layers.0.mlp.gate_proj.scales`| + +The default `tllm_to_externel_key_dict` is based on HF LLaMA as: + +```python +class ModelWeightsLoader: + def __init__(self, model_dir, customized_key_dict: dict = {}) -> None: + ... + self.tllm_to_externel_key_dict = { + "transformer": "model", + "vocab_embedding": "embed_tokens", + "lm_head": "lm_head", + "ln_f": "norm", + "attention": "self_attn", + "qkv": ["q_proj", "k_proj", "v_proj"], + "dense": "o_proj", + "gate": "up_proj", + "proj": "down_proj", + "fc": "gate_proj", + "input_layernorm": "input_layernorm", + "post_layernorm": "post_attention_layernorm", + } + self.tllm_to_externel_key_dict.update(customized_key_dict) + ... +``` + +It can be updated through passing `customized_key_dict` when initializing `ModelWeightsLoader`. + +The dictionary will also get updated according to the layer classes. When iterating over parameters, +if the layer class has attribute `tllm_to_externel_key_dict`, for keywords exist both in the default one and the layer-specified one, +the weight loader will translate according to the layer attribute with higher priority. +This can enable the support for different quantization precisions automatically. + + +### Loading function + +The loading function can load an arbitrary tensor slice according to its `key`, `tp_size`, `tp_dim` and `tp_rank`. + +The template for loading function is as following. + +```python +def load_tensor(self, key, tp_size, tp_dim, tp_rank): + # Retrieve file pointer index + if key in self.shard_map: + ptr_idx = self.shard_map[key] + else: + return None + + # Load tensor from the corresponding shard + if self.format == ModelWeightsFormat.SAFETENSORS: + tensor = self.shards[ptr_idx].get_slice(key) + tensor_shape = tensor.get_shape() + else: + ... + + # Shard and return a tensor slice + slice_shape = ... + return tensor[slice_shape] +``` + +When initializing the `ModelWeightsLoader` object, the file format will be derived from `model_dir` through `detect_format`. The following formats are supported for now: + + * Directory contains or file named `*.safetensors` (Recommended, has better performance) + * Directory contains or file named `*.bin` + * Directory contains or file named `*.pth` + +To support other formats or in-memory loaded models, the format need to be claimed in `ModelWeightsFormat`, `detect_format()`, `preload()` and `load_tensor()`. + +### Postprocessing functions + +After translation and loading, a TRT-LLM key will become a tensor or a list of tensors, which is the input of postprocessing functions. \ +Operations including QKV concatenating, MoE weight stacking and weight-only quantization can be handled here. +The template of postprocessing function is: + +```python +# Example for 1-1 weights mapping +class CustomizedModuleA(Module): + def __init__(...): + super().__init__(...) + ... + self.tp_dim = 0 # Need to set or inherit from parent class + + def postprocess(self, tllm_key, weights, **kwargs): + weights = proc(weights) + return {tllm_key: weights} + +# Example for multiple-multiple weights mapping +class CustomizedModuleB(Module): + def __init__(...): + super().__init__(...) + ... + self.tp_dim = 0 # Need to set or inherit from parent class + # The default value of "weights" in tllm_to_externel_key_dict will be override + self.tllm_to_externel_key_dict = {"weight": ["qweight", "qzeros", "scales"]} + + def postprocess(self, tllm_key, weights, **kwargs): + # Skipped the postprocess of zeros and weights_scaling_factor + # They are loaded in the postprocess of weight + config = kwargs.get("config", None) # Passed through kwargs by default + if not tllm_key.endswith("weight"): + return {} + # The order in weights is defined in tllm_to_externel_key_dict + qweight, qzeros, scales = weights + proccessed_weight, proccessed_zeros = proc(qweight, qzeros, config.num_heads) + return { + tllm_key: proccessed_weight, + tllm_key.replace("weight", "zeros"): proccessed_zeros, + tllm_key.replace("weight", "weights_scaling_factor"): scales, + } +``` + +## Examples + +The `ModelWeightsLoader` class can support different models with the following levels: + +### Natively supported models +For models with native support, users can call the default weight loader without any other operations. +```python +# Using the model weights loader for LLaMA +from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader +loader = ModelWeightsLoader(external_checkpoint_dir) +loader.generate_tllm_weights(trtllm_model) +``` +For calibration-free quantization precisions, passing a properly quantized `trtllm_model` will let the weight loader load at the given precision accordingly. The configurations will be read from `trtllm_model.config` automatically. For now, LLaMA family models using the default `tllm_to_externel_key_dict` is supported natively. + +### Models with customized key names +For models with different naming logic, users can still call the default weight loader with `customized_key_dict` specified. +```python +# Using the model weights loader for the LLM part of LLaVA +from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader +llava_dict = { + "transformer": "language_model.model", + "lm_head": "language_model.lm_head" +} +loader = ModelWeightsLoader(external_checkpoint_dir, llava_dict) +loader.generate_tllm_weights(trtllm_model) +``` +Users need to specify the different part from the default `tllm_to_externel_key_dict`. The loader still have support across different precisions. + +### Models with customized weight layout +For models with different weight layout, users can write the conversion loop explicitly and do customized operations. +```python +# Using the model weights loader for BLOOM +from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader +bloom_dict = { + "transformer": "", + "layers": "h", + "ln_f": "ln_f", + "lm_head": "word_embeddings", + "ln_embed": "word_embeddings_layernorm", + "vocab_embedding": "word_embeddings", + "attention": "self_attention", + "qkv": "query_key_value", + "dense": "dense", + "fc": "dense_h_to_4h", + "proj": "dense_4h_to_h", + "post_layernorm": "post_attention_layernorm", +} +loader = ModelWeightsLoader(external_checkpoint_dir, bloom_dict) +# See ModelWeightsLoader.generate_tllm_weights() +loader.update_key_mapping(trtllm_model) +tllm_weights = {} +for tllm_key, _ in tqdm(trtllm_model.named_parameters()): + if tllm_key.endswith("qkv"): + # Passing the callable handle + tllm_weights.update(loader.load(tllm_key, preprocess=customized_preprocess)) + else: + tllm_weights.update(loader.load(tllm_key)) +loader.check(tllm_weights) +``` +This will apply `preprocess` after `load_tensor()` and before `postprocess`, and demonstrates how to convert the loaded shard into default HF layout. The loader still have support for precisions quantized from FP16/BF16 (e.g. INT8-wo/INT4-wo), the other precisions may require special operations, and can be addressed inside the `preprocess` function. + +### Fully customized +If the model weights loader cannot satisfy the requirements, users can write the conversion loop totally on their own. +```python +tllm_weights = {} +for tllm_key, param in tqdm(trtllm_model.named_parameters()): + # Load from external checkpoints + # The load_tensor() function can also be called here + tensor = ... + # Convert tensor and set the values according to the config + if trtllm_model.config.quantization.quant_algo == xxx: + ... + else: + ... + param.value = tensor +``` +In this mode, every precision require user's own support. + +## Trouble shooting +The weights loader is an experimental feature fow now, and is enabled for LLaMA family models by default. + +If users are encountered with failure caused by `ModelWeightsLoader`, a workaround is passing environmental variable `TRTLLM_DISABLE_UNIFIED_CONVERTER=1` to disable the model weights loader and fallback to the legacy path. + +This workaround will be removed in future version after the LLaMA weights conversion is stable. diff --git a/docs/source/conf.py b/docs/source/conf.py index 2455ed8ee..59d60e9c0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -71,6 +71,12 @@ print('CPP_GEN_DIR', CPP_GEN_DIR) +def setup(app): + from docs.source.generate_examples import generate_examples + + generate_examples() + + def gen_cpp_doc(ofile_name: str, header_dir: str, summary: str): cpp_header_files = [ file for file in os.listdir(header_dir) if file.endswith('.h') diff --git a/docs/source/executor.md b/docs/source/executor.md index 076f718f3..3c7964614 100644 --- a/docs/source/executor.md +++ b/docs/source/executor.md @@ -22,7 +22,7 @@ Users can alter the logits produced by the network, by providing a map of named ``` std::unordered_map)>> ``` -to the `ExecutorConfig`. The map key is the name associated with that logits post-processing callback. Each request can then specify the name of the logits post-processor to use for that particular request, if any. +to an instance of `LogitsPostProcessorConfig`. The map key is the name associated with that logits post-processing callback. Each request can then specify the name of the logits post-processor to use for that particular request, if any. The first argument to the callback is the request id, second is the logits tensor, third are the tokens produced by the request so far, fourth is the operation stream used by the logits tensor, and last one is an optional client id. The callback returns a modified tensor of logits. @@ -37,14 +37,14 @@ We also provide a batched version that allows altering logits of multiple reques std::function const&, std::vector&, std::vector> const&, StreamPtr const&, std::vector> const&)> ``` -A single batched callback can be specified in `ExecutorConfig`. Each request can opt to apply this callback by specifying the name of the logits +A single batched callback can be specified in `LogitsPostProcessorConfig`. Each request can opt to apply this callback by specifying the name of the logits post-processor as `Request::kBatchedPostProcessorName`. Note: Neither callback variant is supported with the `STATIC` batching type for the moment. In a multi-GPU run, callback is invoked on all tensor parallel ranks (in last pipeline rank) by default. For correct execution, user should replicate client-side state accessed by callback on all tensor parallel ranks. -If replication is expensive or infeasible, use `ExecutorConfig::setReplicateLogitsPostProcessor(false)` to invoke callback only on first tensor parallel rank. +If replication is expensive or infeasible, use `LogitsPostProcessorConfig::setReplicate(false)` to invoke callback only on first tensor parallel rank. ### The Request Class diff --git a/docs/source/generate_examples.py b/docs/source/generate_examples.py new file mode 100644 index 000000000..5dc0c6d90 --- /dev/null +++ b/docs/source/generate_examples.py @@ -0,0 +1,50 @@ +from pathlib import Path + + +def underline(title: str, character: str = "=") -> str: + return f"{title}\n{character * len(title)}" + + +def generate_title(filename: str) -> str: + # Turn filename into a title + title = filename.replace("_", " ").title() + # Underline title + title = underline(title) + return title + + +def generate_examples(): + root_dir = Path(__file__).parent.parent.parent.resolve() + + # Source paths + script_dir = root_dir / "examples/high-level-api" + script_paths = sorted(script_dir.glob("*.py")) + + # Destination paths + doc_dir = root_dir / "docs/source/high-level-api-examples" + doc_paths = [doc_dir / f"{path.stem}.rst" for path in script_paths] + + # Generate the example docs for each example script + for script_path, doc_path in zip(script_paths, doc_paths): + if script_path.name == '__init__.py': + continue + script_url = f"https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/high-level-api/{script_path.name}" + + # Make script_path relative to doc_path and call it include_path + include_path = '../../..' / script_path.relative_to(root_dir) + content = (f"{generate_title(doc_path.stem)}\n\n" + f"Source {script_url}.\n\n" + f".. literalinclude:: {include_path}\n" + " :language: python\n" + " :linenos:\n") + with open(doc_path, "w+") as f: + f.write(content) + + # Generate the toctree for the example scripts + with open(doc_dir / "examples_index.template.rst") as f: + examples_index = f.read() + with open(doc_dir / "high_level_api_examples.rst", "w+") as f: + example_docs = "\n ".join(path.stem for path in script_paths + if path.stem.find("__init__") == -1) + + f.write(examples_index.replace(r"%EXAMPLE_DOCS%", example_docs)) diff --git a/docs/source/high-level-api-examples/advanced.md b/docs/source/high-level-api-examples/advanced.md new file mode 100644 index 000000000..a1addcc5f --- /dev/null +++ b/docs/source/high-level-api-examples/advanced.md @@ -0,0 +1,140 @@ +# Advanced Usage + +## Quantization +By simply setting several flags in the `LLM`, TensorRT-LLM can quantize the HuggingFace model automatically. For example, to perform an Int4 AWQ quantization, the following code triggers the model quantization. + +``` python +from tensorrt_llm.hlapi import QuantConfig, QuantAlgo + +quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ) + +llm = LLM(, quant_config=quant_config) +``` + + +## Customization + +### Customizing sampling with `SamplingParams` +With SamplingParams, you can customize the sampling strategy, such as beam search, temperature, and so on. + +To enable beam search with a beam size of 4, set the `sampling_params` as follows: + +```python +from tensorrt_llm.hlapi import LLM, SamplingParams, BuildConfig + +build_config = BuildConfig() +build_config.max_beam_width = 4 + +llm = LLM(, build_config=build_config) +# Let the LLM object generate text with the default sampling strategy, or +# you can create a SamplingParams object as well with several fields set manually +sampling_params = SamplingParams(beam_width=4) # current limitation: beam_width should be equal to max_beam_width + +for output in llm.generate(, sampling_params=sampling_params): + print(output) +``` + +`SamplingParams` manages and dispatches fields to C++ classes including: +* [SamplingConfig](https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#_CPPv4N12tensorrt_llm7runtime14SamplingConfigE) +* [OutputConfig](https://nvidia.github.io/TensorRT-LLM/_cpp_gen/executor.html#_CPPv4N12tensorrt_llm8executor12OutputConfigE) + +Please refer to these classes for more details. + + +### Build configuration +Apart from the arguments mentioned above, you can also customize the build configuration with the `build_config` class and other arguments borrowed from the lower-level APIs. For example: + +```python +llm = LLM(, + build_config=BuildConfig( + max_new_tokens=4096, + max_batch_size=128, + max_beam_width=4)) +``` + +### Runtime customization +Similar to `build_config`, you can also customize the runtime configuration with the `runtime_config`, `peft_cache_config` or other arguments borrowed from the lower-level APIs. For example: + + +```python +from tensorrt_llm.hlapi import LLM, KvCacheConfig + +llm = LLM(, + kv_cache_config=KvCacheConfig( + max_new_tokens=128, + free_gpu_memory_fraction=0.8)) +``` + +### Tokenizer Customization + +By default, the high-level API uses transformers’ `AutoTokenizer`. You can override it with your own tokenizer by passing it when creating the LLM object. For example: + +```python +llm = LLM(, tokenizer=) +``` + +The LLM() workflow should use your tokenizer instead. + +It is also possible to input token IDs directly without Tokenizers with the following code, note that the result will be also IDs without text since the tokenizer is not used. + +``` python +llm = LLM() + +for output in llm.generate([32, 12]): + ... +``` + +### Disable Tokenizer +For performance considerations, you can disable the tokenizer by passing `skip_tokenizer_init=True` when creating `LLM`. In this case, `LLM.generate` and `LLM.generate_async` will expect prompt token ids as input. For example: + +```python +llm = LLM() +for output in llm.generate([[32, 12]], skip_tokenizer_init=True): + print(output) +``` + +You will get something like: +```python +RequestOutput(request_id=1, prompt=None, prompt_token_ids=[1, 15043, 29892, 590, 1024, 338], outputs=[CompletionOutput(index=0, text='', token_ids=[518, 10858, 4408, 29962, 322, 306, 626, 263, 518, 10858, 20627, 29962, 472, 518, 10858, 6938, 1822, 306, 626, 5007, 304, 4653, 590, 4066, 297, 278, 518, 11947, 18527, 29962, 2602, 472], cumulative_logprob=None, logprobs=[])], finished=True) +``` + +Note that the `text` field in `CompletionOutput` is empty since the tokenizer is deactivated. + + +## Generation + +### `asyncio`-based generation +With the high-level API, you can also perform asynchronous generation with the `generate_async` method. For example: + +```python +llm = LLM(model=) + +async for output in llm.generate_async(, streaming=True): + print(output) +``` + +When the `streaming` flag is set to `True`, the `generate_async` method will return a generator that yields the token results as soon as they are available. Otherwise, it will return a generator that yields the final results only. + +### Future-style generation +The result of the `generate_async` method is a Future-like object, it doesn't block the thread unless the `.result()` is called. + +```python +# This will not block the main thread +generation = llm.generate_async() +# Do something else here +# call .result() to explicitly block the main thread and wait for the result when needed +output = generation.result() +``` + +The `.result()` method works like the [result](https://docs.python.org/zh-cn/3/library/asyncio-future.html#asyncio.Future.result) method in the Python Future, you can specify a timeout to wait for the result. + +```python +output = generation.result(timeout=10) +``` + +There is an async version, where the `.aresult()` is used. + +```python +generation = llm.generate_async() +output = await generation.aresult() +``` diff --git a/docs/source/high-level-api-examples/examples_index.template.rst b/docs/source/high-level-api-examples/examples_index.template.rst new file mode 100644 index 000000000..a9ebc60dd --- /dev/null +++ b/docs/source/high-level-api-examples/examples_index.template.rst @@ -0,0 +1,8 @@ +Examples +================================= + +.. toctree:: + :maxdepth: 2 + :caption: Scripts + + %EXAMPLE_DOCS% diff --git a/docs/source/high-level-api-examples/introduction.md b/docs/source/high-level-api-examples/introduction.md new file mode 100644 index 000000000..f25c1f9f0 --- /dev/null +++ b/docs/source/high-level-api-examples/introduction.md @@ -0,0 +1,46 @@ +# High Level API(HLAPI) Introduction + +## Concept + + +## HLAPI Supported Model +* LLaMA (including variants Mistral, Mixtral, InternLM) +* GPT (including variants Starcoder-1/2, Santacoder) +* Gemma-1/2 +* Phi-1/2/3 +* ChatGLM (including variants glm-10b, chatglm, chatglm2, chatglm3, glm4) +* QWen-1/1.5/2 +* Falcon +* Baichuan-1/2 +* GPT-J + +## Model Preparation +The `LLM` class supports the following types of model inputs: + +1. **Hugging Face model name**: triggers a download from the Hugging Face model hub, e.g. `TinyLlama/TinyLlama-1.1B-Chat-v1.0` in the quickstart. +2. **Local Hugging Face models**: uses a locally stored Hugging Face model. +3. **Local TensorRT-LLM engine**: built by `trtllm-build` tool or saved by the HLAPI + + +All kinds of the model inputs can be seamlessly integrated with the HLAPI, and the `LLM(model=)` construcotr can accommodate models in any of the above formats. + +Let's delve into the preparation of the three kinds of local model formats. + +### Option 1: From Hugging Face models +Given its popularity, the TensorRT-LLM HLAPI chooses to support Hugging Face format as one of the start points, to use the HLAPI on LLaMA3.1 models, you need to download the model from [LLaMA3.1 8B model page](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) via below command +```bash +git clone https://huggingface.co/meta-llama/Meta-Llama-3.1-8B +``` + +### Option 2: From TensorRT-LLM engine +There are two ways to build the TensorRT-LLM engine: + +1. You can build the TensorRT-LLM engine from the Hugging Face model directly with the `trtllm-build` tool, and save the engine to disk for later use. Please consult the LLaMA's [README](../llama/README.md). +2. Use the HLAPI to save one: + +```python +llm = LLM() + +# Save engine to local disk +llm.save() +``` diff --git a/docs/source/index.rst b/docs/source/index.rst index bd9c31702..043e684e7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -84,6 +84,30 @@ Welcome to TensorRT-LLM's Documentation! _cpp_gen/executor.rst _cpp_gen/runtime.rst + +.. toctree:: + :maxdepth: 2 + :caption: High Level API Examples + :hidden: + + high-level-api-examples/high_level_api_examples + high-level-api-examples/introduction.md + high-level-api-examples/advanced.md + + +.. toctree:: + :maxdepth: 2 + :caption: Python API + :hidden: + + python-api/tensorrt_llm.layers.rst + python-api/tensorrt_llm.functional.rst + python-api/tensorrt_llm.models.rst + python-api/tensorrt_llm.plugin.rst + python-api/tensorrt_llm.quantization.rst + python-api/tensorrt_llm.runtime.rst + + .. toctree:: :maxdepth: 2 :caption: Blogs diff --git a/docs/source/media/picture-08-06-2024.png b/docs/source/media/picture-08-06-2024.png deleted file mode 100644 index 5eeacbdf7..000000000 Binary files a/docs/source/media/picture-08-06-2024.png and /dev/null differ diff --git a/docs/source/media/picture-08-13-2024.png b/docs/source/media/picture-08-13-2024.png new file mode 100644 index 000000000..02ec7b1a1 Binary files /dev/null and b/docs/source/media/picture-08-13-2024.png differ diff --git a/docs/source/performance/perf-best-practices.md b/docs/source/performance/perf-best-practices.md index a8b4c7622..845337ed3 100644 --- a/docs/source/performance/perf-best-practices.md +++ b/docs/source/performance/perf-best-practices.md @@ -205,7 +205,7 @@ downside is slight reduction of accuracy because one of the quantization scaling factors are discarded. If both model and batch sizes are large, it is recommended to enable the feature -by using the `--use_fused_mlp` argument with `trtllm-build`. When the workload +by using the `--use_fused_mlp=enable` argument with `trtllm-build`. When the workload is very small, or if you're using FP8 PTQ and the accuracy after enabling it does not satisfy your requirement, it is not recommended to enable that feature. @@ -217,7 +217,7 @@ the downside is slight reduction of accuracy because one of the quantization scaling factors are discarded. If model is large and you are running it on Hopper with FP8 precision, it is -recommended to enable the feature by using the `--use_fused_mlp --gemm_swiglu_plugin fp8` +recommended to enable the feature by using the `--use_fused_mlp=enable --gemm_swiglu_plugin fp8` argument with `trtllm-build`. When the workload is very small, or the accuracy after enabling it does not satisfy your requirement, it is not recommended to enable that feature. diff --git a/docs/source/performance/perf-overview.md b/docs/source/performance/perf-overview.md index f8718c04d..31525584e 100644 --- a/docs/source/performance/perf-overview.md +++ b/docs/source/performance/perf-overview.md @@ -27,9 +27,9 @@ The issue will be addressed in future releases. ### Fused Matmul + Gated-SiLU (LLaMA) The current implementation combines two Matmul operations into one Matmul followed by -a separate SwiGLU kernel (when `--use_fused_mlp` is enabled). There is also a more +a separate SwiGLU kernel (when `--use_fused_mlp=enable` is enabled). There is also a more efficient implementation that runs single Matmul + SwiGLU fused kernel for FP8 on Hopper -(when `--use_fused_mlp --gemm_swiglu_plugin fp8` is enabled). The gemm_swiglu_plugin +(when `--use_fused_mlp=enable --gemm_swiglu_plugin fp8` is enabled). The gemm_swiglu_plugin will support more data types and GPU architectures in the future release. ## Throughput Measurements @@ -160,7 +160,7 @@ The following tables are references for commands that are used as part of the be | Stage | Description | Command | | :- | - | - | -| [Build](#engine-building) | Build a TensorRT-LLM engine | `trtllm-build --model_config $model_cfg --use_fused_mlp --gpt_attention_plugin float16 --output_dir $engine_dir --max_batch_size $max_batch_size --max_input_len 2048 --max_seq_len 2048 --reduce_fusion disable --workers $tp_size --max_num_tokens $max_num_tokens --use_paged_context_fmha enable --multiple_profiles enable` | +| [Build](#engine-building) | Build a TensorRT-LLM engine | `trtllm-build --model_config $model_cfg --use_fused_mlp=enable --gpt_attention_plugin float16 --output_dir $engine_dir --max_batch_size $max_batch_size --max_input_len 2048 --max_seq_len 2048 --reduce_fusion disable --workers $tp_size --max_num_tokens $max_num_tokens --use_paged_context_fmha enable --multiple_profiles enable` | | [Dataset](#preparing-a-dataset) | Create a synthetic dataset | `benchmarks/cpp/prepare_dataset.py --output=$dataset_file --tokenizer=$model_name token-norm-dist --num-requests=2000 --input-mean=$isl --output-mean=$osl --input-stdev=0 --output-stdev=0` | | [Run](#running-the-benchmark) | Run a benchmark with a dataset | `mpirun -n $tp_size --allow-run-as-root --oversubscribe cpp/build/benchmarks/gptManagerBenchmark --engine_dir $engine_dir --type IFB --dataset $dataset_file --eos_id -1 --scheduler_policy guaranteed_no_evict --kv_cache_free_gpu_mem_fraction 0.99 --output_csv result.csv --request_rate -1.0 --enable_chunked_context --warm_up 0` | @@ -193,7 +193,7 @@ for the model that you would like to build (see [below](#network-configuration-f command is as follows: ```shell -trtllm-build --model_config $model_cfg --use_fused_mlp --gpt_attention_plugin float16 --output_dir $engine_dir --max_batch_size $max_batch_size --max_input_len 2048 --max_seq_len 2048 --reduce_fusion disable --workers $tp_size --max_num_tokens $max_num_tokens --use_paged_context_fmha enable --multiple_profiles enable +trtllm-build --model_config $model_cfg --use_fused_mlp=enable --gpt_attention_plugin float16 --output_dir $engine_dir --max_batch_size $max_batch_size --max_input_len 2048 --max_seq_len 2048 --reduce_fusion disable --workers $tp_size --max_num_tokens $max_num_tokens --use_paged_context_fmha enable --multiple_profiles enable ``` Some notes about the command: @@ -253,7 +253,7 @@ input and output sequence legnths within the same model. "kv_cache_quant_algo": "FP8" }, "rotary_dim": 64, - "kv_dtype": "float16" + "quant_dtype": "float16" } ``` @@ -297,7 +297,7 @@ input and output sequence legnths within the same model. "bias": false, "parallel_attention": true, "new_decoder_architecture": true, - "kv_dtype": "float16" + "quant_dtype": "float16" } ``` @@ -332,7 +332,7 @@ input and output sequence legnths within the same model. "quant_algo": "FP8", "kv_cache_quant_algo": "FP8" }, - "kv_dtype": "float16" + "quant_dtype": "float16" } ``` @@ -372,7 +372,7 @@ input and output sequence legnths within the same model. "tp_size": 4, "pp_size": 1 }, - "kv_dtype": "float16" + "quant_dtype": "float16" } ``` @@ -416,7 +416,7 @@ input and output sequence legnths within the same model. "quant_algo": "FP8", "kv_cache_quant_algo": "FP8" }, - "kv_dtype": "float16" + "quant_dtype": "float16" } ``` @@ -456,7 +456,7 @@ input and output sequence legnths within the same model. "quant_algo": "FP8", "kv_cache_quant_algo": "FP8" }, - "kv_dtype": "float16" + "quant_dtype": "float16" } ``` @@ -499,7 +499,7 @@ input and output sequence legnths within the same model. "quant_algo": "FP8", "kv_cache_quant_algo": "FP8" }, - "kv_dtype": "float16" + "quant_dtype": "float16" } ``` diff --git a/docs/source/quick-start-guide.md b/docs/source/quick-start-guide.md index 37eaa643f..31af7d037 100644 --- a/docs/source/quick-start-guide.md +++ b/docs/source/quick-start-guide.md @@ -125,6 +125,36 @@ curl -X POST localhost:8000/v2/models/ensemble/generate -d \ }' ``` +## High Level API +We are working on a Python high-level API(HLAPI) for LLM workflow, which is still in incubation and may change later. +Here we show you a preview of how it works and how to use it. + +Note that the APIs are not stable and only support the few models. We appreciate your patience and understanding as we improve this API. + +Here is a simple example to show how to use the HLAPI with TinyLlama. +```python +from tensorrt_llm import LLM, SamplingParams + +llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +outputs = llm.generate(prompts, sampling_params) + +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +``` + ## Next Steps In this Quick Start Guide, you: diff --git a/examples/apps/requirements.txt b/examples/apps/requirements.txt index 59c3de5ff..03e4094c6 100644 --- a/examples/apps/requirements.txt +++ b/examples/apps/requirements.txt @@ -1,3 +1,4 @@ fastapi uvicorn colorama +httpx diff --git a/examples/baichuan/requirements.txt b/examples/baichuan/requirements.txt index 87feb6958..686aa335a 100644 --- a/examples/baichuan/requirements.txt +++ b/examples/baichuan/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.15.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/bindings/executor/README.md b/examples/bindings/executor/README.md index b33e35a36..e3fdd5d3d 100644 --- a/examples/bindings/executor/README.md +++ b/examples/bindings/executor/README.md @@ -19,6 +19,16 @@ cd examples/bindings python3 example_basic.py --model_path=../llama/tmp/7B/trt_engines/fp16/1-gpu/ ``` +### Debug example + +This example shows how you can define which engine IO tensors should be dumped to numpy files. +Run `example_debug.py`, passing in the directory where the TensorRT engine was generated. For example: + +``` +cd examples/bindings +python3 example_debug.py --model_path=../llama/tmp/7B/trt_engines/fp16/1-gpu/ +``` + ### Advanced example This example shows how you can use the python bindings to generate tokens for a larger number of requests concurrently and demonstrate how tokens can be returned in a streaming fashion. diff --git a/examples/bindings/executor/example_debug.py b/examples/bindings/executor/example_debug.py new file mode 100644 index 000000000..f87b0f026 --- /dev/null +++ b/examples/bindings/executor/example_debug.py @@ -0,0 +1,52 @@ +import argparse +import pathlib as pl + +import numpy as np + +import tensorrt_llm.bindings.executor as trtllm + +# This example hows to use the python bindings to create an executor, enqueue a +# request, and get the generated tokens. + +# First, follow the steps in README.md to generate the engines. + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Executor Bindings Example") + parser.add_argument("--model_path", + type=str, + required=True, + help="Directory containing model engine") + args = parser.parse_args() + + # debug_config = trtllm.DebugConfig(dump_input_tensors=True, + # dump_output_tensors=True, + # debug_tensor_names=["test"]) + + # Select which tensors should be dumped + debug_config = trtllm.DebugConfig(debug_tensor_names=["host_request_types"]) + + # Create the executor. + executor = trtllm.Executor( + args.model_path, trtllm.ModelType.DECODER_ONLY, + trtllm.ExecutorConfig(1, debug_config=debug_config)) + + if executor.can_enqueue_requests(): + # Create the request. + request = trtllm.Request(input_token_ids=[1, 2, 3, 4], max_new_tokens=2) + + # Enqueue the request. + request_id = executor.enqueue_request(request) + + # Wait for the new tokens. + responses = executor.await_responses(request_id) + output_tokens = responses[0].result.output_token_ids + + # Print tokens. + print(output_tokens) + + print("debug tensors:") + debug_dir = pl.Path("/tmp/tllm_debug/PP_1/TP_1") + for iter_dir in [x for x in debug_dir.iterdir() if x.is_dir()]: + print(iter_dir.name) + for file in [x for x in iter_dir.iterdir() if x.is_file()]: + print(file.name, np.load(file)) diff --git a/examples/bindings/executor/example_logits_processor.py b/examples/bindings/executor/example_logits_processor.py index 75c3f3482..cec810455 100644 --- a/examples/bindings/executor/example_logits_processor.py +++ b/examples/bindings/executor/example_logits_processor.py @@ -179,12 +179,14 @@ def logits_post_processor_batched( # Create the executor. executor_config = trtllm.ExecutorConfig(args.beam_width) + logits_proc_config = trtllm.LogitsPostProcessorConfig() if not args.lpp_batched: - executor_config.logits_post_processor_map = { + logits_proc_config.processor_map = { "my_logits_pp": logits_post_processor } else: - executor_config.logits_post_processor_batched = logits_post_processor_batched + logits_proc_config.processor_batched = logits_post_processor_batched + executor_config.logits_post_processor_config = logits_proc_config executor = trtllm.Executor(args.engine_path, trtllm.ModelType.DECODER_ONLY, executor_config) diff --git a/examples/bloom/requirements.txt b/examples/bloom/requirements.txt index bc7c3d95a..e30d6ebae 100644 --- a/examples/bloom/requirements.txt +++ b/examples/bloom/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/chatglm/requirements.txt b/examples/chatglm/requirements.txt index dc884373a..a884a17ab 100644 --- a/examples/chatglm/requirements.txt +++ b/examples/chatglm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 protobuf diff --git a/examples/cpp/executor/CMakeLists.txt b/examples/cpp/executor/CMakeLists.txt index 8ee344bf7..9a8546ede 100644 --- a/examples/cpp/executor/CMakeLists.txt +++ b/examples/cpp/executor/CMakeLists.txt @@ -108,6 +108,9 @@ include_directories(${TRTLLM_INCLUDE_DIR} ${CUDAToolkit_INCLUDE_DIRS}) add_executable(executorExampleBasic executorExampleBasic.cpp) target_link_libraries(executorExampleBasic nvinfer_plugin_tensorrt_llm) +add_executable(executorExampleDebug executorExampleDebug.cpp) +target_link_libraries(executorExampleDebug nvinfer_plugin_tensorrt_llm) + add_executable(executorExampleLogitsProcessor executorExampleLogitsProcessor.cpp) target_link_libraries(executorExampleLogitsProcessor diff --git a/examples/cpp/executor/README.md b/examples/cpp/executor/README.md index 57a1337e3..96e883d6d 100644 --- a/examples/cpp/executor/README.md +++ b/examples/cpp/executor/README.md @@ -33,6 +33,16 @@ From the `examples/cpp/executor/build` folder, you can get run the `executorExam ``` where `` is the path to the directly containing the TensorRT engine files. +### executorExampleDebug + +This example shows how you can define which engine IO tensors should be dumped to numpy files. +From the `examples/cpp/executor/build` folder, you can get run the `executorExampleDebug` example with: + +``` +./executorExampleDebug +``` +where `` is the path to the directly containing the TensorRT engine files. + ### executorExampleAdvanced From the `examples/cpp/executor/build` folder, you can also run the `executorExampleAdvanced` example. To get the full list of supported input arguments, type diff --git a/examples/cpp/executor/executorExampleBasic.cpp b/examples/cpp/executor/executorExampleBasic.cpp index 993eb7383..b3ae33283 100644 --- a/examples/cpp/executor/executorExampleBasic.cpp +++ b/examples/cpp/executor/executorExampleBasic.cpp @@ -46,7 +46,7 @@ int main(int argc, char* argv[]) auto request = tle::Request(inputTokens, maxNewTokens); // Enqueue the request - auto requestId = executor.enqueueRequest(std::move(request)); + auto requestId = executor.enqueueRequest(request); // Wait for the response auto responses = executor.awaitResponses(requestId); diff --git a/examples/cpp/executor/executorExampleDebug.cpp b/examples/cpp/executor/executorExampleDebug.cpp new file mode 100644 index 000000000..4179f6bd9 --- /dev/null +++ b/examples/cpp/executor/executorExampleDebug.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/plugins/api/tllmPlugin.h" + +#include + +namespace tlc = tensorrt_llm::common; +namespace tle = tensorrt_llm::executor; + +int main(int argc, char* argv[]) +{ + // Register the TRT-LLM plugins + initTrtLlmPlugins(); + + if (argc != 2) + { + TLLM_LOG_ERROR("Usage: %s ", argv[0]); + return 1; + } + + // Create the executor for this engine + tle::SizeType32 beamWidth = 1; + auto executorConfig = tle::ExecutorConfig(beamWidth); + // Select which tensors should be dumped + auto debugConfig = tle::DebugConfig(); + debugConfig.setDebugTensorNames({"host_request_types"}); + executorConfig.setDebugConfig(debugConfig); + + auto trtEnginePath = argv[1]; + auto executor = tle::Executor(trtEnginePath, tle::ModelType::kDECODER_ONLY, executorConfig); + + // Create the request + tle::SizeType32 maxNewTokens = 2; + tle::VecTokens inputTokens{1, 2, 3, 4}; + auto request = tle::Request(inputTokens, maxNewTokens); + + // Enqueue the request + auto requestId = executor.enqueueRequest(request); + + // Wait for the response + auto responses = executor.awaitResponses(requestId); + + // Get outputTokens + auto outputTokens = responses.at(0).getResult().outputTokenIds.at(beamWidth - 1); + + TLLM_LOG_INFO("Output tokens: %s", tlc::vec2str(outputTokens).c_str()); + + return 0; +} diff --git a/examples/cpp/executor/executorExampleLogitsProcessor.cpp b/examples/cpp/executor/executorExampleLogitsProcessor.cpp index e3dff7f4c..0913b77b1 100644 --- a/examples/cpp/executor/executorExampleLogitsProcessor.cpp +++ b/examples/cpp/executor/executorExampleLogitsProcessor.cpp @@ -61,9 +61,12 @@ int main(int argc, char* argv[]) // Create the executor for this engine tle::SizeType32 beamWidth = 1; auto executorConfig = tle::ExecutorConfig(beamWidth); - executorConfig.setLogitsPostProcessorMap( - std::unordered_map{ - {logitsPostProcessorName, logitsPostProcessorFn}}); + + auto logitsProcConfig = tle::LogitsPostProcessorConfig(); + logitsProcConfig.setProcessorMap(std::unordered_map{ + {logitsPostProcessorName, logitsPostProcessorFn}}); + executorConfig.setLogitsPostProcessorConfig(logitsProcConfig); + auto trtEnginePath = argv[1]; auto executor = tle::Executor(trtEnginePath, tle::ModelType::kDECODER_ONLY, executorConfig); diff --git a/examples/dbrx/requirements.txt b/examples/dbrx/requirements.txt index 5d81c2f7c..ea7fd8f1e 100644 --- a/examples/dbrx/requirements.txt +++ b/examples/dbrx/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/dit/sample.py b/examples/dit/sample.py index 8c800ba49..8b8f41641 100644 --- a/examples/dit/sample.py +++ b/examples/dit/sample.py @@ -10,7 +10,6 @@ from torchvision.utils import save_image import tensorrt_llm -from tensorrt_llm._ipc_utils import set_peer_access from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch from tensorrt_llm.logger import logger from tensorrt_llm.plugin.plugin import CustomAllReduceHelper @@ -77,11 +76,10 @@ def __init__(self, expected_tensor_names = ['latent', 'timestep', 'label', 'output'] if self.mapping.tp_size > 1: - is_p2p_supported = set_peer_access(self.mapping) self.buffer, self.all_reduce_workspace = CustomAllReduceHelper.allocate_workspace( self.mapping, CustomAllReduceHelper.max_workspace_size_auto( - self.mapping.tp_size), is_p2p_supported) + self.mapping.tp_size)) self.inputs['all_reduce_workspace'] = self.all_reduce_workspace expected_tensor_names += ['all_reduce_workspace'] diff --git a/examples/falcon/requirements.txt b/examples/falcon/requirements.txt index f0c1e9f12..decac46a6 100644 --- a/examples/falcon/requirements.txt +++ b/examples/falcon/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 transformers>=4.31.0 datasets~=2.14.5 evaluate~=0.4.1 diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt index dd5a143ca..c28634625 100644 --- a/examples/gemma/requirements.txt +++ b/examples/gemma/requirements.txt @@ -3,7 +3,7 @@ # WAR the new posting of "nvidia-cudnn-cu12~=9.0". # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9". nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64" -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 flax~=0.8.0 # jax[cuda12_pip]~=0.4.19; platform_system != "Windows" jax~=0.4.19; platform_system == "Windows" diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index b71f09c9d..3ff478877 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptj/requirements.txt b/examples/gptj/requirements.txt index 8d7092105..a4d628574 100644 --- a/examples/gptj/requirements.txt +++ b/examples/gptj/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptneox/requirements.txt b/examples/gptneox/requirements.txt index 76d7fa110..42ecdcecb 100644 --- a/examples/gptneox/requirements.txt +++ b/examples/gptneox/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 rouge_score~=0.1.2 evaluate~=0.4.1 diff --git a/examples/grok/requirements.txt b/examples/grok/requirements.txt index 12eb5ee0f..e2177ded1 100644 --- a/examples/grok/requirements.txt +++ b/examples/grok/requirements.txt @@ -1,6 +1,6 @@ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/high-level-api/README.md b/examples/high-level-api/README.md index 907b61a3b..15213b438 100644 --- a/examples/high-level-api/README.md +++ b/examples/high-level-api/README.md @@ -2,7 +2,18 @@ We are working on a Python high-level API(HLAPI) for LLM workflow, which is still in incubation and may change later. Here we show you a preview of how it works and how to use it. -Note that the APIs are not stable and only support the LLaMA model. We appreciate your patience and understanding as we improve this API. +Note that the APIs are not stable and we appreciate your patience and understanding as we improve this API. + +## HLAPI Supported Model +* LLaMA (including variants Mistral, Mixtral, InternLM) +* GPT (including variants Starcoder-1/2, Santacoder) +* Gemma-1/2 +* Phi-1/2/3 +* ChatGLM (including variants glm-10b, chatglm, chatglm2, chatglm3, glm4) +* QWen-1/1.5/2 +* Falcon +* Baichuan-1/2 +* GPT-J ## Quick start diff --git a/examples/high-level-api/requirements.txt b/examples/high-level-api/requirements.txt index 46456a250..f7e1fd97f 100644 --- a/examples/high-level-api/requirements.txt +++ b/examples/high-level-api/requirements.txt @@ -1,2 +1,2 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 diff --git a/examples/internlm/README.md b/examples/internlm/README.md index d9b9c0e93..aae120486 100644 --- a/examples/internlm/README.md +++ b/examples/internlm/README.md @@ -18,7 +18,7 @@ The TensorRT-LLM InternLM implementation is based on the LLaMA model. The implem be found in [tensorrt_llm/models/llama/model.py](../../tensorrt_llm/models/llama/model.py). The TensorRT-LLM InternLM example code lies in [`examples/llama`](./): -* [`convert_checkpoint.py`](../llama/convert_checkpoint.py) converts the Huggingface Model of Skywork into TensorRT-LLM checkpoint. +* [`convert_checkpoint.py`](../llama/convert_checkpoint.py) converts the Huggingface Model of InternLM into TensorRT-LLM checkpoint. * [`convert_checkpoint.py`] to to convert a checkpoint from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) format to the TensorRT-LLM format In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation: diff --git a/examples/internlm/requirements.txt b/examples/internlm/requirements.txt index d3de57f84..211fa8ceb 100644 --- a/examples/internlm/requirements.txt +++ b/examples/internlm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets==2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/jais/requirements.txt b/examples/jais/requirements.txt index b71f09c9d..3ff478877 100644 --- a/examples/jais/requirements.txt +++ b/examples/jais/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/llama/README.md b/examples/llama/README.md index 5bc69c50d..584bc0b51 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -85,9 +85,9 @@ The defaults have been carefully tuned for better performance. For example, `gpt Normally `trtllm-build` only requires single GPU, but if you've already got all the GPUs needed for inference, you could enable parallel building to make the engine building process faster by adding `--workers` argument. Please note that currently `workers` feature only supports single node. -`--use_fused_mlp` enables GEMM horizontal fusion in gated MLP layer, which reduces input traffic and potentially improves performance. For FP8 PTQ, the downside is slight reduction of accuracy because one of the quantization scaling factors are discarded (accuracy 0.45734 vs 0.45755 for LLaMA-v2 7B using modelopt/examples/hf/instruct_eval/mmlu.py). +`--use_fused_mlp=enable` enables GEMM horizontal fusion in gated MLP layer, which reduces input traffic and potentially improves performance. For FP8 PTQ, the downside is slight reduction of accuracy because one of the quantization scaling factors are discarded (accuracy 0.45734 vs 0.45755 for LLaMA-v2 7B using modelopt/examples/hf/instruct_eval/mmlu.py). -`--use_fused_mlp --gemm_swiglu_plugin ` fuses 2 GEMMs without biases and SwiGLU into one kernel. This is a preview feature and is only supported for dtype `fp8`. The supported architecture is SM90. +`--use_fused_mlp=enable --gemm_swiglu_plugin ` fuses 2 GEMMs without biases and SwiGLU into one kernel. This is a preview feature and is only supported for dtype `fp8`. The supported architecture is SM90. Here're some examples: diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index 6ae2e6d9f..4656b4f4c 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/mamba/README.md b/examples/mamba/README.md index b218ae070..90b9f6573 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -25,10 +25,10 @@ In addition, there are two shared files in the parent folder [`examples`](../) f ## Support Matrix -| Model Name | FP16 | BF16 | -| :--------------: | :---: | :---: | -| Mamba1 | Y | Y | -| Mamba2 | Y | Y | +| Model Name | FP16 | BF16 | TP | +| :--------------: | :---: | :---: | :-: | +| Mamba1 | Y | Y | N | +| Mamba2 | Y | Y | Y | * Mamba2: TensorRT-LLM can only support the pure Mamba model for now, will support the hybrid models later. @@ -78,6 +78,9 @@ git clone https://huggingface.co/mistralai/mathstral-7B-v0.1 ./mamba_model/maths ### 2. Convert weights from HF Transformers to TensorRT-LLM format The [`convert_checkpoint.py`](./convert_checkpoint.py) script converts HF weights to TensorRT-LLM checkpoints. +For the Mamba2 models, if they can support tensor parallelism, you can run them with 1, 2, 4 or 8 GPUs. Here we use +mamba-codestral-7B-v0.1 as an example. + ```bash # mamba-2.8b python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b/ \ @@ -103,6 +106,12 @@ python convert_checkpoint.py --model_dir ./mamba_model/mamba2-130m/ \ python convert_checkpoint.py --model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \ --dtype float16 \ --output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/1-gpu/ + +# mamba-codestral-7B-v0.1 with 2-way tensor parallelism. +python convert_checkpoint.py --model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \ + --dtype float16 \ + --world_size 2 \ + --output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/2-gpu/ ``` ### 3. Build TensorRT engine(s) @@ -153,6 +162,15 @@ trtllm-build --checkpoint_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp1 --max_input_len 924 \ --max_seq_len 1024 \ --output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/1-gpu/ + +# mamba-codestral-7B-v0.1 with 2-way tensor parallelism. +trtllm-build --checkpoint_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/2-gpu/ \ + --paged_kv_cache disable \ + --gemm_plugin auto \ + --max_batch_size 8 \ + --max_input_len 924 \ + --max_seq_len 1024 \ + --output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/2-gpu/ ``` Note that when building Mamba models, you need to disable the `paged_kv_cache` as it is used for @@ -200,4 +218,12 @@ python ../summarize.py --test_trt_llm \ --tokenizer_dir ./mamba_model/mathstral-7B-v0.1/ \ --data_type fp16 \ --engine_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/1-gpu/ + +# mamba-codestral-7B-v0.1 with 2-way tensor parallelism. +mpirun -n 2 --allow-run-as-root \ + python ../summarize.py --test_trt_llm \ + --hf_model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \ + --tokenizer_dir ./mamba_model/mathstral-7B-v0.1/ \ + --data_type fp16 \ + --engine_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/2-gpu/ ``` diff --git a/examples/mamba/convert_checkpoint.py b/examples/mamba/convert_checkpoint.py index 831ceab46..6c7d6ed4c 100644 --- a/examples/mamba/convert_checkpoint.py +++ b/examples/mamba/convert_checkpoint.py @@ -23,6 +23,10 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', type=Path, default=None) + parser.add_argument("--world_size", + type=int, + default=1, + help="world size, only support tensor parallelism now") parser.add_argument('--dtype', type=str, default='float16', @@ -60,6 +64,14 @@ def get_tllm_linear_weight(weight, prefix, bias=None): return results +def split(v, tp_size, idx, dim=0): + assert v.shape[dim] % tp_size == 0 + split_size = v.shape[dim] // tp_size + if tp_size == 1: + return v + return torch.split(v, split_size, dim=dim)[idx] + + def convert_hf_mamba(hf_mamba, rank=0, dtype='float32', @@ -160,13 +172,18 @@ def rename_hf_to_tllm(name: str): return name -def convert_from_hf_checkpoint(model_dir: Union[str, Path], +def convert_from_hf_checkpoint(mamba_config: dict, + model_dir: Union[str, Path], rank=0, dtype: Union[str, torch.dtype] = torch.float32, mamba_version: str = 'Mamba1'): logger.info('Loading weights from HF Mamba...') tik = time.time() + tp_rank = rank + tp_size = mamba_config['mapping']['tp_size'] + d_inner = mamba_config['rnn_hidden_size'] + d_state = mamba_config['state_size'] weights = {} if isinstance(dtype, str): dtype = tensorrt_llm.str_dtype_to_torch(dtype) @@ -196,6 +213,44 @@ def convert_from_hf_checkpoint(model_dir: Union[str, Path], in_proj_params = torch.split(param, param.size(0) // 2, dim=0) weights[tllm_name.replace('proj', 'proj_x')] = in_proj_params[0] weights[tllm_name.replace('proj', 'proj_z')] = in_proj_params[1] + elif 'in_proj' in name and mamba_version == 'Mamba2': + nheads = d_inner // mamba_config['rnn_head_size'] + ngroups = mamba_config['ngroups'] + in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt = torch.split( + param, [ + d_inner, d_inner, ngroups * d_state, ngroups * d_state, + nheads + ], + dim=0) + in_proj_z = split(in_proj_z, tp_size, tp_rank, dim=0) + in_proj_x = split(in_proj_x, tp_size, tp_rank, dim=0) + in_proj_b = split(in_proj_b, tp_size, tp_rank, dim=0) + in_proj_c = split(in_proj_c, tp_size, tp_rank, dim=0) + in_proj_dt = split(in_proj_dt, tp_size, tp_rank, dim=0) + in_proj = torch.concat( + [in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt]) + weights[tllm_name] = in_proj.contiguous() + elif 'conv1d' in name and mamba_version == 'Mamba2': + ngroups = mamba_config['ngroups'] + conv_x, conv_b, conv_c = torch.split( + param, [d_inner, ngroups * d_state, ngroups * d_state], + dim=0) + conv_x = split(conv_x, tp_size, tp_rank, dim=0) + conv_b = split(conv_b, tp_size, tp_rank, dim=0) + conv_c = split(conv_c, tp_size, tp_rank, dim=0) + conv = torch.concat([conv_x, conv_b, conv_c]) + weights[tllm_name] = conv.contiguous() + elif any(keyword in name for keyword in ( + 'mixer.norm.weight', + 'A_log', + 'D', + 'dt_proj.bias', + 'dt_bias', + )) and mamba_version == 'Mamba2': + weights[tllm_name] = split(param, tp_size, tp_rank, dim=0) + elif 'out_proj' in name and mamba_version == 'Mamba2': + weights[tllm_name] = split(param, tp_size, tp_rank, + dim=1).contiguous() else: weights[tllm_name] = param del model_params @@ -205,6 +260,11 @@ def convert_from_hf_checkpoint(model_dir: Union[str, Path], if 'lm_head.weight' not in weights or weights['lm_head.weight'].data_ptr( ) == emb.data_ptr(): weights['lm_head.weight'] = copy.deepcopy(emb) + if mamba_version == 'Mamba2': + weights['lm_head.weight'] = split(weights['lm_head.weight'], + tp_size, + tp_rank, + dim=0) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) @@ -218,9 +278,7 @@ def do_convert_from_ckpt(args): def convert(worker_rank, args, convert_args): convert_from_ckpt = do_convert_from_ckpt(args) - world_size = 1 - args.workers = 1 - for rank in range(worker_rank, world_size, args.workers): + for rank in range(worker_rank, args.world_size): if convert_from_ckpt: weights = convert_from_hf_checkpoint(rank=rank, **convert_args) else: @@ -352,13 +410,18 @@ def main(): 'residual_in_fp32': hf_config.residual_in_fp32, 'pad_vocab_size_multiple': hf_config.pad_vocab_size_multiple, 'hidden_act': 'silu', - 'num_attention_heads': 1, + 'num_attention_heads': args.world_size, 'rnn_hidden_size': hf_config.intermediate_size, 'rnn_conv_dim_size': hf_config.intermediate_size, 'state_size': hf_config.state_size, 'conv_kernel': hf_config.conv_kernel, 'use_bias': hf_config.use_bias, 'mamba_version': mamba_version, + 'mapping': { + 'world_size': args.world_size, + 'tp_size': args.world_size, + 'pp_size': 1 + }, } if mamba_version == 'Mamba2': conv_dim = hf_config.intermediate_size + 2 * hf_config.ngroups * hf_config.state_size @@ -377,6 +440,7 @@ def main(): convert_from_ckpt = do_convert_from_ckpt(args) # TODO: Add convert_hf_mamba support for Mamba2 when transformers can support Mamba2 models assert convert_from_ckpt or mamba_version == 'Mamba2', "Mamba2 can only support convert from checkpoints." + assert args.world_size == 1 or mamba_version == 'Mamba2', "Mamba1 can not support tensor parallelism." if not convert_from_ckpt: logger.info(f'Convert by using model') hf_mamba = AutoModelForCausalLM.from_pretrained(args.model_dir, @@ -394,6 +458,7 @@ def main(): else: convert_args['hf_mamba'] = hf_mamba convert_args['mamba_version'] = mamba_version + convert_args['mamba_config'] = config convert(0, args, convert_args) diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt index 2fdedd010..465abbad0 100644 --- a/examples/mamba/requirements.txt +++ b/examples/mamba/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 transformers>=4.39.0 datasets~=2.14.5 evaluate diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt index e99999e73..774d5ee62 100644 --- a/examples/medusa/requirements.txt +++ b/examples/medusa/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/mixtral/requirements.txt b/examples/mixtral/requirements.txt index e02c93f53..2727c662e 100644 --- a/examples/mixtral/requirements.txt +++ b/examples/mixtral/requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 transformers==4.38.2 accelerate==0.25.0 diff --git a/examples/model_api/README.md b/examples/model_api/README.md index b3fb8d736..065a81847 100644 --- a/examples/model_api/README.md +++ b/examples/model_api/README.md @@ -38,3 +38,30 @@ Using AWQ INT4 weight only algorithm to quantize the given hugging llama model f ```bash python ./llama_quantize.py --hf_model_dir --cache_dir ./llama.awq/ ``` + + +## AutoModelForCausalLM + +The API `tensorrt_llm.AutoModelForCausalLM` can read from a Hugging Face model directory, find the correct TRT-LLM model class and dispatch the `from_hugging_face` mothod to the correct TRT-LLM class. + +The following code snippets demonstrated the usage of the `AutoModelForCausalLM` class. + +```python + mapping = Mapping(world_size=world_size, rank=0, tp_size=tp, pp_size=pp) + trtllm_model = AutoModelForCausalLM.from_hugging_face(hf_model_dir, mapping=mapping) + engine = build(trtllm_model, build_config) + executor = GenerationExecutor.create(engine) +``` + +## AutoConfig + +The API `tensorrt_llm.AutoConfig` can read the configuration from a Hugging Face model directory, find and return the correct TRT-LLM configuration class if it's supported, and raise a `NotImplementedError` if not supported. This API is useful when one needs to create a TRT-LLM model object using dummy weights, for things like workflow testing, benchmarks, without reading the real weights from storage, since reading the weights for large models can take significant amount of time. The usage looks like below snippets: + +```python + mapping = Mapping(world_size=world_size, rank=0, tp_size=tp, pp_size=pp) + trtllm_config = AutoConfig.from_hugging_face(hf_model_dir, dtype='float16', mapping=mapping) + + # Use the __init__ constructor directly to create a TRT-LLM model object + # instead of using from_hugging_face class method, since from_hugging_face will read the weights + trtllm_model_fake_weights = AutoModelForCausalLM.get_trtllm_model_class(hf_model_dir)(trtllm_config) +``` diff --git a/examples/mpt/requirements.txt b/examples/mpt/requirements.txt index 8d7092105..a4d628574 100644 --- a/examples/mpt/requirements.txt +++ b/examples/mpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index a6aff6356..c9ff35237 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -300,7 +300,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gemm_plugin float16 \ - --use_fused_mlp \ + --use_fused_mlp=enable \ --max_batch_size 1 \ --max_input_len 2048 \ --max_seq_len 2560 \ @@ -405,7 +405,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gemm_plugin float16 \ - --use_fused_mlp \ + --use_fused_mlp=enable \ --max_batch_size 1 \ --max_input_len 2048 \ --max_seq_len 2560 \ @@ -417,7 +417,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gpt_attention_plugin float16 \ --gemm_plugin float16 \ - --use_fused_mlp \ + --use_fused_mlp=enable \ --max_batch_size 1 \ --max_input_len 4096 \ --max_seq_len 5120 \ @@ -427,9 +427,9 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in # for VILA trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ - --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gemm_plugin float16 \ - --use_fused_mlp \ + --use_fused_mlp=enable \ --max_batch_size 1 \ --max_input_len 2048 \ --max_seq_len 2560 \ @@ -458,7 +458,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in For VILA, you can use either local file or web url as input images. Suppose you have a local image `av.png` downloaded from `https://github.com/Efficient-Large-Model/VILA/blob/main/demo_trt_llm/av.png` and the url of `merlion.png` ```bash - wget -O av.png https://raw.githubusercontent.com/Efficient-Large-Model/VILA/main/demo_trt_llm/av.png + wget -O av.png https://raw.githubusercontent.com/Efficient-Large-Model/VILA/main/demo_images/av.png python run.py \ --max_new_tokens 100 \ @@ -507,7 +507,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --calib_size 32 ``` - Then follow the same `trtllm-build` and `run.py` steps as before. NOTE: for `trtllm-build` command, do not use `--use_fused_mlp` in these quantization modes. + Then follow the same `trtllm-build` and `run.py` steps as before. NOTE: for `trtllm-build` command, do not use `--use_fused_mlp=enable` in these quantization modes. ## NeVA diff --git a/examples/nemotron/requirements.txt b/examples/nemotron/requirements.txt index 41442cada..286046ed2 100644 --- a/examples/nemotron/requirements.txt +++ b/examples/nemotron/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 transformers==4.40.2 # https://github.com/NVIDIA/NeMo/issues/9793 huggingface_hub==0.23.5 diff --git a/examples/opt/requirements.txt b/examples/opt/requirements.txt index 8d7092105..a4d628574 100644 --- a/examples/opt/requirements.txt +++ b/examples/opt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/phi/requirements.txt b/examples/phi/requirements.txt index cfe76b68c..6402bf4cb 100644 --- a/examples/phi/requirements.txt +++ b/examples/phi/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/quantization/requirements.txt b/examples/quantization/requirements.txt index b5c226d7d..bb98915c8 100644 --- a/examples/quantization/requirements.txt +++ b/examples/quantization/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets>=2.14.4 nemo-toolkit[all]<=1.20.0,>=1.18.0 rouge_score~=0.1.2 diff --git a/examples/qwen/requirements.txt b/examples/qwen/requirements.txt index 52f502ec1..24f07e2c1 100644 --- a/examples/qwen/requirements.txt +++ b/examples/qwen/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/qwenvl/requirements.txt b/examples/qwenvl/requirements.txt index ef5a14bdd..a5c37c5fa 100644 --- a/examples/qwenvl/requirements.txt +++ b/examples/qwenvl/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/recurrentgemma/requirements.txt b/examples/recurrentgemma/requirements.txt index f05166baa..11768c51d 100644 --- a/examples/recurrentgemma/requirements.txt +++ b/examples/recurrentgemma/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 git+https://github.com/google-deepmind/recurrentgemma.git flax>=0.8.2 jax~=0.4.23 diff --git a/examples/run.py b/examples/run.py index 4313e4f59..1f65e14b4 100644 --- a/examples/run.py +++ b/examples/run.py @@ -358,6 +358,12 @@ def main(args): assert args.temperature == 1.0, "Medusa should use temperature == 1.0" assert args.num_beams == 1, "Medusa should use num_beams == 1" runner_kwargs.update(medusa_choices=args.medusa_choices) + if args.lookahead_config is not None: + args.lookahead_config = ast.literal_eval(args.lookahead_config) + assert len( + args.lookahead_config + ) == 3, "Lookahead needs [max_window_size, max_ngram_size, max_verification_set_size]" + runner_kwargs.update(lookahead_config=args.lookahead_config) if not args.use_py_session: runner_kwargs.update( max_batch_size=len(batch_input_ids), diff --git a/examples/skywork/requirements.txt b/examples/skywork/requirements.txt index ae491ace9..c2aaea208 100644 --- a/examples/skywork/requirements.txt +++ b/examples/skywork/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets~=2.16.1 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/smaug/requirements.txt b/examples/smaug/requirements.txt index 6ae2e6d9f..4656b4f4c 100644 --- a/examples/smaug/requirements.txt +++ b/examples/smaug/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/summarize.py b/examples/summarize.py index 8ccdb5949..58cc19b9f 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -140,7 +140,7 @@ def main(args): bad_words_list = tensorrt_llm.runtime.decode_words_list( args.bad_words, tokenizer) - # random_seed = 5 + random_seed = args.random_seed temperature = args.temperature num_beams = args.num_beams length_penalty = args.length_penalty @@ -148,6 +148,7 @@ def main(args): repetition_penalty = args.repetition_penalty presence_penalty = args.presence_penalty frequency_penalty = args.frequency_penalty + torch.manual_seed(random_seed) output_dir = Path(args.output_dir) if args.output_dir else None if output_dir is not None: @@ -347,18 +348,28 @@ def eval_hf(datapoint, local_early_stopping = "never" with torch.no_grad(): + hf_config = {} + if num_beams == 1: + hf_config.update({ + "top_k": top_k, + "top_p": top_p, + "do_sample": True, + }) + else: + hf_config.update({ + "num_beams": num_beams, + "num_return_sequences": num_beams, + "early_stopping": local_early_stopping, + }) outputs = model.generate(batch_input_ids, max_new_tokens=output_len, - top_k=top_k, temperature=temperature, eos_token_id=end_id, pad_token_id=pad_id, - num_beams=num_beams, - num_return_sequences=num_beams, length_penalty=length_penalty, - early_stopping=local_early_stopping, output_scores=True, - return_dict_in_generate=True) + return_dict_in_generate=True, + **hf_config) if eval_ppl and batch_size == 1: # model.generate cannot return context logits? # Will cause additional latency diff --git a/examples/utils.py b/examples/utils.py index ab20f9ae0..5d975ad94 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -289,6 +289,13 @@ def add_common_args(parser): help="Medusa choice to use, if not none, will use Medusa decoding." " E.g.: [[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]] for 9 medusa tokens." ) + parser.add_argument( + '--lookahead_config', + type=str, + default=None, + help="lookahead config to use, if not none, will use lookahead decoding." + " E.g.: [5, 6, 7] for [max_window_size, max_ngram_size, max_verification_set_size]." + ) # model arguments parser.add_argument('--engine_dir', type=str, default='engine_outputs') diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index bfdd3b026..b426cf51e 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.13.0.dev2024081300 +tensorrt_llm==0.13.0.dev2024082000 tiktoken datasets kaldialign diff --git a/requirements.txt b/requirements.txt index ddfddd173..2efbdec26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,6 @@ optimum evaluate janus mpmath>=1.3.0 +click +click_option_group +aenum diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index a3d095a0b..0f7e4c750 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -76,6 +76,7 @@ def main(*, trt_root: str = None, nccl_root: str = None, clean: bool = False, + configure_cmake: bool = False, use_ccache: bool = False, fast_build: bool = False, cpp_only: bool = False, @@ -185,7 +186,7 @@ def main(*, source_dir = get_source_dir() with working_directory(build_dir): cmake_def_args = " ".join(cmake_def_args) - if clean or first_build: + if clean or first_build or configure_cmake: build_run( f'cmake -DCMAKE_BUILD_TYPE="{build_type}" -DBUILD_PYT="{build_pyt}" -DBUILD_PYBIND="{build_pybind}"' f' -DNVTX_DISABLE="{disable_nvtx}" -DBUILD_MICRO_BENCHMARKS={build_micro_benchmarks}' @@ -326,6 +327,9 @@ def add_arguments(parser: ArgumentParser): parser.add_argument("--cuda_architectures", "-a") parser.add_argument("--install", "-i", action="store_true") parser.add_argument("--clean", "-c", action="store_true") + parser.add_argument("--configure_cmake", + action="store_true", + help="Always configure cmake before building") parser.add_argument("--use_ccache", "-ccache", default=False, diff --git a/setup.py b/setup.py index 810fe7b71..a3db33e18 100644 --- a/setup.py +++ b/setup.py @@ -116,13 +116,17 @@ def has_ext_modules(self): 'libs/libtensorrt_llm_nvrtc_wrapper.so', 'libs/libdecoder_attention.so', 'bindings.*.so', - ]) + ['bindings/*.pyi', 'tools/plugin_gen/templates/*'], + ]) + [ + 'bindings/*.pyi', 'tools/plugin_gen/templates/*', + 'bench/build/benchmark_config.yml' + ], }, entry_points={ 'console_scripts': [ 'trtllm-build=tensorrt_llm.commands.build:main', 'trtllm-prune=tensorrt_llm.commands.prune:main', 'trtllm-refit=tensorrt_llm.commands.refit:main', + 'trtllm-bench=tensorrt_llm.commands.bench:main', ], }, scripts=['tensorrt_llm/hlapi/trtllm-hlapi-launch'], diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 4eef01cde..8b2bd324d 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -48,12 +48,15 @@ def _add_trt_llm_dll_directory(): from .hlapi.llm import LLM, LlmArgs, SamplingParams from .logger import logger from .mapping import Mapping +from .models.automodel import AutoConfig, AutoModelForCausalLM from .module import Module from .network import Network, net_guard from .parameter import Parameter from .version import __version__ __all__ = [ + 'AutoConfig', + 'AutoModelForCausalLM', 'logger', 'str_dtype_to_trt', 'torch_dtype_to_trt', diff --git a/tensorrt_llm/_ipc_utils.py b/tensorrt_llm/_ipc_utils.py index dc0884055..37fc5a7f8 100644 --- a/tensorrt_llm/_ipc_utils.py +++ b/tensorrt_llm/_ipc_utils.py @@ -15,7 +15,6 @@ import array import struct import sys -from contextlib import contextmanager from typing import List, Tuple from cuda import cudart @@ -31,17 +30,7 @@ def _raise_if_error(error: cudaError_t): raise RuntimeError(error) -@contextmanager -def peer_access(mapping: Mapping): - is_p2p_supported = set_peer_access(mapping, True) - assert is_p2p_supported, "P2P access not supported" - try: - yield - finally: - set_peer_access(mapping, False) - - -def set_peer_access(mapping: Mapping, enabled: bool = True) -> bool: +def can_access_peer(mapping: Mapping) -> bool: src_node = mapping.local_rank for rank in mapping.tp_group: dest_node = mapping.get_local_rank(rank) @@ -56,18 +45,6 @@ def set_peer_access(mapping: Mapping, enabled: bool = True) -> bool: logger.info( f"Cannot access peer device from {src_node} to {dest_node}") return False - - if enabled: - cudart.cudaDeviceEnablePeerAccess(dest_node, 0) - else: - cudart.cudaDeviceDisablePeerAccess(dest_node) - error = cudart.cudaGetLastError()[0] - if error not in [ - cudaError_t.cudaSuccess, - cudaError_t.cudaErrorPeerAccessAlreadyEnabled, - cudaError_t.cudaErrorPeerAccessNotEnabled - ]: - raise RuntimeError(error) return True diff --git a/tensorrt_llm/bench/__init__.py b/tensorrt_llm/bench/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorrt_llm/bench/build/__init__.py b/tensorrt_llm/bench/build/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorrt_llm/bench/build/benchmark_config.yml b/tensorrt_llm/bench/build/benchmark_config.yml new file mode 100644 index 000000000..ff2b32a17 --- /dev/null +++ b/tensorrt_llm/bench/build/benchmark_config.yml @@ -0,0 +1,69 @@ +meta-llama/Llama-2-7b-hf: + tp1_pp1: + general: + max_batch_size: 4096 + max_num_tokens: 8192 +meta-llama/Llama-2-70b-hf: + tp2_pp1: + general: + max_batch_size: 2048 + max_num_tokens: 2048 + tp4_pp1: + general: + max_batch_size: 4096 + max_num_tokens: 8192 + 4224: + max_batch_size: 256 + max_num_tokens: 8192 + tp8_pp1: + general: + max_batch_size: 8192 + max_num_tokens: 16384 + 2176: + max_batch_size: 1024 + max_num_tokens: 16384 +tiiuae/falcon-180B: + tp4_pp1: + general: + max_batch_size: 4096 + max_num_tokens: 8192 + tp8_pp1: + general: + max_batch_size: 2048 + max_num_tokens: 8192 +EleutherAI/gpt-j-6b: + tp1_pp1: + general: + max_batch_size: 128 + max_num_tokens: 2048 + 256: + max_batch_size: 2048 + max_num_tokens: 2048 +meta-llama/Meta-Llama-3-8B: + tp1_pp1: + general: + max_batch_size: 2048 + max_num_tokens: 8192 +meta-llama/Meta-Llama-3-70B: + tp4_pp1: + general: + max_batch_size: 2048 + max_num_tokens: 1024 + tp8_pp1: + general: + max_batch_size: 8192 + max_num_tokens: 16384 +mistralai/Mixtral-8x7B-v0.1: + tp2_pp1: + general: + max_batch_size: 2048 + max_num_tokens: 3072 + tp4_pp1: + general: + max_batch_size: 8192 + max_num_tokens: 8192 +mistralai/Mistral-7B-v0.1: + tp1_pp1: + general: + max_batch_size: 4098 + max_num_tokens: 8192 diff --git a/tensorrt_llm/bench/build/build.py b/tensorrt_llm/bench/build/build.py new file mode 100644 index 000000000..30a5fd5f2 --- /dev/null +++ b/tensorrt_llm/bench/build/build.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +from pathlib import Path +from select import select +from sys import stdin +from typing import Dict, get_args +import click +from click_option_group import AllOptionGroup, optgroup, RequiredMutuallyExclusiveOptionGroup +from transformers import PretrainedConfig as HFPretrainedConfig +import yaml + +from tensorrt_llm.bench.dataclasses import BenchmarkEnvironment +from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer +from tensorrt_llm.bench.utils import (VALID_QUANT_ALGOS, VALID_COMPUTE_DTYPES) +from tensorrt_llm.builder import BuildConfig +from tensorrt_llm.hlapi import LLM +from tensorrt_llm.hlapi.llm_utils import QuantConfig +from tensorrt_llm.logger import logger +from tensorrt_llm.quantization.mode import QuantAlgo + +from .utils import DEFAULT_HF_MODEL_DIRS + + +def derive_model_name(model_name): + model_dir = Path(model_name) + if model_dir.exists() and model_dir.is_dir(): + hf_config = HFPretrainedConfig.from_pretrained(model_dir) + for arch in hf_config.architectures: + if arch in DEFAULT_HF_MODEL_DIRS.keys(): + model_name = DEFAULT_HF_MODEL_DIRS[arch] + return model_name + + +def get_benchmark_engine_settings( + model_name: str, + tp_size: int, + pp_size: int, + max_seq_len: int, +) -> Dict[str, int]: + """Retrieve benchmark settings for a specific model + configuration. + + Args: + model_name (str): Huggingface model name. + tp_size (int): Number of tensor parallel shards. + pp_size (int): Number of pipeline parallel stages. + max_seq_len (int): The maximum sequence length to compile the engine. + + Raises: + ValueError: When the model_name is not supported. + RuntimeError: When the tp_size/pp_size configuration is not found. + + Returns: + Dict[str, int]: Dictionary containing engine configuration information + for engine build (max_num_tokens, max_batch_size). + """ + # Load up reference configurations so that we can set the appropriate + # settings. + settings_yml = Path(__file__).parent / "benchmark_config.yml" + with open(settings_yml, "r") as config: + configs = yaml.safe_load(config) + + model_name = derive_model_name(model_name) + # Check that the model is a supported benchmark model. + if model_name not in configs: + raise ValueError( + f"'{model_name}' is not a model that is configured for benchmarking." + ) + # Try and load the configuration TP x PP. If not valid, inform the user. + try: + model_configs = configs[model_name][f"tp{tp_size}_pp{pp_size}"] + config = model_configs.get(max_seq_len, None) + config = config if config is not None else model_configs.get("general") + except KeyError: + raise RuntimeError( + f"TP-{tp_size} x PP-{pp_size} is not a supported configuration." + "Please specify a valid benchmark configuration.") + + return config + + +@click.command(name="build") +@optgroup.group("Engine Configuration", + help="Configuration of the TensorRT-LLM engine.") +@optgroup.option( + "--tp_size", + "-tp", + type=int, + default=1, + required=False, + help="Number of tensor parallel shards to run the benchmark with.", +) +@optgroup.option( + "--pp_size", + "-pp", + type=int, + default=1, + required=False, + help="Number of pipeline parallel shards to run the benchmark with.", +) +@optgroup.option( + "--dtype", + type=click.Choice(tuple(get_args(VALID_COMPUTE_DTYPES))), + default="auto", + required=False, + help="Activation and plugin data type.", +) +@optgroup.option( + "--quantization", + "-q", + type=click.Choice(tuple(get_args(VALID_QUANT_ALGOS))), + default=None, + help= + ("The quantization algorithm to be used when benchmarking. See the " + "documentations for more information.\n" + " - https://nvidia.github.io/TensorRT-LLM/precision.html" + " - https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/quantization-in-TRT-LLM.md" + ), +) +@optgroup.group( + "Engine IFB Engine Limits", + cls=AllOptionGroup, + help="Runtime inflight batching scheduler limits.", +) +@optgroup.option( + "--max_batch_size", + default=None, + type=int, + help="Maximum batch size to build the benchmark engine with.", +) +@optgroup.option( + "--max_num_tokens", + type=int, + default=None, + help="Maximumn number of tokens the engine can accept.", +) +@optgroup.group( + "Engine Input Configuration", + cls=RequiredMutuallyExclusiveOptionGroup, + help="Input settings for configuring engine limits.", +) +@optgroup.option( + "--dataset", + type=click.Path(exists=True, + readable=True, + path_type=Path, + resolve_path=True), + default=None, + help="Pass in a dataset file for parsing instead of stdin.", +) +@optgroup.option("--max_seq_length", + type=click.IntRange(min=1), + default=None, + help="Fixed maximum sequence length for engine build.") +@click.pass_obj +def build_command( + bench_env: BenchmarkEnvironment, + **params, +) -> None: + """Build engines for benchmarking.""" + logger.set_level("info") + + # Collect configuration parameters from CLI parameters. + tp_size = params.get("tp_size") + pp_size = params.get("pp_size") + dtype = params.get("dtype") + quantization = params.pop("quantization") + max_num_tokens = params.pop("max_num_tokens") + max_batch_size = params.pop("max_batch_size") + + # Dataset options + dataset_path: Path = params.pop("dataset") + max_seq_len: int = params.pop("max_seq_length") + data_on_stdin: bool = bool(len(select([ + stdin, + ], [], [], 0.0)[0])) + + # Initialize the HF tokenizer for the specified model. + tokenizer = initialize_tokenizer(bench_env.model) + + # If we are receiving data from a path or stdin, parse and gather metadata. + if dataset_path or data_on_stdin: + logger.info("Found dataset.") + # Cannot set the data file path and pipe in from stdin. Choose one. + if dataset_path is not None and data_on_stdin: + raise ValueError( + "Cannot provide a dataset on both stdin and by --dataset " + "option. Please pick one.") + stream = stdin if data_on_stdin else open(dataset_path, "r") + # Parse the dataset from stdin and return it plus its metadata. + metadata, _ = \ + create_dataset_from_stream(tokenizer, stream=stream) + # The max sequence length option for build is the sum of max osl + isl. + max_seq_len = metadata.max_sequence_length + logger.info(metadata.get_summary_for_logger.info()) + + # We have a specified ISL:OSL combination. + elif max_seq_len is None: + raise RuntimeError("Unknown input configuration. Exiting.") + + # Get the config for the engine + config = get_benchmark_engine_settings(bench_env.model, tp_size, pp_size, + max_seq_len) + + # If specified on the command line, override max batch size or max num + # tokens from baseline config. + max_batch_size = max_batch_size if max_batch_size is not None else config[ + "max_batch_size"] + max_num_tokens = max_num_tokens if max_num_tokens is not None else config[ + "max_num_tokens"] + + # Construct a TRT-LLM build config. + build_config = BuildConfig(max_batch_size=max_batch_size, + max_seq_len=max_seq_len, + max_num_tokens=max_num_tokens) + + # Set the compute quantization. + quant_algo = QuantAlgo(quantization) if quantization is not None else None + quant_config = QuantConfig() + quant_config.quant_algo = quant_algo + # If the quantization is FP8, force the KV cache dtype to FP8. + quant_config.kv_cache_quant_algo = quant_algo.value \ + if quant_algo == QuantAlgo.FP8 else None + + # Enable multiple profiles and paged context FMHA. + build_config.plugin_config.multiple_profiles = True + # build_config.plugin_config._reduce_fusion = True + + # Enable FHMA, and FP8 FMHA if FP8 quantization is enabled. + # TODO: Revisit, there is an issue with enabling FHMA. If only + # paged FMHA is enabled with FP8 quantization, the Builder + # will not enable the FP8 FMHA. + build_config.plugin_config.use_paged_context_fmha = True + build_config.plugin_config.use_fp8_context_fmha = True \ + if quant_algo == QuantAlgo.FP8 else False + + # Construct the engine path and report the engine metadata. + model_name = derive_model_name(bench_env.model) + engine_dir = Path(bench_env.workspace, model_name, + f"tp_{tp_size}_pp_{pp_size}") + + logger.info( + "\n===========================================================\n" + "= ENGINE BUILD INFO\n" + "===========================================================\n" + f"Model Name:\t\t{bench_env.model}\n" + f"Workspace Directory:\t{bench_env.workspace}\n" + f"Engine Directory:\t{engine_dir}\n\n" + "===========================================================\n" + "= ENGINE CONFIGURATION DETAILS\n" + "===========================================================\n" + f"Max Sequence Length:\t\t{max_seq_len}\n" + f"Max Batch Size:\t\t\t{max_batch_size}\n" + f"Max Num Tokens:\t\t\t{max_num_tokens}\n" + f"Quantization:\t\t\t{quantization}\n" + "===========================================================\n") + + # Build the LLM engine with the HLAPI. + logger.set_level("error") + llm = LLM(bench_env.model, + tokenizer, + dtype=dtype, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + build_config=build_config, + quant_config=quant_config) + # Save the engine. + llm.save(engine_dir) + llm.shutdown() + logger.set_level("info") + logger.info( + "\n\n===========================================================\n" + f"ENGINE SAVED: {engine_dir}\n" + "===========================================================\n") diff --git a/tensorrt_llm/bench/build/utils.py b/tensorrt_llm/bench/build/utils.py new file mode 100644 index 000000000..f5b2e1563 --- /dev/null +++ b/tensorrt_llm/bench/build/utils.py @@ -0,0 +1,22 @@ +DEFAULT_HF_MODEL_DIRS = { + 'BaichuanForCausalLM': 'baichuan-inc/Baichuan-13B-Chat', + 'BloomForCausalLM': 'bigscience/bloom-560m', + 'GLMModel': 'THUDM/glm-10b', + 'ChatGLMModel': 'THUDM/chatglm3-6b', + 'ChatGLMForCausalLM': 'THUDM/chatglm3-6b', + 'FalconForCausalLM': 'tiiuae/falcon-rw-1b', + 'GPTForCausalLM': 'gpt2-medium', + 'GPTJForCausalLM': 'EleutherAI/gpt-j-6b', + 'GPTNeoXForCausalLM': 'EleutherAI/gpt-neox-20b', + 'InternLMForCausalLM': 'internlm/internlm-chat-7b', + 'InternLM2ForCausalLM': 'internlm/internlm2-chat-7b', + 'LlamaForCausalLM': 'meta-llama/Llama-2-7b-hf', + 'MPTForCausalLM': 'mosaicml/mpt-7b', + 'PhiForCausalLM': 'microsoft/phi-2', + 'OPTForCausalLM': 'facebook/opt-350m', + 'QWenLMHeadModel': 'Qwen/Qwen-7B', + 'QWenForCausalLM': 'Qwen/Qwen-7B', + 'Qwen2ForCausalLM': 'Qwen/Qwen1.5-7B', + 'Qwen2MoeForCausalLM': 'Qwen/Qwen1.5-MoE-A2.7B', + 'RecurrentGemmaForCausalLM': 'google/recurrentgemma-2b', +} diff --git a/tensorrt_llm/bench/dataclasses.py b/tensorrt_llm/bench/dataclasses.py new file mode 100644 index 000000000..a4238689b --- /dev/null +++ b/tensorrt_llm/bench/dataclasses.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel, computed_field, model_validator + +from tensorrt_llm.bench.utils import (VALID_CACHE_DTYPES, VALID_COMPUTE_DTYPES, + VALID_QUANT_ALGOS) + + +class EngineConstraints(BaseModel): + max_batch_size: int = 2048 + max_tokens: int = 2048 + max_sequence_length: int = 6144 + tp_size: int = 1 + pp_size: int = 1 + + @computed_field + def world_size(self) -> int: + return self.tp_size * self.pp_size + + +class EngineConfiguration(BaseModel): + quantization: Optional[VALID_QUANT_ALGOS] = None + kv_cache_dtype: Optional[VALID_CACHE_DTYPES] = "float16" + fused_mlp: Optional[bool] = False + dtype: Optional[VALID_COMPUTE_DTYPES] = "float16" + gemm_plugin: Optional[bool] = False + gpt_attn_plugin: Optional[bool] = True + paged_context_fmha: Optional[bool] = True + gemm_swiglu_plugin: Optional[bool] = False + multi_block_mode: Optional[bool] = False + multiple_profiles: Optional[bool] = True + build_options: List[str] = [] + + +class BuildConfiguration(BaseModel): + model: str + workspace: Path + engine_dir: Optional[Path] = None + engine_config: EngineConfiguration + engine_limits: EngineConstraints + + @computed_field + def get_build_feature_args(self) -> List[str]: + ... + + @model_validator(mode="after") + def check_engine_dir(self) -> BuildConfiguration: + if self.engine_dir is None: + limits = self.engine_limits + engine_name: str = ( + f"BS_{limits.max_batch_size}_sl_{limits.max_sequence_length}_" + f"tp_{limits.tp_size}_pp_{limits.pp_size}") + self.engine_dir = Path( + self.workspace, + self.model, + engine_name, + ) + + return self + + +class BenchmarkEnvironment(BaseModel): + model: str + workspace: Path + + +class InferenceRequest(BaseModel): + task_id: int + prompt: Optional[str] = None + output_tokens: int + logits: Optional[List[int]] = None + + @model_validator(mode="after") + def verify_prompt_and_logits(self) -> InferenceRequest: + if self.prompt is None and self.logits is None: + raise ValueError( + f"Both prompt and logits for {self.task_id} are both None.") + return self + + +class DatasetMetadata(BaseModel): + max_isl: int + max_osl: int + max_sequence_length: int + num_requests: int + + def get_summary_for_print(self) -> str: + return ("===========================================================\n" + "= DATASET DETAILS\n" + "===========================================================\n" + f"Max Input Sequence Length:\t{self.max_isl}\n" + f"Max Output Sequence Length:\t{self.max_osl}\n" + f"Max Sequence Length:\t{self.max_sequence_length}\n" + f"Number of Sequences:\t{self.num_requests}\n" + "===========================================================\n" + f"\n") diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/enums.py b/tensorrt_llm/bench/enums.py similarity index 100% rename from benchmarks/suite/tensorrt_llm_bench/utils/enums.py rename to tensorrt_llm/bench/enums.py diff --git a/tensorrt_llm/bench/run/__init__.py b/tensorrt_llm/bench/run/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorrt_llm/bench/run/dataclasses.py b/tensorrt_llm/bench/run/dataclasses.py new file mode 100644 index 000000000..e234b4491 --- /dev/null +++ b/tensorrt_llm/bench/run/dataclasses.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from importlib.util import find_spec +from pathlib import Path +from typing import Any, List + +from pydantic import (BaseModel, Field, PositiveFloat, computed_field, + model_validator) + +import tensorrt_llm.bindings.executor as trtllm +from tensorrt_llm.bench.enums import IFBSchedulingPolicy + + +class RuntimeConfig(BaseModel): + model: str + engine_dir: Path + sw_version: str + settings_config: ExecutorSettingsConfig + world_config: ExecutorWorldConfig + + def get_config(self) -> trtllm.ExecutorConfig: + return trtllm.ExecutorConfig( + scheduler_config=self.settings_config.get_scheduler_config(), + kv_cache_config=self.settings_config.get_kvcache_config(), + parallel_config=self.world_config.get_parallel_config(), + batching_type=trtllm.BatchingType.INFLIGHT, + iter_stats_max_iterations=0, + request_stats_max_iterations=0, + max_batch_size=self.settings_config.max_batch_size, + max_num_tokens=self.settings_config.max_num_tokens, + enable_chunked_context=self.settings_config.chunking, + ) + + +class ExecutorWorldConfig(BaseModel): + pp_size: int = 1 + tp_size: int = 1 + world_size: int = 1 + gpus_per_node: int = 8 + leader_mode: bool = False + + @model_validator(mode="after") + def validate_world_size(self) -> ExecutorWorldConfig: + parallel_world = self.pp_size * self.tp_size + num_gpus = self.world_size * self.gpus_per_node + valid_world = bool(num_gpus >= parallel_world) + + if not valid_world: + raise ValueError( + f"World configuration is invalid, TP * PP ({parallel_world})" + "does not equal the total number of available GPUs" + f"({num_gpus}).") + + return self + + def _get_tensorrt_llm_executor_worker_path(self) -> Path: + module_path = find_spec("tensorrt_llm").loader.get_filename() + exec_path = Path(module_path).parent / 'bin' / 'executorWorker' + return exec_path.absolute() + + def get_parallel_config(self) -> trtllm.ParallelConfig: + if self.leader_mode: + comm_mode = trtllm.CommunicationMode.LEADER + orchestrator_config = None + else: + comm_mode = trtllm.CommunicationMode.ORCHESTRATOR + orchestrator_config = trtllm.OrchestratorConfig( + True, str(self._get_tensorrt_llm_executor_worker_path())) + + return trtllm.ParallelConfig( + trtllm.CommunicationType.MPI, + comm_mode, + orchestrator_config=orchestrator_config, + ) + + +class ExecutorSettingsConfig(BaseModel): + chunking: bool = True + scheduler_policy: IFBSchedulingPolicy = IFBSchedulingPolicy.MAX_UTILIZTION + max_batch_size: int + max_num_tokens: int + kv_cache_percent: PositiveFloat = Field(default=.90, le=1.0) + + def get_kvcache_config(self) -> trtllm.KvCacheConfig: + return trtllm.KvCacheConfig( + free_gpu_memory_fraction=self.kv_cache_percent, ) + + def get_scheduler_config(self) -> trtllm.SchedulerConfig: + return trtllm.SchedulerConfig( + capacity_scheduler_policy=self.scheduler_policy.value, + context_chunking_policy=trtllm.ContextChunkingPolicy. + FIRST_COME_FIRST_SERVED, + ) + + +class ResponseRecord(BaseModel): + request_id: int + timestamp: float + output_tokens: List[int] + is_final: bool + has_error: bool + + +class PercentileStats(BaseModel): + p50: float + p95: float + p99: float + minimum: float + maximum: float + average: float + + @classmethod + def from_iterable(cls, values: List[Any]) -> PercentileStats: + length = len(values) + return cls( + p50=values[int(length * 0.50)], + p95=values[int(length * 0.95)], + p99=values[int(length * 0.99)], + average=float(sum(values)) / length, + minimum=min(values), + maximum=max(values), + ) + + +class RequestStats(BaseModel): + request_id: int + input_tokens: int + time_log: List[float] = Field(default_factory=list, init=False) + error_responses: int = Field(default=0, init=False) + num_responses: int = Field(default=0, init=False) + num_tokens: int = Field(default=0, init=False) + + @computed_field + def first_token_latency(self) -> float: + try: + return self.time_log[1] - self.time_log[0] + except IndexError: + return 0 + + @computed_field + def request_latency(self) -> float: + return max(self.time_log) - min(self.time_log) + + def register_event(self, is_error: bool, is_response: bool, + timestamp: float, num_tokens: int) -> None: + self.time_log.append(timestamp) + self.error_responses += 1 if is_error else 0 + self.num_responses += 1 if is_response else 0 + self.num_tokens += num_tokens + + +class BenchmarkStatistics(BaseModel): + total_latency_ns: float + total_output_tokens: int + total_input_tokens: int + num_requests: int + issue_rate_ns: float + + request_percentiles: PercentileStats = None + token_percentiles: PercentileStats = None + + @computed_field + def token_throughput_ns(self) -> float: + return float(self.total_output_tokens) / self.total_latency_ns + + @computed_field + def request_throughput_ns(self) -> float: + return float(self.num_requests) / self.total_latency_ns + + @computed_field + def average_input_length(self) -> float: + return float(self.total_input_tokens) / self.num_requests + + @computed_field + def average_output_length(self) -> float: + return float(self.total_output_tokens) / self.num_requests diff --git a/tensorrt_llm/bench/run/run.py b/tensorrt_llm/bench/run/run.py new file mode 100644 index 000000000..52eed4f5a --- /dev/null +++ b/tensorrt_llm/bench/run/run.py @@ -0,0 +1,431 @@ +from __future__ import annotations + +import json +import multiprocessing as mp +from copy import deepcopy +from datetime import timedelta +from pathlib import Path +from threading import Event, Thread +from time import monotonic_ns, sleep +from typing import Generator, List, Tuple + +import click +from click_option_group import optgroup + +import tensorrt_llm.bindings.executor as trtllm +from tensorrt_llm.bench.dataclasses import BenchmarkEnvironment +from tensorrt_llm.bench.enums import IFBSchedulingPolicy +from tensorrt_llm.bench.run.dataclasses import ResponseRecord, RuntimeConfig +from tensorrt_llm.bench.run.utils import (StatsKeeper, get_executor_request, + get_settings_from_engine) +from tensorrt_llm.bench.utils.data import generate_dataset_from_stream +from tensorrt_llm.logger import logger + + +@click.command(name="throughput") +@optgroup.group("Engine run configuration.", + help="Runtime settings for executing a TensorRT-LLM engine.") +@optgroup.option( + "--engine_dir", + type=click.Path(exists=True, + readable=True, + path_type=Path, + resolve_path=True), + required=True, + help="Path to a serialized TRT-LLM engine.", +) +@optgroup.option( + "--max_batch_size", + type=int, + help="Maximum runtime batch size to run the engine with.", +) +@optgroup.option( + "--max_num_tokens", + type=int, + help="Maximum runtime tokens that an engine can accept.", +) +@optgroup.option( + "--beam_width", + type=int, + default=1, + help="Number of search beams.", +) +@optgroup.option( + "--kv_cache_free_gpu_mem_fraction", + type=float, + default=.90, + help="The percentage of memory to use for KV Cache after model load.", +) +@optgroup.group( + "Engine Input Configuration", + help="Input configuration for driving the engine.", +) +@optgroup.option( + "--dataset", + type=click.Path(exists=True, + readable=True, + path_type=Path, + resolve_path=True), + default=None, + help="Pass in a dataset file for parsing instead of stdin.", +) +@optgroup.option( + "--request_rate", + type=int, + default=-1, + help="Desired input request rate (number of messages per second).", + hidden=True, +) +@optgroup.option( + "--num_requests", + type=int, + default=0, + help="Number of requests to cap benchmark run at. Minimum between value and" + "length of dataset.", +) +@click.pass_obj +def run_command( + bench_env: BenchmarkEnvironment, + **params, +) -> None: + """Run a throughput test on a TRT-LLM engine.""" + + logger.set_level("info") + logger.info("Preparing to run throughput benchmark...") + # Parameters from CLI + # Model, experiment, and engine params + dataset_path: Path = params.pop("dataset") + request_rate: int = params.pop("request_rate") + num_requests: int = params.pop("num_requests") + model: str = bench_env.model + engine_dir: Path = params.pop("engine_dir") + # Engine configuration parsing + exec_settings, build_cfg = get_settings_from_engine(engine_dir) + exec_settings["model"] = model + engine_bs = exec_settings["settings_config"]["max_batch_size"] + engine_tokens = exec_settings["settings_config"]["max_num_tokens"] + engine_max_seq_len = build_cfg["max_seq_len"] + + # Runtime Options + runtime_max_bs = params.pop("max_batch_size") + runtime_max_bs = runtime_max_bs if runtime_max_bs else engine_bs + runtime_max_tokens = params.pop("max_num_tokens") + runtime_max_tokens = runtime_max_bs if runtime_max_tokens else engine_tokens + kv_cache_percent = params.pop("kv_cache_free_gpu_mem_fraction") + beam_width = params.pop("beam_width") + + # Update configuration with runtime options + exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent + exec_settings["settings_config"]["max_batch_size"] = runtime_max_bs + exec_settings["settings_config"]["max_num_tokens"] = runtime_max_tokens + exec_settings["settings_config"]["beam_width"] = beam_width + exec_settings["settings_config"][ + "scheduler_policy"] = IFBSchedulingPolicy.NO_EVICT + # Construct the runtime configuration dataclass. + runtime_config = RuntimeConfig(**exec_settings) + + # Dataset Loading and Preparation + metadata, requests = generate_dataset_from_stream(dataset_path, model, + num_requests) + # TODO: Verify that the engine can handle the max/min ISL/OSL. + if metadata.max_sequence_length > engine_max_seq_len: + raise RuntimeError( + f"Engine supports a max sequence of {engine_max_seq_len}. Provided " + "dataset contains a maximum sequence of " + f"{metadata.max_sequence_length}. Please rebuild a new engine to" + "support this dataset.") + executor_requests = [] + while requests: + request = requests.pop() + executor_requests.append( + get_executor_request(request, pad_id=-1, eos_id=-1)) + del request + + logger.info("Setting up benchmarker and infrastructure.") + new_request_queue = mp.Queue() + response_queue = mp.Queue() + logger.set_level("error") + benchmark = ThroughputBenchmark( + dataset=executor_requests, + request_rate=request_rate, + runtime_cfg=runtime_config, + request_queue=new_request_queue, + response_queue=response_queue, + ) + logger.set_level("info") + try: + logger.info("Ready to start benchmark.") + benchmark.start_benchmark() + benchmark.wait() + benchmark.stop_benchmark() + benchmark.report_statistics() + except KeyboardInterrupt: + logger.set_level("error") + benchmark.stop_benchmark() + finally: + logger.set_level("error") + benchmark.shutdown() + + +class ExecutorManager: + """Utility class for managing a TRT-LLM Executor instance.""" + + def __init__(self, runtime_cfg: RuntimeConfig, + response_queue: mp.Queue) -> None: + """Initialize the ExecutorManager. + + Args: + runtime_cfg (RuntimeConfig): Execution runtime configuration. + response_queue (mp.Queue): Process-safe queue for passing request + responses to main process. + """ + logger.info("Initializing Executor.") + # Runtime related properties. + self.runtime_config: RuntimeConfig = runtime_cfg + self.executor = trtllm.Executor( + self.runtime_config.engine_dir, + trtllm.ModelType.DECODER_ONLY, + executor_config=self.runtime_config.get_config()) + + # Runtime tracking and multiprocessing. + self.responses = response_queue + self._shutdown = Event() + self._resp_daemon_finished = Event() + + self.response_thread = Thread(target=self.response_daemon) + self.response_thread.start() + + def enqueue(self, *requests: trtllm.Request) -> Generator[int]: + """Generate the next request identifier. + + Yields: + Generator[int]: The request identifier of the last queued request. + """ + for request in requests: + req_id = self.executor.enqueue_request(request) + yield req_id, len(request.input_token_ids) + + def stop(self) -> None: + """Stop a running manager.""" + + logger.info("Stopping response parsing.") + self._shutdown.set() + self.response_thread.join() + logger.info("Parsing stopped.") + + def shutdown(self) -> None: + """Shutdown daemon components.""" + + if self.executor is not None: + logger.info("Shutting down ExecutorServer.") + self.executor.shutdown() + + def response_daemon(self) -> None: + """Daemon method for retrieving messages from the Executor.""" + + logger.info("Starting response daemon...") + + def _process_response() -> None: + responses = self.executor.await_responses(timeout=timedelta( + milliseconds=1)) + now = monotonic_ns() + for response in responses: + # logger.info("Pushing response to queue") + self.responses.put( + ResponseRecord( + timestamp=now, + request_id=response.request_id, + has_error=response.has_error(), + is_final=response.result.is_final, + output_tokens=response.result.output_token_ids[0])) + + while not self._shutdown.is_set(): + _process_response() + + logger.info("Collecting last responses before shutdown.") + # Reap the last messages before shutting down + _process_response() + self._resp_daemon_finished.set() + logger.info("Completed request parsing.") + + +class ThroughputBenchmark: + """Throughput benchmark utility class.""" + + def __init__( + self, + dataset: List[trtllm.Request], + request_rate: int, + runtime_cfg: RuntimeConfig, + request_queue: mp.Queue, + response_queue: mp.Queue, + ) -> None: + """Initialize the throughput benchmark. + + Args: + dataset (List[trtllm.Request]): A dataset of TRT-LLM requests to + benchmark against. + request_rate (int): Rate to deliver input requests to the backend. + runtime_cfg (RuntimeConfig): Runtime configuration. + request_queue (mp.Queue): Process-safe queue of request identifiers + response_queue (mp.Queue): Process-safe queue for passing request + responses to main process. + """ + logger.info(f"Initializing Throughput Benchmark. [rate=%d req/s]") + # Dataset and input properties. + self.requests = dataset + self.delay_func = lambda x: sleep( + x) if request_rate > 0 else lambda x: None + self.request_delay = 1.0 / request_rate + + # Runtime configuration for Executor + self.runtime_config = deepcopy(runtime_cfg) + self.executor = None + + # Request and response reporting structures + self.new_request_queue = request_queue + self.response_queue = response_queue + + # Benchmark stats and time tracking. + self.start_time = None + self.end_time = None + self.submitted_requests = 0 + self.statistics = StatsKeeper() + + # Multiprocessing for handling request load generation + # and response parsing. + self.stop = mp.Event() + self.parsing_complete = mp.Event() + self.request_thread: Thread = Thread(target=self.enqueue_process) + self.stats_process: Thread = Thread(target=self.collect_statistics) + + def enqueue_process(self) -> None: + """Method for starting enqueueing requests.""" + logger.info("Request serving started.") + + request_generator = self.executor.enqueue(*self.requests) + # Iterate the generator until we run out of requests. + # Note the walrus operator. + while ((request := next(request_generator, False)) + and not self.stop.is_set()): + self.submitted_requests += 1 + timestamp = monotonic_ns() + self.new_request_queue.put((timestamp, request[0], request[1])) + self.delay_func(self.request_delay) + logger.info("Request serving stopped.") + + def start_benchmark(self) -> None: + """Start the benchmark.""" + # Start the ExecutorManager for running the backend. + self.executor = ExecutorManager(self.runtime_config, + self.response_queue) + logger.info("Executor started.") + # Note the time we started the thread. + self.start_time = monotonic_ns() + self.request_thread.start() + # Start the statistics thread. + self.stats_process.start() + logger.info("Benchmark started.") + + def stop_benchmark(self) -> None: + """Stop the benchmark and clean up backend and threads.""" + logger.info("Stop received.") + self.stop.set() + self.executor.stop() + self.request_thread.join() + logger.info("Request generator successfully joined.") + self.stats_process.join() + logger.info("Statistics process successfully joined.") + + def shutdown(self) -> None: + """Shutdown the backend.""" + logger.info("Benchmark Shutdown called!") + if self.executor is not None: + self.executor.shutdown() + logger.info("Executor shutdown.") + + def wait(self) -> bool: + """Wait (blocking) on the benchmark. + + Returns: + bool: Return whether the event is set. + """ + return not self.parsing_complete.wait() + + def collect_statistics(self) -> None: + """Collect statistics (daemon method).""" + logger.info("Starting statistics collection.") + + def _process_requests() -> None: + while not self.new_request_queue.empty(): + new_request: Tuple[float, + int] = self.new_request_queue.get_nowait() + self.statistics.register_request(new_request[1], new_request[0], + new_request[2]) + + while not self.response_queue.empty(): + response: ResponseRecord = self.response_queue.get_nowait() + self.statistics.register_response(response) + + logger.info("Collecting live stats...") + # TODO: Revisit this conditional, if the request rate is slow enough this + # will probably prematurely trip. We will likely need a conditional that + # captures a new event for submission being complete, with the stop event + # overriding it if detected. + while not self.stop.is_set( + ) and self.statistics.num_complete < self.submitted_requests: + _process_requests() + + logger.info("Collecting last stats...") + _process_requests() + self.end_time = monotonic_ns() + self.parsing_complete.set() + logger.info("Ending statistics collection.") + + def report_statistics(self) -> None: + """Report internal statistics about benchmark.""" + + config_path = self.runtime_config.engine_dir / "config.json" + with open(config_path, "r") as config: + engine_config = json.load(config) + + stats = self.statistics.generate_statistics_summary() + rt_cfg = self.runtime_config + build_cfg = engine_config["build_config"] + pretrain_cfg = engine_config["pretrained_config"] + total_latency_s = stats.total_latency_ns / 1.0e9 + + logger.info( + "\n===========================================================\n" + "= ENGINE DETAILS\n" + "===========================================================\n" + f"Model:\t\t\t{rt_cfg.model}\n" + f"Engine Directory:\t{rt_cfg.engine_dir}\n" + f"TensorRT-LLM Version:\t{rt_cfg.sw_version}\n" + f"Dtype:\t\t\t{pretrain_cfg['dtype']}\n" + f"KV Cache Dtype:\t\t{pretrain_cfg['quantization']['kv_cache_quant_algo']}\n" + f"Quantization:\t\t{pretrain_cfg['quantization']['quant_algo']}\n" + f"Max Input Length:\t{build_cfg['max_input_len']}\n" + f"Max Sequence Length:\t{build_cfg['max_seq_len']}\n" + f"\n" + "===========================================================\n" + "= WORLD + RUNTIME INFORMATION \n" + "===========================================================\n" + f"TP Size:\t\t{rt_cfg.world_config.tp_size}\n" + f"PP Size:\t\t{rt_cfg.world_config.pp_size}\n" + f"Max Runtime Batch Size:\t{rt_cfg.settings_config.max_batch_size}\n" + f"Max Runtime Tokens:\t{rt_cfg.settings_config.max_num_tokens}\n" + f"Scheduling Policy:\t{rt_cfg.settings_config.scheduler_policy.values[1]}\n" + f"KV Memory Percentage:\t{rt_cfg.settings_config.kv_cache_percent * 100.0}%\n" + f"Issue Rate (req/sec):\t{stats.issue_rate_ns * 1e9}" + f"\n" + "===========================================================\n" + "= STATISTICS\n" + "===========================================================\n" + f"Number of requests:\t\t{stats.num_requests}\n" + f"Average Input Length (tokens):\t{stats.average_input_length}\n" + f"Average Output Length (tokens):\t{stats.average_output_length}\n" + f"Token Throughput (tokens/sec):\t{stats.total_output_tokens / total_latency_s}\n" + f"Request Throughput (req/sec):\t{stats.num_requests / total_latency_s}\n" + f"Total Latency (seconds):\t{total_latency_s}\n" + "===========================================================\n") diff --git a/tensorrt_llm/bench/run/utils.py b/tensorrt_llm/bench/run/utils.py new file mode 100644 index 000000000..6ed640b1b --- /dev/null +++ b/tensorrt_llm/bench/run/utils.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import json +from collections import defaultdict +from pathlib import Path +from typing import Dict, Tuple, Union + +import tensorrt_llm.bindings.executor as trtllm +from tensorrt_llm.bench.run.dataclasses import (BenchmarkStatistics, + PercentileStats, RequestStats, + ResponseRecord) +from tensorrt_llm.bindings import InferenceRequest + + +def get_executor_request(request: InferenceRequest, + pad_id: int, + eos_id: int, + streaming: bool = False) -> trtllm.Request: + return trtllm.Request( + input_token_ids=request.logits, + max_new_tokens=request.output_tokens, + stop_words=[], + bad_words=[], + streaming=streaming, + output_config=trtllm.OutputConfig(exclude_input_from_output=True), + pad_id=pad_id, + end_id=eos_id, + ) + + +def get_settings_from_engine( + engine_path: Path +) -> Tuple[Dict[str, Union[str, int]], Dict[str, Union[str, int]]]: + config_path = engine_path / "config.json" + runtime_config = {} + + with open(config_path, "r") as config_json: + config = json.load(config_json) + + engine_world_map = config["pretrained_config"]["mapping"] + engine_build_cfg = config["build_config"] + engine_parallel_map = engine_build_cfg["auto_parallel_config"] + + world_config = { + "pp_size": engine_world_map["pp_size"], + "tp_size": engine_world_map["tp_size"], + "world_size": engine_world_map["world_size"], + "gpus_per_node": engine_parallel_map["gpus_per_node"], + } + + executor_settings = { + "max_batch_size": engine_build_cfg["max_batch_size"], + "max_num_tokens": engine_build_cfg["max_num_tokens"], + } + + runtime_config.update({ + "sw_version": config["version"], + "engine_dir": str(engine_path.absolute()), + "settings_config": executor_settings, + "world_config": world_config, + }) + + return runtime_config, engine_build_cfg + + +class StatsKeeper: + + def __init__(self) -> None: + self.requests: RequestStats = {} + self.num_complete: int = 0 + + self._unseen_cache = defaultdict(list) + + def register_request( + self, + request_id: int, + timestamp: float, + num_tokens: int, + ) -> None: + request = RequestStats(request_id=request_id, input_tokens=num_tokens) + request.register_event(False, False, timestamp, 0) + self.requests[request_id] = request + + def register_response(self, response: ResponseRecord) -> None: + request_id = response.request_id + + if request_id not in self.requests: + self._unseen_cache[request_id].append(response) + else: + self.requests[request_id].register_event( + is_error=response.has_error, + is_response=True, + timestamp=response.timestamp, + num_tokens=len(response.output_tokens)) + + if response.is_final: + self.num_complete += 1 + + def generate_statistics_summary(self) -> None: + total_output_tokens: int = 0 + total_input_tokens: int = 0 + num_requests = len(self.requests) + total_request_latency: float = 0.0 + start_time = float("inf") + end_time = -1 + + request_latencies = [] + last_queue_time = 0.0 + queue_time_total = 0.0 + + for entry in self.requests.values(): + entry.time_log.sort() + + queue_time_total += entry.time_log[0] - last_queue_time + last_queue_time = entry.time_log[0] + + request_latencies.append(entry.request_latency) + total_output_tokens += entry.num_tokens + total_input_tokens += entry.input_tokens + total_request_latency += entry.request_latency + start_time = min(start_time, entry.time_log[0]) + end_time = max(end_time, entry.time_log[-1]) + + stats = BenchmarkStatistics( + num_requests=num_requests, + total_latency_ns=end_time - start_time, + total_output_tokens=total_output_tokens, + total_input_tokens=total_input_tokens, + request_percentiles=PercentileStats.from_iterable( + request_latencies), + issue_rate_ns=queue_time_total / num_requests) + + return stats diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/__init__.py b/tensorrt_llm/bench/utils/__init__.py similarity index 96% rename from benchmarks/suite/tensorrt_llm_bench/utils/__init__.py rename to tensorrt_llm/bench/utils/__init__.py index b327f7a91..e0f49a822 100644 --- a/benchmarks/suite/tensorrt_llm_bench/utils/__init__.py +++ b/tensorrt_llm/bench/utils/__init__.py @@ -1,6 +1,6 @@ import functools import os -import subprocess +import subprocess # nosec B404 from pathlib import Path from typing import Any, Callable, List, Literal @@ -10,9 +10,9 @@ "tiiuae/falcon-180B", "meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-70b-hf", "EleutherAI/gpt-j-6b", ] -VALID_COMPUTE_DTYPES = Literal["float16", "bfloat16"] +VALID_COMPUTE_DTYPES = Literal["auto", "float16", "bfloat16"] VALID_CACHE_DTYPES = Literal["float16", "float8", "int8"] -VALID_QUANT_ALGOS = Literal["None", f"{QuantAlgo.W8A16}", f"{QuantAlgo.W4A16}", +VALID_QUANT_ALGOS = Literal[f"{QuantAlgo.W8A16}", f"{QuantAlgo.W4A16}", f"{QuantAlgo.W4A16_AWQ}", f"{QuantAlgo.W4A8_AWQ}", f"{QuantAlgo.W4A16_GPTQ}", f"{QuantAlgo.FP8}", f"{QuantAlgo.INT8}"] diff --git a/tensorrt_llm/bench/utils/data.py b/tensorrt_llm/bench/utils/data.py new file mode 100644 index 000000000..4f6380325 --- /dev/null +++ b/tensorrt_llm/bench/utils/data.py @@ -0,0 +1,137 @@ +import json +import sys +from functools import partial +from pathlib import Path +from select import select +from typing import List, TextIO, Tuple + +from transformers import AutoTokenizer, PreTrainedTokenizer + +from tensorrt_llm.bench.dataclasses import DatasetMetadata, InferenceRequest + + +def generate_dataset_from_stream(dataset_path: Path, + model: str, + num_requests: int = 0): + # Check for data on stdin. + data_on_stdin: bool = bool(len(select([ + sys.stdin, + ], [], [], 0.0)[0])) + + # Cannot set the data file path and pipe in from stdin. Choose one. + if dataset_path is not None and data_on_stdin: + raise ValueError( + "Cannot provide a dataset on both stdin and by --dataset option. " + "Please pick one.") + # If we are receiving data from a path or stdin, parse and gather metadata. + stream = sys.stdin if data_on_stdin else open(dataset_path, "r") + tokenizer = initialize_tokenizer(model) + # Parse the dataset from stdin and return it plus its metadata. + metadata, requests = \ + create_dataset_from_stream( + tokenizer, + stream=stream, + num_requests=num_requests + ) + + return metadata, requests + + +def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer: + """Initialize a tokenizer. + + Args: + model_name (str): The name of the HuggingFace model to pull a + tokenizer from. + + Returns: + PreTrainedTokenizer: An initialized HuggingFace tokenizer. + """ + # Initialize the tokenizer specific to the model that we are planning + # to benchmark. + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + if tokenizer.pad_token_id is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + return tokenizer + + +def create_dataset_from_stream( + tokenizer: PreTrainedTokenizer, + max_input_length: int = 0, + max_output_length: int = 0, + stream: TextIO = sys.stdin, + num_requests: int = 0, +) -> Tuple[DatasetMetadata, List[InferenceRequest]]: + """Generate metadata and a list of requests to drive benchmarking. + + Args: + tokenizer (PreTrainedTokenizer): HuggingFace tokenizer. + max_input_length (int): Maximum input length to cap prompts to. + + Returns: + DatasetMetadata: Dataclass of dataset statistics. + List[InferenceRequest]: A list of inference requests for benchmarking. + """ + # Initialize dataset list, and metadata tracking variables. + dataset = [] + max_isl = 0 + max_osl = 0 + max_sequence = 0 + max_requests = num_requests if num_requests > 0 else float("inf") + + # If we're limiting the input length to a certain size, then set up + # a partial to truncate the data down to size. Otherwise, just use the + # unmodified tokenizer callable. + tokenize = (partial( + tokenizer, + padding="max_length", + max_length=max_input_length, + truncation=True, + ) if max_input_length > 0 else tokenizer) + + # If we need to limit the output length, fill in a partial callable + # for max, otherwise a lambda that just returns x with no bounds. + output_limiter = (partial(max, max_output_length) + if max_output_length > 0 else lambda x: x) + + # For each line in the standard input, parse out the JSON string we expect + # to see. + # Note the := walrus -- we're assigning and checking the condition. + while (line := stream.readline()) and len(dataset) < max_requests: + # We expect the data to come in as a JSON string. + # For example: + # {"prompt": "Generate an infinite response to the following: + # There once was a man who.", "output_tokens": 1000} + # Each line should be a complete JSON dictionary with no indentation + # or newline characters. + data = json.loads(line) + logits = data.get("logits", None) + prompt = data.get("prompt", None) + task_id = data["task_id"] + osl = data["output_tokens"] + # If the request comes in with logits, just use the provided. + # Otherwise we need to tokenize it. + logits = tokenize(prompt)["input_ids"] if logits is None else logits + + request = InferenceRequest( + task_id=task_id, + prompt=prompt, + output_tokens=output_limiter(osl), + logits=logits, + ) + max_isl = max(max_isl, len(logits)) + max_osl = max(max_osl, osl) + max_sequence = max(max_sequence, len(logits) + osl) + dataset.append(request) + + # Fill in basic dataset metrics here + # TODO: Maybe fill this out to be more complete? + metadata = DatasetMetadata( + max_isl=max_isl, + max_osl=max_osl, + max_sequence_length=max_sequence, + num_requests=len(dataset), + ) + + return metadata, dataset diff --git a/tensorrt_llm/bench/utils/tokenize.py b/tensorrt_llm/bench/utils/tokenize.py new file mode 100644 index 000000000..44f04df56 --- /dev/null +++ b/tensorrt_llm/bench/utils/tokenize.py @@ -0,0 +1,105 @@ +import json +import sys +from functools import partial +from typing import List, TextIO, Tuple + +from transformers import AutoTokenizer, PreTrainedTokenizer + +from tensorrt_llm.bench.dataclasses import DatasetMetadata, InferenceRequest + + +def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer: + """Initialize a tokenizer. + + Args: + model_name (str): The name of the HuggingFace model to pull a + tokenizer from. + + Returns: + PreTrainedTokenizer: An initialized HuggingFace tokenizer. + """ + # Initialize the tokenizer specific to the model that we are planning + # to benchmark. + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + if tokenizer.pad_token_id is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + return tokenizer + + +def create_dataset_from_stream( + tokenizer: PreTrainedTokenizer, + max_input_length: int = 0, + max_output_length: int = 0, + stream: TextIO = sys.stdin, +) -> Tuple[DatasetMetadata, List[InferenceRequest]]: + """Generate metadata and a list of requests to drive benchmarking. + + Args: + tokenizer (PreTrainedTokenizer): HuggingFace tokenizer. + max_input_length (int): Maximum input length to cap prompts to. + + Returns: + DatasetMetadata: Dataclass of dataset statistics. + List[InferenceRequest]: A list of inference requests for benchmarking. + """ + # Initialize dataset list, and metadata tracking variables. + dataset = [] + max_isl = 0 + max_osl = 0 + max_sequence = 0 + + # If we're limiting the input length to a certain size, then set up + # a partial to truncate the data down to size. Otherwise, just use the + # unmodified tokenizer callable. + tokenize = (partial( + tokenizer, + padding="max_length", + max_length=max_input_length, + truncation=True, + ) if max_input_length > 0 else tokenizer) + + # If we need to limit the output length, fill in a partial callable + # for max, otherwise a lambda that just returns x with no bounds. + output_limiter = (partial(max, max_output_length) + if max_output_length > 0 else lambda x: x) + + # For each line in the standard input, parse out the JSON string we expect + # to see. + # Note the := walrus -- we're assigning and checking the condition. + while line := stream.readline(): + # We expect the data to come in as a JSON string. + # For example: + # {"prompt": "Generate an infinite response to the following: There once was a man who.", "output_tokens": 1000} + # Each line should be a complete JSON dictionary with no indentation + # or newline characters. + data = json.loads(line) + logits = data.get("logits", None) + prompt = data.get("prompt", None) + task_id = data["task_id"] + osl = data["output_tokens"] + # If the request comes in with logits, just use the provided. + # Otherwise we need to tokenize it. + logits = tokenize(prompt)["input_ids"] if logits is None else logits + + request = InferenceRequest( + task_id=task_id, + prompt=prompt, + output_tokens=output_limiter(osl), + logits=logits, + ) + max_isl = max(max_isl, len(logits)) + max_osl = max(max_osl, osl) + max_sequence = max(max_sequence, len(logits) + osl) + dataset.append(request) + + # Fill in basic dataset metrics here + # TODO: Maybe fill this out to be more complete? + metadata = DatasetMetadata( + max_isl=max_isl, + max_osl=max_osl, + max_sequence_length=max_sequence, + num_requests=len(dataset), + ) + + return metadata, dataset diff --git a/tensorrt_llm/commands/bench.py b/tensorrt_llm/commands/bench.py new file mode 100644 index 000000000..9c48dac3f --- /dev/null +++ b/tensorrt_llm/commands/bench.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import click + +from tensorrt_llm.bench.build.build import build_command +from tensorrt_llm.bench.dataclasses import BenchmarkEnvironment +from tensorrt_llm.bench.run.run import run_command + + +@click.group(name="trtllm-bench", context_settings={'show_default': True}) +@click.option( + "--model", + "-m", + required=True, + type=str, + help="The Huggingface name of the model to benchmark.", +) +@click.option( + "--workspace", + "-w", + required=False, + type=click.Path(writable=True, readable=True), + default="/tmp", # nosec B108 + help="The directory to store benchmarking intermediate files.", +) +@click.pass_context +def main( + ctx, + model: str, + workspace: Path, +) -> None: + ctx.obj = BenchmarkEnvironment(model=model, workspace=workspace) + + # Create the workspace where we plan to store intermediate files. + ctx.obj.workspace.mkdir(parents=True, exist_ok=True) + + +main.add_command(build_command) +main.add_command(run_command) + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index d78a93be0..8cd5c7264 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -33,6 +33,7 @@ from tensorrt_llm.models import MODEL_MAP, PretrainedConfig from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode from tensorrt_llm.plugin import PluginConfig, add_plugin_argument +from tensorrt_llm.quantization.mode import QuantAlgo def parse_arguments(): @@ -131,15 +132,6 @@ def parse_arguments(): help= 'Deprecated. Set this option to enable is equvilient to `--kv_cache_type paged` for transformer based models.' ) - parser.add_argument( - '--use_fused_mlp', - default=False, - action='store_true', - help= - 'Enable horizontal fusion in GatedMLP, reduces layer input traffic and potentially improves performance. ' - 'For FP8 PTQ, the downside is slight reduction of accuracy because one of the quantization scaling factors is discarded. ' - '(An example for reference only: 0.45734 vs 0.45755 for LLaMA-v2 7B using `modelopt/examples/hf/instruct_eval/mmlu.py`).' - ) parser.add_argument( '--gather_all_token_logits', action='store_true', @@ -443,7 +435,7 @@ def main(): kwargs = { 'logits_dtype': args.logits_dtype, - 'use_fused_mlp': args.use_fused_mlp, + 'use_fused_mlp': plugin_config.use_fused_mlp, 'cp_size': args.cp_size, 'tp_size': args.tp_size, 'pp_size': args.pp_size, @@ -466,6 +458,11 @@ def main(): model_config = PretrainedConfig.from_json_file(config_path) + # avoid ValueError if not supported quantization is chosen with use_fused_mlp + quant_algo = model_config.quantization.quant_algo + if quant_algo and quant_algo != QuantAlgo.FP8: + kwargs['use_fused_mlp'] = False + if args.build_config is None: if args.multiple_profiles == "enable" and args.opt_num_tokens is not None: raise RuntimeError( diff --git a/tensorrt_llm/executor.py b/tensorrt_llm/executor.py index c562d7943..3ba8f6b2d 100644 --- a/tensorrt_llm/executor.py +++ b/tensorrt_llm/executor.py @@ -1,6 +1,7 @@ import asyncio import atexit import datetime +import json import math import secrets import threading @@ -369,7 +370,7 @@ async def aget_stats(self): @staticmethod def create( - engine_dir: Path, + engine_object_or_path: Union[Path, "Engine"], executor_config: tllm.ExecutorConfig = tllm.ExecutorConfig(1), model_world_size: int = 1, world_size: int = 0, @@ -387,7 +388,7 @@ def create( f"on {world_size} ranks.") worker_kwargs = { - "engine_dir": engine_dir, + "engine_object_or_path": engine_object_or_path, "executor_config": executor_config, } @@ -412,7 +413,7 @@ class WorkerExit(GeneratorExit): def __init__( self, - engine_dir: Path, + engine_object_or_path: Union[Path, "Engine"], executor_config: tllm.ExecutorConfig = tllm.ExecutorConfig(1), ) -> None: super().__init__() @@ -422,10 +423,17 @@ def __init__( self._pending: set = set() self.result_queue = None self.rank = mpi_rank() - - self.engine = tllm.Executor(engine_dir, - tllm.ModelType.DECODER_ONLY, - executor_config=executor_config) + from .builder import Engine + if isinstance(engine_object_or_path, Engine): + self.engine = tllm.Executor( + engine_object_or_path.engine, + json.dumps(engine_object_or_path.config.to_dict()), + tllm.ModelType.DECODER_ONLY, + executor_config=executor_config) + else: + self.engine = tllm.Executor(engine_object_or_path, + tllm.ModelType.DECODER_ONLY, + executor_config=executor_config) self.awaiter_stop_event = threading.Event() self.awaiter_thread = threading.Thread(target=self.awaiter_loop, daemon=True) @@ -678,7 +686,7 @@ def __init__( @print_traceback_on_error @staticmethod def workers_main( - engine_dir: Path, + engine_object_or_path: Union[Path, "Engine"], request_queue_addr: Tuple[str, int, bytes], request_id_queue_addr: Tuple[str, int, bytes], result_queue_addr: Tuple[str, int, bytes], @@ -699,7 +707,8 @@ def workers_main( # TODO[chunweiy]: fix the non-rank0 process failure init_ok = True try: - executor = ExecutorBindingsWorker(engine_dir, executor_config) + executor = ExecutorBindingsWorker(engine_object_or_path, + executor_config) except Exception as e: init_ok = False raise e diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 374f05f18..741ed8eb9 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -4913,8 +4913,7 @@ def gpt_attention( tp_rank = trt.PluginField("tp_rank", np.array(tp_rank, dtype=np.int32), trt.PluginFieldType.INT32) kv_cache_quant_mode_field = trt.PluginField( - "kv_cache_quant_mode", - np.array(np.int8(kv_cache_quant_mode), dtype=np.int32), + "kv_cache_quant_mode", np.array(kv_cache_quant_mode, dtype=np.int32), trt.PluginFieldType.INT32) paged_kv_cache = trt.PluginField( "paged_kv_cache", np.array(paged_kv_cache_flag, dtype=np.int32), @@ -5519,7 +5518,6 @@ def lora_plugin( transa: bool = False, transb: bool = False, host_context_lengths: Tensor = None, # for pad-free input mode - max_num_tokens: int = 0, max_low_rank: int = 0, lora_ranks: List[Tensor] = None, lora_weights_pointers: List[Tensor] = None, @@ -5550,9 +5548,6 @@ def lora_plugin( host_context_lengths: cpu Tensor = None A host tensor that contains the lengths of the different inputs, - max_num_tokens : int - Maximum number of tokens, used to determine the workspace size. - max_low_rank : int Maximum low_rank, used to determine the workspace size. @@ -5600,9 +5595,6 @@ def lora_plugin( "remove_input_padding", np.array(np.int8(default_net().plugin_config.remove_input_padding), dtype=np.int8), trt.PluginFieldType.INT8) - max_num_tokens_field = trt.PluginField( - "max_num_tokens", np.array(max_num_tokens, dtype=np.int32), - trt.PluginFieldType.INT32) max_low_rank_field = trt.PluginField("max_low_rank", np.array(max_low_rank, dtype=np.int32), trt.PluginFieldType.INT32) @@ -5616,8 +5608,7 @@ def lora_plugin( pfc = trt.PluginFieldCollection([ in_hidden_size_field, transa, transb, num_lora_modules_field, pf_type, - remove_input_padding, max_num_tokens_field, max_low_rank_field, - weight_index_field + remove_input_padding, max_low_rank_field, weight_index_field ] + out_hidden_size_field_list) lora_plug = plg_creator.create_plugin("lora", pfc) diff --git a/tensorrt_llm/hlapi/llm.py b/tensorrt_llm/hlapi/llm.py index 2b5e8288d..8725bf382 100644 --- a/tensorrt_llm/hlapi/llm.py +++ b/tensorrt_llm/hlapi/llm.py @@ -261,7 +261,8 @@ def _build_model(self): if self.args.decoding_config is not None: executor_config.decoding_config = self.args.decoding_config if self.args.logits_post_processor_map: - executor_config.logits_post_processor_map = self.args.logits_post_processor_map + executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( + processor_map=self.args.logits_post_processor_map) executor_config.normalize_log_probs = self.args.normalize_log_probs executor_config.enable_chunked_context = self.args.enable_chunked_context executor_config.max_beam_width = self.args.build_config.max_beam_width diff --git a/tensorrt_llm/hlapi/llm_utils.py b/tensorrt_llm/hlapi/llm_utils.py index 9383ccd01..cb6c6da2e 100644 --- a/tensorrt_llm/hlapi/llm_utils.py +++ b/tensorrt_llm/hlapi/llm_utils.py @@ -46,9 +46,8 @@ from ..builder import BuildConfig, Engine, EngineConfig, build from ..logger import logger from ..mapping import Mapping -from ..models import MODEL_MAP -from ..models.modeling_utils import (PretrainedConfig, QuantAlgo, QuantConfig, - TopModelMixin) +from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM +from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig from ..module import Module from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, get_build_cache_config_from_env) @@ -283,6 +282,11 @@ def __post_init__(self): # The underlying implementation might disable it if it is not supported. self.enable_chunked_context: bool = False + # TODO[chunweiy]: Enable this option in the future + # Currently we want HLAPI to be consistent with the lower APIs in the model building, thus disable this to avoid + # magics. + self.perform_config_arbitration = False + if self.skip_tokenizer_init: self.tokenizer = None else: @@ -386,6 +390,14 @@ def setup(self): self.build_config = self.build_config or BuildConfig() + if self.perform_config_arbitration: + self._perform_config_arbitration() + + def _perform_config_arbitration(self): + ''' + Arbitrate the configurations for the model building. The configs between different functional or performance + features might be confilcted, and this method will arbitrate the conflicts and raise errors if necessary. + ''' self._config_arbitrator = _ConfigArbitrator() if self.build_config_mutable: if not self.build_config.max_num_tokens: @@ -415,6 +427,8 @@ def setup(self): kv_cache_config=self.kv_cache_config, build_config=self.build_config) + self._config_arbitrator = None + def _check_model_or_model_dir(self): if not self.model: raise ValueError("model should be provided.") @@ -598,7 +612,8 @@ def __setstate__(self, state): def __getstate__(self): state = self.__dict__.copy() - del state['_config_arbitrator'] + if '_config_arbitrator' in state: + del state['_config_arbitrator'] return state @@ -1012,19 +1027,7 @@ def _download_hf_model(self): def _load_model_from_hf(self): ''' Load a TRT-LLM model from a HF model. ''' assert self._model_dir is not None - - import transformers - hf_config = transformers.AutoConfig.from_pretrained( - self._model_dir, trust_remote_code=True) - architecture = hf_config.architectures[0] - - if architecture not in MODEL_MAP: - raise KeyError(f"Unsupported model architecture: {architecture}") - model_cls = MODEL_MAP[architecture] - if TopModelMixin.__name__ in model_cls.from_hugging_face.__qualname__: - raise NotImplementedError( - f"Unsupported model architecture in HLAPI: {architecture}") - + model_cls = AutoModelForCausalLM.get_trtllm_model_class(self._model_dir) if self.llm_args.quant_config.requires_calibration: assert self.workspace is not None checkpoint_dir = f"{self.workspace}/quantized-checkpoint" @@ -1061,6 +1064,7 @@ def _load_model_from_ckpt(self): os.path.join(self._model_dir, 'config.json')) self.pretrained_config.mapping = self.mapping + #TODO: TRTLLM-1091, change the architecture in the checkpoint to TRT-LLM one, not HF one. architecture = self.pretrained_config.architecture assert architecture in MODEL_MAP, \ f"Unsupported model architecture: {architecture}" @@ -1257,16 +1261,7 @@ def get_final_build_config(llm_args: LlmArgs, # The build() doesn't need the real model instance to get a updated BuildConig. What is really needed is the # dtype. That's why the model will be downloaded from HF if necessary to get the accurate dtype. - import transformers - hf_config = transformers.AutoConfig.from_pretrained( - model_dir, trust_remote_code=True) - architecture = hf_config.architectures[0] - - if architecture not in MODEL_MAP: - raise KeyError(f"Unsupported model architecture: {architecture}") - model_cls = MODEL_MAP[architecture] - config_cls = model_cls.config_class - pretrained_config = config_cls.from_hugging_face( + pretrained_config = AutoConfig.from_hugging_face( model_dir, mapping=Mapping(world_size=llm_args.parallel_config.world_size, tp_size=llm_args.parallel_config.tp_size, diff --git a/tensorrt_llm/hlapi/utils.py b/tensorrt_llm/hlapi/utils.py index de0b23dc1..6aaf28119 100644 --- a/tensorrt_llm/hlapi/utils.py +++ b/tensorrt_llm/hlapi/utils.py @@ -1,6 +1,5 @@ import hashlib import os -import signal import sys import tempfile import traceback @@ -326,15 +325,6 @@ def register(self, obj: Any): exception_handler = ExceptionHandler() sys.excepthook = exception_handler - -def sigint_handler(signal, frame): - sys.stderr.write("\nSIGINT received, quit LLM!\n") - sys.exit(1) - - -# Register the signal handler to handle SIGINT -# This helps to deal with user's Ctrl+C -signal.signal(signal.SIGINT, sigint_handler) # Use the system temporary directory to share the cache temp_dir = tempfile.gettempdir() diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index a093d006f..6256a4ef7 100644 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -728,7 +728,6 @@ def forward(self, ], host_request_types=q_lora_params.host_request_types, host_context_lengths=q_lora_params.host_context_lengths, - max_num_tokens=q_lora_params.max_num_tokens, max_encoder_context_length=q_lora_params. max_encoder_context_length, host_encoder_input_lengths=q_lora_params. @@ -1470,8 +1469,7 @@ def forward(self, v_lora_params.lora_weights_pointers[0], ], host_request_types=q_lora_params.host_request_types, - host_context_lengths=q_lora_params.host_context_lengths, - max_num_tokens=q_lora_params.max_num_tokens) + host_context_lengths=q_lora_params.host_context_lengths) q_lora, k_lora, v_lora = self.qkv_lora(hidden_states, qkv_lora_params) diff --git a/tensorrt_llm/layers/embedding.py b/tensorrt_llm/layers/embedding.py index 8ade20b07..869a05a43 100644 --- a/tensorrt_llm/layers/embedding.py +++ b/tensorrt_llm/layers/embedding.py @@ -44,8 +44,7 @@ def __init__(self, tp_size: int = 1, tp_group: Optional[list] = None, sharding_dim: int = 0, - tp_rank: Optional[int] = None, - share_embedding_table: bool = False): + tp_rank: Optional[int] = None): super().__init__() # num_embeddings records the total vocab size no matter using TP or not self.num_embeddings = num_embeddings @@ -56,7 +55,6 @@ def __init__(self, self.tp_rank = tp_rank self.dtype = dtype self.tp_dim = sharding_dim - self.share_embedding_table = share_embedding_table if sharding_dim == 1: self.weight = Parameter(shape=(self.num_embeddings, @@ -91,11 +89,12 @@ def weight_loader(self, mapping: Mapping, param: Parameter, shard_size) param.value = loaded_weight - def postprocess(self, tllm_key, weights): + def postprocess(self, tllm_key, weights, **kwargs): + config = kwargs.get("config", None) if weights is None: return {} weights = weights.to(str_dtype_to_torch(self.dtype)) - if self.share_embedding_table: + if config.share_embedding_table: return {} else: weights = weights.clone() @@ -119,10 +118,9 @@ def __init__(self, tp_size=1, tp_group=None, sharding_dim=0, - tp_rank=0, - share_embedding_table=False): + tp_rank=0): super().__init__(num_embeddings, embedding_dim, dtype, tp_size, - tp_group, sharding_dim, tp_rank, share_embedding_table) + tp_group, sharding_dim, tp_rank) if vocab_size is None: vocab_size = num_embeddings self.vocab_size = vocab_size diff --git a/tensorrt_llm/layers/linear.py b/tensorrt_llm/layers/linear.py index 0735df1e8..8a150c3c4 100644 --- a/tensorrt_llm/layers/linear.py +++ b/tensorrt_llm/layers/linear.py @@ -342,18 +342,34 @@ def collect_and_bias(self, x, **kwargs): return x - def postprocess(self, - tllm_key, - weights, - using_head_as_leading_dim=False, - num_heads=-1): + def postprocess(self, tllm_key, weights, **kwargs): + using_head_as_leading_dim = kwargs.get("using_head_as_leading_dim", + False) + config = kwargs.get("config", None) if self.is_qkv: if isinstance(weights, list): + if config.remove_duplicated_kv_heads: + head_size = config.hidden_size // config.num_heads if config.head_size is None else config.head_size + k, v = weights[1:] + k = k.reshape([ + k.shape[0] // head_size // 2, 2, head_size, + self.in_features + ]) + v = v.reshape([ + v.shape[0] // head_size // 2, 2, head_size, + self.in_features + ]) + assert (k[:, 0] == k[:, 1]).all() + assert (v[:, 0] == v[:, 1]).all() + k = k[:, 0].reshape([-1, self.in_features]) + v = v[:, 0].reshape([-1, self.in_features]) + weights[1] = k + weights[2] = v weights = torch.cat(weights) if using_head_as_leading_dim: # Reorder [n_head, 3, head_dim, ...] into [3, n_head, head_dim, ...] - head_dim = self.out_features // (3 * num_heads) - w = weights.reshape(num_heads, 3, head_dim, -1) + head_dim = self.out_features // (3 * config.num_heads) + w = weights.reshape(config.num_heads, 3, head_dim, -1) w = w.transpose(0, 1) if w.shape[-1] > 1: weights = w.reshape(-1, self.in_features) # Weight diff --git a/tensorrt_llm/layers/lora.py b/tensorrt_llm/layers/lora.py index 6de43e42c..5ad9cac72 100644 --- a/tensorrt_llm/layers/lora.py +++ b/tensorrt_llm/layers/lora.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List from .._common import default_net from ..functional import Tensor, lora_plugin @@ -28,7 +28,6 @@ def __init__( lora_weights_pointers: List[Tensor] = None, host_request_types: Tensor = None, host_context_lengths: Tensor = None, - max_num_tokens: Optional[int] = None, max_encoder_context_length: Tensor = None, host_encoder_input_lengths: Tensor = None, weight_index: int = 0, @@ -38,7 +37,6 @@ def __init__( self.lora_weights_pointers = lora_weights_pointers self.host_request_types = host_request_types self.host_context_lengths = host_context_lengths - self.max_num_tokens = max_num_tokens self.max_encoder_context_length = max_encoder_context_length self.host_encoder_input_lengths = host_encoder_input_lengths self.weight_index = weight_index @@ -71,10 +69,6 @@ def forward(self, host_context_lengths=lora_runtime_params.host_context_lengths if not is_cross_attention else lora_runtime_params.host_encoder_input_lengths, - # For cross attention, max_encoder_context_length should be used instead of max_num_tokens - max_num_tokens=lora_runtime_params.max_num_tokens - if not is_cross_attention else - lora_runtime_params.max_encoder_context_length, max_low_rank=self.max_low_rank, lora_ranks=lora_runtime_params.lora_ranks, lora_weights_pointers=lora_runtime_params.lora_weights_pointers, @@ -93,7 +87,6 @@ def __init__( lora_ranks=None, # : List[dict[Tensor]] lora_weights_pointers=None, # : List[dict[Tensor]] host_context_lengths: Tensor = None, - max_num_tokens: Optional[int] = None, max_encoder_context_length: Tensor = None, # For cross attention host_request_types: Tensor = None, host_encoder_input_lengths: Tensor = None, # For cross attention @@ -104,7 +97,6 @@ def __init__( self.lora_weights_pointers = lora_weights_pointers self.host_context_lengths = host_context_lengths - self.max_num_tokens = max_num_tokens self.max_encoder_context_length = max_encoder_context_length self.host_request_types = host_request_types self.host_encoder_input_lengths = host_encoder_input_lengths @@ -115,7 +107,6 @@ def get_layer_params(self, layer_idx: int): lora_ranks=[self.lora_ranks[layer_idx]], lora_weights_pointers=[self.lora_weights_pointers[layer_idx]], host_context_lengths=self.host_context_lengths, - max_num_tokens=self.max_num_tokens, max_encoder_context_length=self.max_encoder_context_length, host_request_types=self.host_request_types, host_encoder_input_lengths=self.host_encoder_input_lengths, @@ -133,7 +124,6 @@ def get_runtime_params(self, layer_idx: int, lora_module: str): [f"{lora_module}_lora_weights_pointers"] ], host_context_lengths=self.host_context_lengths, - max_num_tokens=self.max_num_tokens, max_encoder_context_length=self.max_encoder_context_length, host_request_types=self.host_request_types, host_encoder_input_lengths=self.host_encoder_input_lengths, diff --git a/tensorrt_llm/layers/mlp.py b/tensorrt_llm/layers/mlp.py index 77066bd24..25056d757 100644 --- a/tensorrt_llm/layers/mlp.py +++ b/tensorrt_llm/layers/mlp.py @@ -46,8 +46,7 @@ def fc_gate_lora(hidden_states, lora, lora_layer_params): mlp_gate_lora_params.lora_weights_pointers[0] ], host_request_types=mlp_fc_lora_params.host_request_types, - host_context_lengths=mlp_fc_lora_params.host_context_lengths, - max_num_tokens=mlp_fc_lora_params.max_num_tokens) + host_context_lengths=mlp_fc_lora_params.host_context_lengths) mlp_fc_lora, mlp_gate_lora = lora(hidden_states, mlp_in_lora_params) mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora], diff --git a/tensorrt_llm/layers/moe.py b/tensorrt_llm/layers/moe.py index 9fa96a12e..40a6b744a 100644 --- a/tensorrt_llm/layers/moe.py +++ b/tensorrt_llm/layers/moe.py @@ -337,7 +337,7 @@ def __init__(self, in_features: int, out_features: int, self.register_parameter('activation_scaling_factor', None) self.register_parameter('weights_scaling_factor', None) - def postprocess(self, tllm_key, weights): + def postprocess(self, tllm_key, weights, **kwargs): if tllm_key.endswith("weight"): if isinstance(weights, torch.Tensor): weights = [weights] @@ -667,7 +667,6 @@ def get_params(module): gate_lora_weights_pointers, }], host_context_lengths=lora_layer_params.host_context_lengths, - max_num_tokens=lora_layer_params.max_num_tokens, max_encoder_context_length=lora_layer_params. max_encoder_context_length, host_request_types=lora_layer_params.host_request_types, diff --git a/tensorrt_llm/layers/ssm.py b/tensorrt_llm/layers/ssm.py index 05ebab307..6feb9a2d0 100644 --- a/tensorrt_llm/layers/ssm.py +++ b/tensorrt_llm/layers/ssm.py @@ -21,7 +21,7 @@ permute, selective_scan, shape, split, view) from ..module import Module from ..parameter import Parameter -from .linear import Linear +from .linear import ColumnLinear, Linear, RowLinear from .normalization import RmsNorm @@ -240,32 +240,40 @@ def __init__(self, chunk_size=256, bias=False, rmsnorm=True, - dtype=None): + dtype=None, + tp_group=None, + tp_size=1): super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv - self.d_inner = d_inner + assert d_inner % tp_size == 0 + self.d_inner = d_inner // tp_size self.headdim = headdim - self.ngroups = ngroups + assert ngroups % tp_size == 0 + self.ngroups = ngroups // tp_size self.chunk_size = chunk_size self.rmsnorm = rmsnorm self.dtype = dtype - assert self.d_inner % self.headdim == 0 - self.nheads = self.d_inner // self.headdim + assert d_inner % headdim == 0 + nheads = d_inner // headdim + assert nheads % tp_size == 0 + self.nheads = nheads // tp_size self.A = Parameter(shape=(self.nheads, ), dtype="float32") self.D = Parameter(shape=(self.nheads, ), dtype="float32") self.dt_bias = Parameter(shape=(self.nheads, ), dtype="float32") - d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads - self.in_proj = Linear(self.d_model, - d_in_proj, - bias=bias, - dtype=dtype, - gather_output=False) + d_in_proj = 2 * d_inner + 2 * ngroups * d_state + nheads + self.in_proj = ColumnLinear(d_model, + d_in_proj, + bias=bias, + dtype=dtype, + tp_group=tp_group, + tp_size=tp_size, + gather_output=False) - self.conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + self.conv_dim = (d_inner + 2 * ngroups * d_state) // tp_size self.conv1d = MambaConv1d(self.conv_dim, self.d_conv, pre_stride=self.d_inner, @@ -274,15 +282,16 @@ def __init__(self, if rmsnorm: self.norm = RmsNorm(normalized_shape=self.d_inner, - num_groups=ngroups, + num_groups=self.ngroups, eps=1e-5, dtype=dtype) - self.out_proj = Linear(self.d_inner, - self.d_model, - bias=bias, - dtype=dtype, - gather_output=False) + self.out_proj = RowLinear(d_inner, + d_model, + bias=bias, + dtype=dtype, + tp_group=tp_group, + tp_size=tp_size) def forward(self, hidden_states: Tensor, diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index e09ab9eac..39481b3a0 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -22,6 +22,7 @@ from .cogvlm.model import CogVLMForCausalLM from .dbrx.config import DbrxConfig from .dbrx.model import DbrxForCausalLM +from .deci.model import DeciLMForCausalLM from .dit.model import DiT from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder from .falcon.config import FalconConfig @@ -157,4 +158,5 @@ 'RecurrentGemmaForCausalLM': RecurrentGemmaForCausalLM, 'CogVLMForCausalLM': CogVLMForCausalLM, 'DiT': DiT, + 'DeciLMForCausalLM': DeciLMForCausalLM, } diff --git a/tensorrt_llm/models/automodel.py b/tensorrt_llm/models/automodel.py new file mode 100644 index 000000000..974064305 --- /dev/null +++ b/tensorrt_llm/models/automodel.py @@ -0,0 +1,73 @@ +from typing import Optional + +from ..mapping import Mapping +from . import MODEL_MAP +from .modeling_utils import QuantConfig + + +class AutoConfig: + + @staticmethod + def from_hugging_face(hf_model_or_dir, + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + import transformers + hf_config = transformers.AutoConfig.from_pretrained( + hf_model_or_dir, trust_remote_code=True) + hf_arch = hf_config.architectures[0] + trtllm_model_cls = MODEL_MAP.get(hf_arch, None) + if trtllm_model_cls is None: + raise NotImplementedError( + f"The given huggingface model architecture {hf_arch} is not supported in TRT-LLM yet" + ) + + if not hasattr(trtllm_model_cls, 'config_class'): + raise NotImplementedError( + f"The given TRT-LLM model class {trtllm_model_cls} does not support AutoConfig" + ) + + trtllm_cfg_cls = getattr(trtllm_model_cls, 'config_class') + if not hasattr(trtllm_cfg_cls, 'from_hugging_face'): + raise NotImplementedError( + f"The given TRT-LLM model class {trtllm_cfg_cls} does not support from_hugging_face" + ) + + return trtllm_cfg_cls.from_hugging_face(hf_model_or_dir, dtype, mapping, + quant_config, **kwargs) + + +class AutoModelForCausalLM: + + @staticmethod + def get_trtllm_model_class(hf_model_or_dir): + import transformers + hf_config = transformers.AutoConfig.from_pretrained( + hf_model_or_dir, trust_remote_code=True) + hf_arch = hf_config.architectures[0] + trtllm_model_cls = MODEL_MAP.get(hf_arch, None) + + if trtllm_model_cls is None: + raise NotImplementedError( + f"The given huggingface model architecture {hf_arch} is not supported in TRT-LLM yet" + ) + return trtllm_model_cls + + @staticmethod + def from_hugging_face(hf_model_or_dir, + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + trtllm_model_cls = AutoModelForCausalLM.get_trtllm_model_class( + hf_model_or_dir) + + if not hasattr(trtllm_model_cls, 'from_hugging_face'): + raise NotImplementedError( + f"The given {trtllm_model_cls} does not support from_hugging_face yet" + ) + + return trtllm_model_cls.from_hugging_face(hf_model_or_dir, dtype, + mapping, quant_config, + **kwargs) diff --git a/tensorrt_llm/models/baichuan/model.py b/tensorrt_llm/models/baichuan/model.py index 5fb74e7f0..cd95c57a5 100644 --- a/tensorrt_llm/models/baichuan/model.py +++ b/tensorrt_llm/models/baichuan/model.py @@ -110,11 +110,9 @@ def __init__(self, config: PretrainedConfig): super().__init__() hidden_size = config.hidden_size - self.vocab_embedding = Embedding( - config.vocab_size, - config.hidden_size, - dtype=config.dtype, - share_embedding_table=config.share_embedding_table) + self.vocab_embedding = Embedding(config.vocab_size, + config.hidden_size, + dtype=config.dtype) self.layers = DecoderLayerList(BaichuanDecoderLayer, config) self.ln_f = RmsNorm(normalized_shape=hidden_size, diff --git a/tensorrt_llm/models/bloom/model.py b/tensorrt_llm/models/bloom/model.py index 62ee0e92b..a189f7c3b 100644 --- a/tensorrt_llm/models/bloom/model.py +++ b/tensorrt_llm/models/bloom/model.py @@ -108,11 +108,9 @@ class BloomModel(Module): def __init__(self, config: PretrainedConfig): super().__init__() - self.vocab_embedding = Embedding( - config.vocab_size, - config.hidden_size, - dtype=config.dtype, - share_embedding_table=config.share_embedding_table) + self.vocab_embedding = Embedding(config.vocab_size, + config.hidden_size, + dtype=config.dtype) self.ln_embed = LayerNorm(normalized_shape=config.hidden_size, dtype=config.dtype) self.layers = DecoderLayerList(BloomDecoderLayer, config) diff --git a/tensorrt_llm/models/chatglm/model.py b/tensorrt_llm/models/chatglm/model.py index d52bb66df..a6b82c928 100644 --- a/tensorrt_llm/models/chatglm/model.py +++ b/tensorrt_llm/models/chatglm/model.py @@ -176,11 +176,9 @@ def __init__(self, config: ChatGLMConfig): self.chatglm_version = config.chatglm_version norm_cls = RmsNorm if config.rmsnorm else LayerNorm - self.vocab_embedding = Embedding( - config.vocab_size, - config.hidden_size, - dtype=config.dtype, - share_embedding_table=config.share_embedding_table) + self.vocab_embedding = Embedding(config.vocab_size, + config.hidden_size, + dtype=config.dtype) if config.chatglm_version == 'glm': self.position_embedding = Embedding( diff --git a/tensorrt_llm/models/convert_utils.py b/tensorrt_llm/models/convert_utils.py index 06f8699d1..486a47c22 100644 --- a/tensorrt_llm/models/convert_utils.py +++ b/tensorrt_llm/models/convert_utils.py @@ -1,3 +1,4 @@ +import fnmatch import re from pathlib import Path from typing import Dict, Optional, Union @@ -73,12 +74,15 @@ def weight_only_quantize_dict(weights: Dict[str, torch.Tensor], if quant_algo not in [QuantAlgo.W4A16, QuantAlgo.W8A16]: return weights if exclude_modules is None: - exclude_modules = ['shared_expert_gate.weight'] + exclude_modules = ['*shared_expert_gate.weight'] for name in list(weights): - if any([_name in name for _name in exclude_modules]): - continue - if any([_name in name for _name in quant_weights - ]) and weights[name].dtype != torch.int8: + is_excluded = False + for exclude_module in exclude_modules: + if fnmatch.fnmatchcase(name, exclude_module): + is_excluded = True + break + if not is_excluded and any([_name in name for _name in quant_weights + ]) and weights[name].dtype != torch.int8: quant_weight, quant_scale = weight_only_quantize( weight=weights[name], quant_algo=quant_algo, plugin=plugin) weights[name] = quant_weight diff --git a/tensorrt_llm/models/deci/__init__.py b/tensorrt_llm/models/deci/__init__.py new file mode 100644 index 000000000..71bf6d298 --- /dev/null +++ b/tensorrt_llm/models/deci/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorrt_llm/models/deci/config.py b/tensorrt_llm/models/deci/config.py new file mode 100644 index 000000000..b9accc61e --- /dev/null +++ b/tensorrt_llm/models/deci/config.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Union + +import torch + +from tensorrt_llm._utils import torch_dtype_to_str +from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.deci.convert import hf_block_config_to_layer_config +from tensorrt_llm.models.deci.layer_config import (AttentionConfig, + AttentionImplementation, + DeciLayerConfig, FFNConfig) +from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig + + +class DeciConfig(PretrainedConfig): + + def __init__(self, + *, + architecture: str = 'DeciLMForCausalLM', + dtype: str, + hidden_size: int, + num_hidden_layers: int, + num_attention_heads: int, + vocab_size: int, + hidden_act: str = 'gelu', + logits_dtype: str = 'float32', + norm_epsilon: float = 0.00001, + position_embedding_type: Union[ + PositionEmbeddingType, + str] = PositionEmbeddingType.rope_gpt_neox, + rotary_base: float = 10000.0, + rotary_scaling: Optional[dict] = None, + max_position_embeddings: int, + num_key_value_heads: Optional[int] = None, + intermediate_size: Optional[int] = None, + mapping: Optional[Union[Mapping, dict]] = None, + quantization: Optional[Union[QuantConfig, dict]] = None, + use_parallel_embedding: bool = False, + embedding_sharding_dim: int = 0, + share_embedding_table: bool = False, + head_size: Optional[int] = None, + qk_layernorm: bool = False, + layer_configs: Optional[List[Union[DeciLayerConfig, + Dict[str, + Dict[str, + Any]]]]] = None, + **kwargs): + super().__init__(architecture=architecture, + dtype=dtype, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + vocab_size=vocab_size, + hidden_act=hidden_act, + logits_dtype=logits_dtype, + norm_epsilon=norm_epsilon, + position_embedding_type=position_embedding_type, + max_position_embeddings=max_position_embeddings, + num_key_value_heads=num_key_value_heads, + intermediate_size=intermediate_size, + mapping=mapping, + quantization=quantization, + use_parallel_embedding=use_parallel_embedding, + embedding_sharding_dim=embedding_sharding_dim, + share_embedding_table=share_embedding_table, + head_size=head_size, + qk_layernorm=qk_layernorm, + **kwargs) + + self.rotary_base = rotary_base + self.rotary_scaling = rotary_scaling + + if layer_configs is not None: + assert len( + layer_configs + ) == num_hidden_layers, f"num_hidden_layers ({num_hidden_layers}) must match len(layer_configs) ({len(layer_configs)})" + + self.layer_configs = self._ensure_layer_configs(layer_configs) + else: + self.layer_configs = None + + # HACK: this is needed for many parts of the code + self.layer_types = [ + AttentionImplementation( + self.get_layer_config(layer_idx).attention.impl).value + for layer_idx in range(self.num_hidden_layers) + ] + + def _ensure_layer_configs( + self, layer_configs: List[Union[DeciLayerConfig, Dict[str, Any]]] + ) -> List[DeciLayerConfig]: + return [ + DeciLayerConfig.from_dict(c) if isinstance(c, dict) else c + for c in layer_configs + ] + + def to_dict(self): + output = super().to_dict() + if self.layer_configs is not None: + output["layer_configs"] = [asdict(c) for c in self.layer_configs] + return output + + def get_layer_config(self, layer_idx: int) -> DeciLayerConfig: + if self.layer_configs is not None: + conf = self.layer_configs[layer_idx] + else: + conf = DeciLayerConfig() + + attention_impl = conf.attention.impl + num_key_value_heads = conf.attention.num_key_value_heads or self.num_key_value_heads + ffn_impl = conf.ffn.impl + intermediate_size = conf.ffn.intermediate_size or self.intermediate_size + + return DeciLayerConfig( + attention=AttentionConfig(impl=attention_impl, + num_key_value_heads=num_key_value_heads), + ffn=FFNConfig(impl=ffn_impl, intermediate_size=intermediate_size)) + + def get_layer_num_kv_heads(self, layer_idx) -> int: + layer_config = self.get_layer_config(layer_idx) + assert layer_config.is_attention_layer, f"Layer {layer_idx} is not an attention layer" + return layer_config.attention.num_key_value_heads or self.num_key_value_heads + + @classmethod + def from_hugging_face( + cls, + hf_config_or_dir: Union[str, 'transformers.PretrainedConfig'], + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + trust_remote_code: bool = False, + **kwargs): + import transformers + + if isinstance(hf_config_or_dir, transformers.PretrainedConfig): + hf_config = hf_config_or_dir + else: + hf_config = transformers.AutoConfig.from_pretrained( + hf_config_or_dir, trust_remote_code=trust_remote_code) + + assert hf_config.model_type == "deci", f"Unsupported model type: {hf_config.model_type}" + + block_configs = getattr(hf_config, "block_configs", None) + if block_configs is not None: + layer_configs = [ + hf_block_config_to_layer_config(block_config, + hf_config.num_attention_heads, + hf_config.hidden_size) + for block_config in block_configs + ] + else: + # older deci arch + num_key_value_heads_per_layer = getattr( + hf_config, "num_key_value_heads_per_layer", None) + if num_key_value_heads_per_layer is not None: + layer_configs = [ + DeciLayerConfig(attention=AttentionConfig( + num_key_value_heads=num_key_value_heads)) + for num_key_value_heads in num_key_value_heads_per_layer + ] + else: + layer_configs = None + + if dtype == 'auto': + dtype = getattr(hf_config, 'torch_dtype', "float16") + if isinstance(dtype, torch.dtype): + dtype = torch_dtype_to_str(dtype) + if dtype == 'float32': + dtype = 'float16' + if dtype == 'bfloat16' and torch.cuda.get_device_properties( + 0).major < 8: + logger.warning( + "Pre SM 80 GPUs do not support bfloat16, fallback to float16") + dtype = 'float16' + + return cls(dtype=dtype, + hidden_size=hf_config.hidden_size, + hidden_act=hf_config.hidden_act, + intermediate_size=hf_config.intermediate_size, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + norm_epsilon=hf_config.rms_norm_eps, + rotary_scaling=hf_config.rope_scaling, + rotary_base=hf_config.rope_theta, + vocab_size=hf_config.vocab_size, + max_position_embeddings=hf_config.max_position_embeddings, + mapping=mapping, + quantization=quant_config, + layer_configs=layer_configs, + **kwargs) diff --git a/tensorrt_llm/models/deci/convert.py b/tensorrt_llm/models/deci/convert.py new file mode 100644 index 000000000..5f0a58ef5 --- /dev/null +++ b/tensorrt_llm/models/deci/convert.py @@ -0,0 +1,365 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +import json +import time +from abc import ABC, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, Iterator, Optional, TypedDict, Union + +import safetensors +import torch + +from tensorrt_llm._utils import pad_vocab_size +from tensorrt_llm.logger import logger +from tensorrt_llm.models.deci.layer_config import (AttentionConfig, + AttentionImplementation, + DeciLayerConfig, FFNConfig, + FFNImplementation) +from tensorrt_llm.models.llama.convert import dup_kv_weight, split +from tensorrt_llm.quantization.mode import QuantAlgo + + +def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: + intermediate_size = int(2 * ffn_mult * n_embd / 3) + return _find_multiple(intermediate_size, 256) + + +def _find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +# BlockConfig is a custom class defined inside deci huggingface checkpoints, we can't import it +def hf_block_config_to_layer_config(block_config: "BlockConfig", + num_attn_heads: int, + hidden_size: int) -> DeciLayerConfig: + attn = block_config.attention + if attn.no_op: + attn_impl = AttentionImplementation.NO_OP + num_key_value_heads = None + elif attn.replace_with_linear: + attn_impl = AttentionImplementation.LINEAR + num_key_value_heads = None + elif attn.sparsify: + raise NotImplementedError("Sparsification is not supported") + else: + attn_impl = AttentionImplementation.ATTENTION + num_key_value_heads = num_attn_heads // attn.n_heads_in_group + + ffn = block_config.ffn + if ffn.no_op: + ffn_impl = FFNImplementation.NO_OP + intermediate_size = None + elif ffn.replace_with_linear: + ffn_impl = FFNImplementation.LINEAR + intermediate_size = None + elif ffn.sparsify: + raise NotImplementedError("Sparsification is not supported") + else: + ffn_impl = FFNImplementation.MLP + intermediate_size = _ffn_mult_to_intermediate_size( + ffn.ffn_mult, hidden_size) + + return DeciLayerConfig(attention=AttentionConfig( + impl=attn_impl, num_key_value_heads=num_key_value_heads), + ffn=FFNConfig(impl=ffn_impl, + intermediate_size=intermediate_size)) + + +@contextmanager +def timed_loading() -> Iterator[None]: + tik = time.time() + yield + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + logger.info(f'Weights loaded. Total time: {t}') + + +class TpDim(enum.IntEnum): + NO_TP = -1 + COLWISE = 0 + ROWWISE = 1 + + +class SafetensorsIndex(TypedDict): + metadata: Dict[str, Any] + weight_map: Dict[str, str] + + +class WeightsLoader(ABC): + + @abstractmethod + def get_weight(self, + name: str, + tp_dim: TpDim = TpDim.NO_TP, + tp_size: int = 1, + tp_rank: int = 0) -> torch.Tensor: + ... + + +class HFModelWeightsLoader(WeightsLoader): + + def __init__(self, *, hf_model: "transformers.PreTrainedModel", + dtype: str) -> None: + self.model_params = dict(hf_model.named_parameters()) + self.dtype = getattr(torch, dtype) + + def get_weight(self, + name: str, + tp_dim: TpDim = TpDim.NO_TP, + tp_size: int = 1, + tp_rank: int = 0) -> torch.Tensor: + weight = self.model_params[name] + if weight.dtype != self.dtype: + weight = weight.to(self.dtype) + weight = weight.detach() + + if tp_dim != TpDim.NO_TP: + weight = split(weight, tp_size, tp_rank, dim=tp_dim) + return weight + + +class SafetensorsWeightsLoader(WeightsLoader): + + def __init__(self, *, model_dir: Path, dtype: str) -> None: + self.model_dir = model_dir + self.dtype = getattr(torch, dtype) + + # the index has a weight map that maps weight names to the files they are found in + safetensor_index_json = self.model_dir / "model.safetensors.index.json" + has_safetensor_index_json = safetensor_index_json.is_file() + if has_safetensor_index_json: + with safetensor_index_json.open("r") as fr: + self.sharding_map: SafetensorsIndex = json.load(fr) + else: + self.sharding_map = SafetensorsIndex(metadata={}, weight_map={}) + + shard_files = {f.name for f in self.model_dir.glob("*.safetensors")} + if has_safetensor_index_json: + # only read the files that have weights according to the index + shard_files &= set(self.sharding_map["weight_map"].values()) + self.shard_files = sorted(list(shard_files)) + + self.safetensors_files = { + shard_file: safetensors.safe_open(model_dir / shard_file, + framework="pt", + device="cpu") + for shard_file in shard_files + } + + def get_weight(self, + name: str, + tp_dim: TpDim = TpDim.NO_TP, + tp_size: int = 1, + tp_rank: int = 0) -> torch.Tensor: + shard_filename = self.sharding_map['weight_map'].get( + name, self.shard_files[0]) + if tp_dim == TpDim.NO_TP: + res = self.safetensors_files[shard_filename].get_tensor(name) + else: + tensor_slice = self.safetensors_files[shard_filename].get_slice( + name) + tensor_shape = tensor_slice.get_shape() + if len(tensor_shape) == 1: + if tp_dim == TpDim.COLWISE: + slice_width = tensor_shape[0] // tp_size + res = tensor_slice[slice_width * tp_rank:slice_width * + (tp_rank + 1)] + else: # row-wise, but 1-dimensional ==> no tp + res = tensor_slice[:] + else: + assert tensor_shape[ + tp_dim] % tp_size == 0, f"Current weight shape is invalid for tp_size={tp_size}" + slice_width = tensor_shape[tp_dim] // tp_size + if tp_dim == TpDim.COLWISE: + res = tensor_slice[slice_width * tp_rank:slice_width * + (tp_rank + 1), :] + else: + res = tensor_slice[:, slice_width * tp_rank:slice_width * + (tp_rank + 1)] + + return res.to(self.dtype).contiguous() + + +def load_model_weights(loader: WeightsLoader, + config: "DeciConfig") -> Dict[str, torch.Tensor]: + mapping = config.mapping + num_hidden_layers = config.num_hidden_layers + vocab_size = config.vocab_size + pad_vocab = vocab_size % mapping.tp_size != 0 + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + + weights = {} + + def load_weight(name: str, tp_dim: TpDim = TpDim.NO_TP) -> torch.Tensor: + return loader.get_weight(name=name, + tp_dim=tp_dim, + tp_rank=mapping.tp_rank, + tp_size=mapping.tp_size) + + with timed_loading(): + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = load_weight( + "model.embed_tokens.weight", + TpDim(config.embedding_sharding_dim) + if config.use_parallel_embedding else + TpDim.NO_TP) # vocab_embedding + + if mapping.is_last_pp_rank(): + v = load_weight("lm_head.weight", + TpDim.NO_TP) if pad_vocab else load_weight( + "lm_head.weight", TpDim.COLWISE) # lm_head + if pad_vocab: + v = torch.nn.functional.pad( + v, (0, 0, 0, vocab_size_padded - vocab_size), 'constant', 0) + v = split(v, mapping.tp_size, mapping.tp_rank) + weights['lm_head.weight'] = v + weights['transformer.ln_f.weight'] = load_weight( + "model.norm.weight") # ln_f + + layers_range = mapping.pp_layers(num_hidden_layers) + for l in layers_range: + layer_config = config.get_layer_config(l) + layer_idx = l - layers_range[0] + tllm_prex = f'transformer.layers.{layer_idx}' + + # Attention + if layer_config.is_attention_layer: + weights[f'{tllm_prex}.input_layernorm.weight'] = load_weight( + f"model.layers.{l}.input_layernorm.weight" + ) # input_layernorm + + qkv = {} + for comp in ["q", "k", "v"]: + weight_part = load_weight( + f"model.layers.{l}.self_attn.{comp}_proj.weight", + TpDim.COLWISE) + qkv[comp] = weight_part + + if layer_config.attention.num_key_value_heads < mapping.tp_size: + # duplicate the KV heads up to tensor_parallel + qkv["k"] = dup_kv_weight( + qkv["k"], layer_config.attention.num_key_value_heads, + mapping.tp_size) + qkv["v"] = dup_kv_weight( + qkv["v"], layer_config.attention.num_key_value_heads, + mapping.tp_size) + + weights[f'{tllm_prex}.attention.qkv.weight'] = torch.cat( + [qkv["q"], qkv["k"], qkv["v"]], 0) + weights[f'{tllm_prex}.attention.dense.weight'] = load_weight( + f"model.layers.{l}.self_attn.o_proj.weight", + TpDim.ROWWISE) # attention.dense + + elif layer_config.is_linear_attention_layer: + weights[f'{tllm_prex}.input_layernorm.weight'] = load_weight( + f"model.layers.{l}.input_layernorm.weight" + ) # input_layernorm + + weights[f'{tllm_prex}.attention.weight'] = load_weight( + f"model.layers.{l}.self_attn.linear_attn.weight", + TpDim.COLWISE) + + elif not layer_config.is_noop_attention_layer: + raise NotImplementedError( + f"Loading weights for layer with attention of type {layer_config.attention.impl} is not supported" + ) + + # MLP + if layer_config.is_mlp_layer: + weights[f'{tllm_prex}.post_layernorm.weight'] = load_weight( + f"model.layers.{l}.post_attention_layernorm.weight" + ) # post_layernorm + + weights[f'{tllm_prex}.ffn.gate.weight'] = load_weight( + f"model.layers.{l}.mlp.up_proj.weight", + TpDim.COLWISE) # mlp.gate + weights[f'{tllm_prex}.ffn.proj.weight'] = load_weight( + f"model.layers.{l}.mlp.down_proj.weight", + TpDim.ROWWISE) # mlp.proj + weights[f'{tllm_prex}.ffn.fc.weight'] = load_weight( + f"model.layers.{l}.mlp.gate_proj.weight", + TpDim.COLWISE) # mlp.fc + + elif layer_config.is_linear_ffn_layer: + weights[f'{tllm_prex}.post_layernorm.weight'] = load_weight( + f"model.layers.{l}.post_attention_layernorm.weight" + ) # post_layernorm + + weights[f'{tllm_prex}.ffn.weight'] = load_weight( + f"model.layers.{l}.mlp.linear_mlp.weight", TpDim.COLWISE) + + elif not layer_config.is_noop_ffn_layer: + raise NotImplementedError( + f"Loading weights for a layer with FFN of type {layer_config.ffn.impl} is not implemented yet" + ) + + return weights + + +def load_weights_from_hf_model( + hf_model: "transformers.PreTrainedModel", + config: "DeciConfig", + act_range: Optional[dict] = None, + qkv_para: Optional[dict] = None, + smoother: Optional[dict] = None) -> Dict[str, torch.Tensor]: + quant_algo = config.quantization.quant_algo + use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16] + if quant_algo == QuantAlgo.W8A16: + torch.int8 + elif quant_algo == QuantAlgo.W4A16: + torch.quint4x2 + else: + pass + + use_smooth_quant = config.quantization.use_plugin_sq + int8_kv_cache = config.quantization.kv_cache_quant_algo == QuantAlgo.INT8 + if use_smooth_quant or int8_kv_cache: + assert act_range is not None + assert qkv_para is not None + assert smoother is not None + + # TODO(oargov): add support for these quants + assert not use_weight_only, "WOQ is not supported yet" + assert not use_smooth_quant, "SmoothQuant is not supported yet" + assert not int8_kv_cache, "INT8 KV cache is not supported yet" + + # TODO(oargov): support moe + moe_config = getattr(config, "moe", None) + assert moe_config is None, "MoE is not supported yet" + + # TODO(oargov): implement resisdual mlp + residual_mlp = getattr(config, "residual_mlp", None) + assert not residual_mlp, "Residual MLP is not supported yet" + + loader = HFModelWeightsLoader(hf_model=hf_model, dtype=config.dtype) + logger.info('Converting weights from Huggingface model...') + return load_model_weights(loader=loader, config=config) + + +def load_weights_from_hf_safetensors( + model_dir: Union[str, Path], + config: "DeciConfig") -> Dict[str, torch.Tensor]: + + if isinstance(model_dir, str): + model_dir = Path(model_dir) + + loader = SafetensorsWeightsLoader(model_dir=model_dir, dtype=config.dtype) + logger.info('Loading weights from Huggingface safetensors...') + return load_model_weights(loader=loader, config=config) diff --git a/tensorrt_llm/models/deci/layer_config.py b/tensorrt_llm/models/deci/layer_config.py new file mode 100644 index 000000000..84ed3487d --- /dev/null +++ b/tensorrt_llm/models/deci/layer_config.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +class AttentionImplementation(str, enum.Enum): + ATTENTION = "attention" + LINEAR = "linear" + NO_OP = "no_op" + + +class FFNImplementation(str, enum.Enum): + MLP = "mlp" + LINEAR = "linear" + NO_OP = "no_op" + + +@dataclass(frozen=True, kw_only=True) +class AttentionConfig: + impl: AttentionImplementation = AttentionImplementation.ATTENTION + num_key_value_heads: Optional[int] = None + + @property + def needs_kv_cache(self) -> bool: + return self.impl == AttentionImplementation.ATTENTION + + +@dataclass(frozen=True, kw_only=True) +class FFNConfig: + impl: FFNImplementation = FFNImplementation.MLP + intermediate_size: Optional[int] = None + + +@dataclass(frozen=True, kw_only=True) +class DeciLayerConfig: + attention: AttentionConfig = field(default_factory=AttentionConfig) + ffn: FFNConfig = field(default_factory=FFNConfig) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "DeciLayerConfig": + assert "attention" in d, "Missing attention configuration" + assert "ffn" in d, "Missing mlp configuration" + + return cls( + attention=AttentionConfig(**d["attention"]), + ffn=FFNConfig(**d["ffn"]), + ) + + @property + def is_attention_layer(self) -> bool: + return self.attention.impl == AttentionImplementation.ATTENTION + + @property + def is_mlp_layer(self) -> bool: + return self.ffn.impl == FFNImplementation.MLP + + @property + def is_noop_attention_layer(self) -> bool: + return self.attention.impl == AttentionImplementation.NO_OP + + @property + def is_linear_attention_layer(self) -> bool: + return self.attention.impl == AttentionImplementation.LINEAR + + @property + def is_noop_ffn_layer(self) -> bool: + return self.ffn.impl == FFNImplementation.NO_OP + + @property + def is_linear_ffn_layer(self) -> bool: + return self.ffn.impl == FFNImplementation.LINEAR diff --git a/tensorrt_llm/models/deci/model.py b/tensorrt_llm/models/deci/model.py new file mode 100644 index 000000000..b0d0ded0e --- /dev/null +++ b/tensorrt_llm/models/deci/model.py @@ -0,0 +1,643 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Type, Union + +from tensorrt_llm.bindings import KVCacheType +from tensorrt_llm.functional import (AllReduceFusionParams, AttentionMaskType, + PositionEmbeddingType, Tensor, + gather_last_token_logits, recv, send) +from tensorrt_llm.layers.attention import (Attention, AttentionParams, + KeyValueCacheParams, + SpecDecodingParams) +from tensorrt_llm.layers.embedding import Embedding +from tensorrt_llm.layers.linear import ColumnLinear +from tensorrt_llm.layers.lora import LoraParams +from tensorrt_llm.layers.mlp import GatedMLP +from tensorrt_llm.layers.normalization import RmsNorm +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.convert_utils import has_safetensors +from tensorrt_llm.models.deci.config import DeciConfig +from tensorrt_llm.models.deci.convert import (load_weights_from_hf_model, + load_weights_from_hf_safetensors) +from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM +from tensorrt_llm.module import Module, ModuleList +from tensorrt_llm.plugin.plugin import init_all_reduce_helper + +from ..._common import default_net +from ..._utils import pad_vocab_size +from ..modeling_utils import QuantConfig, preprocess_weights + + +@dataclass +class DeciLMLayerOutput: + hidden_states: Tensor + present_kv: Optional[Tensor] = None + + +@dataclass +class DeciLMLayerListOutput: + hidden_states: Tensor + present_kvs: List[Tensor] + + +class NoOp(Module): + + def forward(self, hidden_states: Tensor, *args, **kwargs) -> int: + return 0 + + +class NoOpAttention(NoOp): + + def forward(self, + hidden_states: Tensor, + attention_mask=None, + use_cache: bool = False, + *args, + **kwargs) -> Union[int, Tuple[int, None]]: + out = super().forward(hidden_states=hidden_states, + attention_mask=attention_mask, + use_cache=use_cache, + *args, + **kwargs) + if use_cache: + return out, None + return out + + +class LinearAttention(ColumnLinear): + + def forward(self, + hidden_states: Tensor, + attention_mask=None, + use_cache: bool = False, + *args, + **kwargs) -> Union[Tensor, Tuple[Tensor, None]]: + out = super().forward(x=hidden_states, + lora_runtime_params=None, + lora_hidden_state=None) + + if use_cache: + return out, None + return out + + +class LinearFFN(ColumnLinear): + + def forward( + self, + hidden_states, + lora_layer_params=None, + reduce_fusion_params: Optional[AllReduceFusionParams] = None + ) -> Tensor: + return super().forward(x=hidden_states, + lora_runtime_params=None, + lora_hidden_state=None) + + +NoOpFFN = NoOp +NoOpLayerNorm = NoOp + + +class DeciLMDecoderLayer(Module): + + def __init__(self, config: DeciConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.config = config + + layers_range = config.mapping.pp_layers(config.num_hidden_layers) + self.local_layer_idx = layer_idx - layers_range[0] + + self.layer_config = self.config.get_layer_config(self.layer_idx) + + layer_type_len = len(config.layer_types) + layer_types = config.layer_types * ((layer_idx + 1) // layer_type_len) + layer_types = layer_types + config.layer_types[0:( + (layer_idx + 1) % layer_type_len)] + + attention_layer_idx = layer_types.count('attention') - 1 + self._init_attention(attention_layer_idx) + self._init_ffn() + + def _init_attention(self, attention_layer_idx) -> None: + """ + Initialize some attention alternative + """ + # normal attention + if self.layer_config.is_attention_layer: + self.input_layernorm = RmsNorm( + normalized_shape=self.config.hidden_size, + eps=self.config.norm_epsilon, + dtype=self.config.dtype, + ) + + self.attention = Attention( + local_layer_idx=attention_layer_idx, + hidden_size=self.config.hidden_size, + attention_head_size=self.config.head_size, + num_attention_heads=self.config.num_attention_heads, + num_kv_heads=self.layer_config.attention.num_key_value_heads, + max_position_embeddings=self.config.max_position_embeddings, + dtype=self.config.dtype, + attention_mask_type=AttentionMaskType.causal, + bias=False, + position_embedding_type=PositionEmbeddingType.rope_gpt_neox, + rotary_embedding_base=self.config.rotary_base, + rotary_embedding_scaling=self.config.rotary_scaling, + tp_group=self.config.mapping.tp_group, + tp_size=self.config.mapping.tp_size, + tp_rank=self.config.mapping.tp_rank, + quant_mode=self.config.quant_mode, + ) + + elif self.layer_config.is_noop_attention_layer: + self.input_layernorm = NoOpLayerNorm() + self.attention = NoOpAttention() + + elif self.layer_config.is_linear_attention_layer: + self.input_layernorm = RmsNorm( + normalized_shape=self.config.hidden_size, + eps=self.config.norm_epsilon, + dtype=self.config.dtype, + ) + + self.attention = LinearAttention( + in_features=self.config.hidden_size, + out_features=self.config.hidden_size, + bias=False, + dtype=self.config.dtype, + tp_group=self.config.mapping.tp_group, + tp_size=self.config.mapping.tp_size, + gather_output=True) + + else: + raise NotImplementedError( + f"Attention of type {str(self.layer_config.attention.impl)} is not implemented" + ) + + def _init_ffn(self) -> None: + """ + Initialize some ffn alternative + """ + + if self.layer_config.is_mlp_layer: + intermediate_size = self.layer_config.ffn.intermediate_size or self.config.intermediate_size + mlp_hidden_size = intermediate_size or self.config.hidden_size * 4 + + self.post_layernorm = RmsNorm( + normalized_shape=self.config.hidden_size, + eps=self.config.norm_epsilon, + dtype=self.config.dtype, + ) + + self.ffn = GatedMLP( + hidden_size=self.config.hidden_size, + ffn_hidden_size=mlp_hidden_size, + hidden_act=self.config.hidden_act, + bias=False, + dtype=self.config.dtype, + tp_group=self.config.mapping.tp_group, + tp_size=self.config.mapping.tp_size, + quant_mode=self.config.quant_mode, + ) + + elif self.layer_config.is_noop_ffn_layer: + self.post_layernorm = NoOpLayerNorm() + self.ffn = NoOpFFN() + + elif self.layer_config.is_linear_ffn_layer: + self.post_layernorm = RmsNorm( + normalized_shape=self.config.hidden_size, + eps=self.config.norm_epsilon, + dtype=self.config.dtype, + ) + + self.ffn = LinearFFN(in_features=self.config.hidden_size, + out_features=self.config.hidden_size, + bias=False, + dtype=self.config.dtype, + tp_group=self.config.mapping.tp_group, + tp_size=self.config.mapping.tp_size, + gather_output=True) + + else: + raise NotImplementedError( + f"FFN of type {str(self.layer_config.ffn.impl)} is not implemented" + ) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + use_cache: bool = False, + spec_decoding_params=None, + kv_cache_params: Optional[KeyValueCacheParams] = None, + attention_params: Optional[AttentionParams] = None, + lora_layer_params: Optional[LoraParams] = None, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + attention_output = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + use_cache=use_cache, + spec_decoding_params=spec_decoding_params, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + lora_layer_params=lora_layer_params, + ) + + if use_cache: + attention_output, present_kv = attention_output + else: + present_kv = None + + hidden_states = residual + attention_output + residual = hidden_states + hidden_states = self.post_layernorm(hidden_states) + hidden_states = self.ffn(hidden_states, + lora_layer_params=lora_layer_params) + hidden_states = residual + hidden_states + + return DeciLMLayerOutput(hidden_states=hidden_states, + present_kv=present_kv) + + +class DeciLMDecoderLayerList(ModuleList): + + def __init__(self, cls: Type[DeciLMDecoderLayer], config: DeciConfig): + self.num_hidden_layers = config.num_hidden_layers + # global indices of local layers + self.layer_list = config.mapping.pp_layers(config.num_hidden_layers) + super().__init__([cls(config, idx) for idx in self.layer_list]) + # global indices of local attention layers + self.attention_layer_list = [ + self.layer_list[i] for i, layer in enumerate(self) + if layer.layer_config.is_attention_layer + ] + + def forward( + self, + hidden_states: Tensor, + use_cache: bool, + attention_mask: Optional[Tensor], + kv_cache_params: KeyValueCacheParams, + attention_params: Optional[AttentionParams] = None, + position_ids: Optional[Tensor] = None, + lora_params: Optional[LoraParams] = None, + spec_decoding_params: Optional[SpecDecodingParams] = None, + ) -> DeciLMLayerListOutput: + kv_cache_params.fill_none_tensor_list(len(self.layer_list)) + + presents = [] + + # put None where we don't have attention layers + pkv_iter = iter(kv_cache_params.past_key_value) + + past_key_values = [x for x in pkv_iter] + + for layer_idx, (layer, past) in enumerate(zip(self, past_key_values)): + layer_out = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + attention_params=attention_params, + kv_cache_params=KeyValueCacheParams( + past_key_value=[past], + host_past_key_value_lengths=kv_cache_params. + host_past_key_value_lengths, + host_max_attention_window_sizes=kv_cache_params. + host_max_attention_window_sizes, + host_sink_token_length=kv_cache_params. + host_sink_token_length, + kv_cache_block_offsets=kv_cache_params. + kv_cache_block_offsets, + host_kv_cache_block_offsets=kv_cache_params. + host_kv_cache_block_offsets, + host_kv_cache_pool_pointers=kv_cache_params. + host_kv_cache_pool_pointers, + cache_indirection=kv_cache_params.cache_indirection, + ), + spec_decoding_params=spec_decoding_params, + use_cache=use_cache, + lora_layer_params=lora_params.get_layer_config(layer_idx) + if lora_params is not None + and lora_params.lora_ranks is not None else None) + + hidden_states = layer_out.hidden_states + if use_cache and layer_out.present_kv is not None: + presents.append(layer_out.present_kv) + + return DeciLMLayerListOutput(hidden_states=hidden_states, + present_kvs=presents) + + +class DeciLMModel(Module): + + def __init__(self, config: DeciConfig) -> None: + super().__init__() + init_all_reduce_helper() + + self.mapping = config.mapping + if self.mapping.is_first_pp_rank(): + # first rank in pipeline-parallel handles token embedding + assert config.vocab_size is not None + self.vocab_embedding = Embedding(config.vocab_size, + config.hidden_size, + dtype=config.dtype) + + self.position_embedding_type = config.position_embedding_type + self.layers = DeciLMDecoderLayerList(DeciLMDecoderLayer, config) + + if self.mapping.is_last_pp_rank(): + # last rank in pipeline-parallel handles final norm + self.ln_f = RmsNorm( + normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype, + ) + + def _vocab_embedding(self, + input_ids: Tensor, + prompt_embedding_table: Optional[Tensor] = None, + prompt_tasks: Optional[Tensor] = None, + prompt_vocab_size: Optional[Tensor] = None) -> Tensor: + # prompt tuning + ptuning_args = ([ + prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else []) + + hidden_states = self.vocab_embedding(input_ids, *ptuning_args) + return hidden_states + + def forward( + self, + input_ids, + position_ids=None, + use_cache: bool = False, + attention_mask: Optional[Tensor] = None, + spec_decoding_params=None, + kv_cache_params: Optional[KeyValueCacheParams] = None, + attention_params: Optional[AttentionParams] = None, + hidden_states: Optional[Tensor] = None, + prompt_embedding_table: Optional[Tensor] = None, + prompt_tasks: Optional[Tensor] = None, + prompt_vocab_size: Optional[Tensor] = None, + lora_params: Optional[LoraParams] = None, + ) -> DeciLMLayerListOutput: + + if self.mapping.is_first_pp_rank(): + # first pipeline rank ==> do prompt embedding + hidden_states = self._vocab_embedding( + input_ids=input_ids, + prompt_embedding_table=prompt_embedding_table, + prompt_tasks=prompt_tasks, + prompt_vocab_size=prompt_vocab_size) + else: + # receive hidden states from prior rank in the pipeline + hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) + + layers_out = self.layers.forward( + hidden_states, + use_cache=use_cache, + attention_mask=attention_mask, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + lora_params=lora_params, + spec_decoding_params=spec_decoding_params, + ) + + if self.mapping.is_last_pp_rank(): + # last pipeline rank ==> do final norm + hidden_states = self.ln_f(layers_out.hidden_states) + else: + # send hidden states to next rank in the pipeline + hidden_states = send(layers_out.hidden_states, + self.mapping.next_pp_rank()) + + return DeciLMLayerListOutput(hidden_states=hidden_states, + present_kvs=layers_out.present_kvs) + + +class DeciLMForCausalLM(DecoderModelForCausalLM): + config_class = DeciConfig + + def __init__(self, config: DeciConfig): + + transformer = DeciLMModel(config) + vocab_size_padded = pad_vocab_size(config.vocab_size, + config.mapping.tp_size) + + if config.mapping.is_last_pp_rank(): + # last pipeline rank needs to do calculate logits + lm_head = ColumnLinear( + config.hidden_size, + vocab_size_padded, + bias=False, + dtype=config.dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + gather_output=True, + ) + else: + lm_head = None + super().__init__(config, transformer, lm_head) + + # Create constant attention parameters to be reused by all layers. + Attention.create_attention_const_params(self, config) + self.position_embedding_type = config.position_embedding_type + + @classmethod + def from_hugging_face(cls, + hf_model_or_dir: Union[ + str, 'transformers.PreTrainedModel'], + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + load_by_shard: bool = False, + load_model_on_cpu: bool = False, + trust_remote_code: bool = False, + **kwargs) -> "DeciLMForCausalLM": + import transformers + + # TODO(oargov): add support for these + assert not load_by_shard, "load_by_shard is not implemented yet" + + use_preloading = isinstance(hf_model_or_dir, + transformers.PreTrainedModel) + if use_preloading: + hf_config_or_dir = hf_model_or_dir.config + else: + hf_config_or_dir = hf_model_or_dir + + config = DeciConfig.from_hugging_face( + hf_config_or_dir, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + trust_remote_code=trust_remote_code, + **kwargs) + + if use_preloading: + assert not load_by_shard + weights = load_weights_from_hf_model(hf_model_or_dir, config) + elif has_safetensors( + hf_model_or_dir) and not config.quant_mode.has_any_quant(): + weights = load_weights_from_hf_safetensors(hf_model_or_dir, config) + else: + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + hf_model_or_dir, + device_map='auto' if not load_model_on_cpu else 'cpu', + torch_dtype=dtype, + trust_remote_code=trust_remote_code, + ) + weights = load_weights_from_hf_model(hf_model, config) + preprocess_weights(weights, config) + + model = DeciLMForCausalLM(config) + model.load(weights) + return model + + def forward( + self, + input_ids: Tensor, + position_ids: Optional[Tensor] = None, + use_cache: bool = False, + last_token_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + kv_cache_params: Optional[KeyValueCacheParams] = None, + attention_params: Optional[AttentionParams] = None, + hidden_states: Optional[Tensor] = None, + prompt_embedding_table: Optional[Tensor] = None, + prompt_tasks: Optional[Tensor] = None, + prompt_vocab_size: Optional[Tensor] = None, + lora_params: Optional[LoraParams] = None, + spec_decoding_params=None, + ): + # fill attention params. + attention_params = Attention.fill_attention_params( + self, attention_params) + + model_out = self.transformer.forward( + input_ids=input_ids, + position_ids=position_ids, + use_cache=use_cache, + attention_mask=attention_mask, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + lora_params=lora_params, + hidden_states=hidden_states, + prompt_embedding_table=prompt_embedding_table, + prompt_tasks=prompt_tasks, + prompt_vocab_size=prompt_vocab_size, + spec_decoding_params=spec_decoding_params) + hidden_states = model_out.hidden_states + + if self.config.mapping.is_last_pp_rank(): + hidden_states = gather_last_token_logits( + hidden_states, + last_token_ids, + default_net().plugin_config.remove_input_padding, + ) + + lm_logits = self.lm_head(hidden_states) + lm_logits.mark_output("logits", self.config.logits_dtype) + else: + hidden_states.mark_output("hidden_states_output", self.config.dtype) + + if use_cache and not default_net().plugin_config.paged_kv_cache: + presents = model_out.present_kvs + for i, present in zip(self.transformer.layers.attention_layer_list, + presents): + present.mark_output(f"present_key_value_{i}", + self.config.kv_dtype) + if self.config.mapping.is_last_pp_rank(): + return (lm_logits, presents, hidden_states) + return (hidden_states, presents) + else: + if self.config.mapping.is_last_pp_rank(): + return lm_logits, hidden_states + return hidden_states + + def prepare_attention_inputs( + self, + *, + max_batch_size: int, + max_beam_width: int, + max_input_len: int, + max_seq_len: int, + num_kv_heads: int, + head_size: int, + num_layers: int, + kv_dtype: str, + kv_cache_type: KVCacheType, + num_profiles: int = 1, + enable_ctx_gen_opt_profiles: bool = False, + remove_input_padding: bool = False, + use_gpt_attention_plugin: bool = False, + paged_kv_cache: bool = False, + tokens_per_block: int = 64, + mapping: Mapping = Mapping(), + use_cache: bool = True, + streamingllm: bool = False, + attn_layer_idx: Optional[List[int]] = None, + opt_batch_size: Optional[int] = None, + num_kv_heads_per_layer: Optional[List[int]] = None): + + if attn_layer_idx is None: + attn_layer_idx, num_kv_heads_per_layer = [], [] + for layer_idx in range(self.config.num_hidden_layers): + layer_config = self.config.get_layer_config(layer_idx) + if layer_config.is_attention_layer: + attn_layer_idx.append(layer_idx) + num_kv_heads_per_layer.append( + layer_config.attention.num_key_value_heads) + num_layers = len(attn_layer_idx) + + attention_inputs = super().prepare_attention_inputs( + max_batch_size=max_batch_size, + max_beam_width=max_beam_width, + max_input_len=max_input_len, + max_seq_len=max_seq_len, + num_kv_heads=num_kv_heads, + head_size=head_size, + num_layers=num_layers, + kv_dtype=kv_dtype, + num_profiles=num_profiles, + kv_cache_type=kv_cache_type, + enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, + remove_input_padding=remove_input_padding, + use_gpt_attention_plugin=use_gpt_attention_plugin, + tokens_per_block=tokens_per_block, + mapping=mapping, + streamingllm=streamingllm, + attn_layer_idx=attn_layer_idx, + opt_batch_size=opt_batch_size, + num_kv_heads_per_layer=num_kv_heads_per_layer) + + kv_idx = 0 + past_key_value = [] + for i in range(self.config.num_hidden_layers): + layer_config = self.config.get_layer_config(i) + if layer_config.is_attention_layer: + past_key_value.append( + attention_inputs['past_key_value'][kv_idx]) + kv_idx += 1 + else: + past_key_value.append(None) + attention_inputs['past_key_value'] = past_key_value + + return attention_inputs diff --git a/tensorrt_llm/models/enc_dec/model.py b/tensorrt_llm/models/enc_dec/model.py index 072900d05..78acb32ff 100644 --- a/tensorrt_llm/models/enc_dec/model.py +++ b/tensorrt_llm/models/enc_dec/model.py @@ -667,7 +667,6 @@ def forward(self, def prepare_inputs(self, max_batch_size, max_input_len, - max_num_tokens, prompt_embedding_table_size: int = 0, lora_target_modules: List[str] = None, *args, @@ -890,7 +889,6 @@ def prepare_inputs(self, lora_params = LoraParams( lora_ranks=lora_ranks, lora_weights_pointers=lora_weights_pointers, - max_num_tokens=max_num_tokens, host_request_types=host_request_types, host_context_lengths=host_context_lengths, ) @@ -1226,7 +1224,6 @@ def prepare_inputs(self, max_beam_width, max_decoder_input_len, max_seq_len, - max_num_tokens, max_encoder_input_len, gather_context_logits: bool = False, gather_generation_logits: bool = False, @@ -1596,7 +1593,6 @@ def prepare_inputs(self, lora_ranks=lora_ranks, lora_weights_pointers=lora_weights_pointers, host_context_lengths=host_context_lengths, - max_num_tokens=max_num_tokens, max_encoder_context_length=max_encoder_input_len, host_request_types=host_request_types, host_encoder_input_lengths=host_encoder_input_lengths, diff --git a/tensorrt_llm/models/generation_mixin.py b/tensorrt_llm/models/generation_mixin.py index 17300b3b0..cb12289f8 100644 --- a/tensorrt_llm/models/generation_mixin.py +++ b/tensorrt_llm/models/generation_mixin.py @@ -185,7 +185,8 @@ def prepare_attention_inputs(self, mapping=Mapping(), streamingllm=False, attn_layer_idx=None, - opt_batch_size=None): + opt_batch_size=None, + num_kv_heads_per_layer=None): default_range = GenerationMixin.default_range @@ -258,16 +259,24 @@ def prepare_attention_inputs(self, else: if kv_cache_type != KVCacheType.PAGED: for i in layers_range: + if num_kv_heads_per_layer is not None: + heads_dim_name = f"num_heads_{attn_layer_idx[i]}" + kv_heads = num_kv_heads_per_layer[i] + else: + heads_dim_name = "num_heads" + kv_heads = num_kv_heads + kv_dim_range = OrderedDict([ ('batch_size_beam_width', bb_range), ('kv', [2] * num_profiles), - ('num_heads', [num_kv_heads] * num_profiles), + (heads_dim_name, [kv_heads] * num_profiles), ('past_key_len', kv_cache_range), ('head_size', [head_size] * num_profiles), ]) + kv = Tensor(name=f'past_key_value_{attn_layer_idx[i]}', dtype=kv_dtype, - shape=[-1, 2, num_kv_heads, -1, head_size], + shape=[-1, 2, kv_heads, -1, head_size], dim_range=kv_dim_range) past_key_value.append(kv) else: @@ -774,7 +783,6 @@ def prepare_basic_inputs( mapping=mapping, streamingllm=streamingllm, opt_batch_size=opt_batch_size) - for key, value in attention_inputs.items(): basic_inputs[key] = value diff --git a/tensorrt_llm/models/mamba/model.py b/tensorrt_llm/models/mamba/model.py index 977154a7e..7d2aac4d6 100644 --- a/tensorrt_llm/models/mamba/model.py +++ b/tensorrt_llm/models/mamba/model.py @@ -21,7 +21,7 @@ from ..._utils import str_dtype_to_trt from ...functional import (Tensor, arange, cast, concat, expand, gather_last_token_logits, shape, unsqueeze) -from ...layers import Embedding, LayerNorm, Linear, Mamba, Mamba2, RmsNorm +from ...layers import ColumnLinear, Embedding, LayerNorm, Mamba, Mamba2, RmsNorm from ...module import Module, ModuleList from ...plugin import current_all_reduce_helper from ..generation_mixin import GenerationMixin @@ -38,6 +38,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): self.last_layer = layer_idx == n_layer - 1 if config.mamba_version == 'Mamba1': + assert config.mapping.tp_size == 1, "Mamba1 can not support tensor parallelism." self.ssm = Mamba(config.hidden_size, config.rnn_hidden_size, d_state=config.state_size, @@ -54,7 +55,9 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): chunk_size=config.chunk_size, bias=config.use_bias, rmsnorm=config.ssm_rmsnorm, - dtype=config.dtype) + dtype=config.dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size) if config.rms_norm: self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, @@ -105,17 +108,15 @@ class MambaModel(Module): def __init__(self, config: PretrainedConfig): super().__init__() self.d_conv = config.conv_kernel - self.d_inner = config.rnn_hidden_size + self.d_inner = config.rnn_hidden_size // config.mapping.tp_size n_layer = config.num_hidden_layers self.residual_in_fp32 = config.residual_in_fp32 if config.vocab_size % config.pad_vocab_size_multiple != 0: config.vocab_size += config.pad_vocab_size_multiple - ( config.vocab_size % config.pad_vocab_size_multiple) - self.vocab_embedding = Embedding( - config.vocab_size, - config.hidden_size, - dtype=config.dtype, - share_embedding_table=config.share_embedding_table) + self.vocab_embedding = Embedding(config.vocab_size, + config.hidden_size, + dtype=config.dtype) self.layers = ModuleList( [MambaLayer(config, i) for i in range(n_layer)]) if config.rms_norm: @@ -180,10 +181,10 @@ def __init__(self, config: PretrainedConfig): self.config = config self.mamba_version = config.mamba_version - self.d_inner = config.rnn_hidden_size + self.d_inner = config.rnn_hidden_size // config.mapping.tp_size self.d_conv = config.conv_kernel self.d_state = config.state_size - self.conv_dim = config.rnn_conv_dim_size + self.conv_dim = config.rnn_conv_dim_size // config.mapping.tp_size self.gather_context_logits = False if isinstance(logits_dtype, str): @@ -193,11 +194,13 @@ def __init__(self, config: PretrainedConfig): self._logits_dtype = logits_dtype self.backbone = MambaModel(config) - self.lm_head = Linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - gather_output=False) + self.lm_head = ColumnLinear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + gather_output=True) def __post_init__(self): return diff --git a/tensorrt_llm/models/model_weights_loader.py b/tensorrt_llm/models/model_weights_loader.py index a2cca7cd3..3b55f2e30 100644 --- a/tensorrt_llm/models/model_weights_loader.py +++ b/tensorrt_llm/models/model_weights_loader.py @@ -9,8 +9,8 @@ from safetensors import safe_open from tqdm import tqdm -from tensorrt_llm.quantization.layers import (WeightOnlyQuantColumnLinear, - WeightOnlyQuantRowLinear) +from tensorrt_llm.quantization.layers import ( + WeightOnlyGroupwiseQuantColumnLinear, WeightOnlyGroupwiseQuantRowLinear) from .._utils import trt_dtype_to_torch from ..logger import logger @@ -185,6 +185,8 @@ def load_tensor(self, key, tp_size, tp_dim, tp_rank): return tensor[:] else: width = tensor_shape[tp_dim] + if width == 1: + return tensor[:] slice_width = math.ceil(width / tp_size) slice_start = tp_rank * slice_width slice_end = min((tp_rank + 1) * slice_width, width) @@ -196,7 +198,8 @@ def load_tensor(self, key, tp_size, tp_dim, tp_rank): def load(self, tllm_key: str, preprocess: Callable[[int], None] = None, - skip_tp: bool = False): + skip_tp: bool = False, + custom_postprocess_kwargs: dict = {}): """Load tensor from shards This function contains following steps: @@ -226,10 +229,10 @@ def load(self, tllm_to_externel_key_dict = sub_module.tllm_to_externel_key_dict if hasattr( sub_module, "tllm_to_externel_key_dict") else {} tp_dim = sub_module.tp_dim if hasattr(sub_module, "tp_dim") else -1 - require_weight_transpose = (isinstance( - sub_module, WeightOnlyQuantColumnLinear) or isinstance( - sub_module, - WeightOnlyQuantRowLinear)) and tllm_key.endswith("weight") + require_weight_transpose = ( + isinstance(sub_module, WeightOnlyGroupwiseQuantColumnLinear) + or isinstance(sub_module, WeightOnlyGroupwiseQuantRowLinear) + ) and tllm_key.endswith("weight") if tp_dim >= 0 and require_weight_transpose: tp_dim = 1 - tp_dim tp_size = sub_module.tp_size if hasattr(sub_module, "tp_size") else 1 @@ -260,7 +263,9 @@ def load(self, else: weight_dict = {tllm_key: v.to(trt_dtype_to_torch(param.dtype))} else: - v = sub_module.postprocess(tllm_key, v) + postprocess_kwargs = {"config": self.model.config} + postprocess_kwargs.update(custom_postprocess_kwargs) + v = sub_module.postprocess(tllm_key, v, **postprocess_kwargs) if isinstance(v, dict): weight_dict = v else: @@ -290,6 +295,9 @@ def check(self, weights): continue w_shape = weights[tllm_key].shape if w_shape != param.shape: + logger.warning( + f'{tllm_key} has invalid shape {w_shape}. Expected {param.shape}.' + ) pad = torch.nn.functional.pad pad_dim = [] for dim in range(weights[tllm_key].dim()): @@ -298,11 +306,12 @@ def check(self, weights): pad_dim.append( max(0, param.shape[current_dim] - w_shape[current_dim])) try: + logger.warning( + f'{tllm_key} is going to be padded by {pad_dim}.') weights[tllm_key] = pad(weights[tllm_key], tuple(pad_dim), value=0) assert weights[tllm_key].shape == param.shape - logger.warning(f'Parameter {tllm_key} is auto padded.') except: raise ValueError( f'Parameter {tllm_key} has invalid shape {weights[tllm_key].shape} compared with expected shape {param.shape}. Auto padding failed.' diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 9b78d12ae..70f7a98db 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -650,7 +650,6 @@ def prepare_inputs( model_inputs['lora_ranks'], model_inputs['lora_weights_pointers'], host_context_lengths=model_inputs['host_context_lengths'], - max_num_tokens=max_num_tokens, host_request_types=model_inputs['host_request_types']) if model_inputs['spec_decoding_params'] is not None: result['spec_decoding_params'] = model_inputs[ @@ -823,6 +822,14 @@ def fuse_gate_mlp( for name, mlp, layer in model.named_modules_with_parent(): if isinstance(mlp, GatedMLP): init_params = get_init_params(mlp) + + hidden_act = init_params["hidden_act"] + if hidden_act not in ["silu", "gelu"]: + logger.warning( + f"fuse_gate_mlp cannot be done for {name} due to unsupported activation {hidden_act}. Skipping." + ) + continue + init_params["inner_layernorm"] = mlp.inner_layernorm is not None fused_layer = FusedGatedMLP(**init_params) diff --git a/tensorrt_llm/models/opt/model.py b/tensorrt_llm/models/opt/model.py index f66bdf11e..45fc5d08b 100644 --- a/tensorrt_llm/models/opt/model.py +++ b/tensorrt_llm/models/opt/model.py @@ -112,11 +112,9 @@ def __init__(self, config: PretrainedConfig): super().__init__() self.do_layer_norm_before = config.do_layer_norm_before - self.vocab_embedding = Embedding( - config.vocab_size, - config.hidden_size, - dtype=config.dtype, - share_embedding_table=config.share_embedding_table) + self.vocab_embedding = Embedding(config.vocab_size, + config.hidden_size, + dtype=config.dtype) self.position_embedding = Embedding(config.max_position_embeddings, config.hidden_size, dtype=config.dtype) diff --git a/tensorrt_llm/models/phi/model.py b/tensorrt_llm/models/phi/model.py index 5cacb835f..71f64a640 100644 --- a/tensorrt_llm/models/phi/model.py +++ b/tensorrt_llm/models/phi/model.py @@ -144,6 +144,8 @@ def forward( class PhiForCausalLM(DecoderModelForCausalLM): config_class = PhiConfig + config_class = PhiConfig + def __init__(self, config: PretrainedConfig): self.check_config(config) transformer = PhiModel(config) diff --git a/tensorrt_llm/models/qwen/convert.py b/tensorrt_llm/models/qwen/convert.py index 124313f1a..d78c71a6e 100644 --- a/tensorrt_llm/models/qwen/convert.py +++ b/tensorrt_llm/models/qwen/convert.py @@ -1086,7 +1086,7 @@ def convert_hf_qwen(hf_model, if mapping.is_last_pp_rank(): if hf_model.config.tie_word_embeddings: # lm_head.weight has the same weights as embedding - lm_head_weights = v + lm_head_weights = v.clone() else: lm_head_weights = get_weight(model_params, 'lm_head', dtype) diff --git a/tensorrt_llm/models/recurrentgemma/model.py b/tensorrt_llm/models/recurrentgemma/model.py index 856cf90d1..d555fc5c3 100644 --- a/tensorrt_llm/models/recurrentgemma/model.py +++ b/tensorrt_llm/models/recurrentgemma/model.py @@ -149,11 +149,9 @@ def __init__(self, config: PretrainedConfig) -> None: self.lru_width = config.rnn_hidden_size n_layer = config.num_hidden_layers - self.vocab_embedding = Embedding( - config.vocab_size, - config.hidden_size, - dtype=config.dtype, - share_embedding_table=config.share_embedding_table) + self.vocab_embedding = Embedding(config.vocab_size, + config.hidden_size, + dtype=config.dtype) self.layers = ModuleList( [ResidualLayer(config, layer_idx=i) for i in range(n_layer)]) diff --git a/tensorrt_llm/plugin/plugin.py b/tensorrt_llm/plugin/plugin.py index 15333bd1a..5b24dd85f 100644 --- a/tensorrt_llm/plugin/plugin.py +++ b/tensorrt_llm/plugin/plugin.py @@ -23,7 +23,7 @@ import tensorrt as trt -from .._ipc_utils import IpcMemory +from .._ipc_utils import IpcMemory, can_access_peer from ..logger import logger from ..mapping import Mapping @@ -171,6 +171,7 @@ class PluginConfig(metaclass=PluginConfigMeta): _paged_state: bool = field(default=True, init=False) _streamingllm: bool = field(default=False, init=False) _manage_weights: bool = field(default=False, init=False) + _use_fused_mlp: bool = field(default=True, init=False) def update_from_dict(self, config: dict): for name in config.keys(): @@ -297,6 +298,7 @@ def set_nccl_plugin(self, dtype: str = "auto"): "paged_state", "streamingllm", "reduce_fusion", + "use_fused_mlp", ] @@ -378,10 +380,9 @@ def max_workspace_size_auto(tp_size: int) -> int: @staticmethod def allocate_workspace(mapping: Mapping, - size: int, - is_p2p_supported: bool = True - ) -> Tuple[List[IpcMemory], "torch.tensor"]: + size: int) -> Tuple[List[IpcMemory], "torch.tensor"]: import torch + is_p2p_supported = can_access_peer(mapping) ipc_buffers_ping = IpcMemory(mapping, size * mapping.tp_size, is_p2p_supported) ipc_buffers_pong = IpcMemory(mapping, size * mapping.tp_size, diff --git a/tensorrt_llm/quantization/functional.py b/tensorrt_llm/quantization/functional.py index fe32b5525..4b781362c 100644 --- a/tensorrt_llm/quantization/functional.py +++ b/tensorrt_llm/quantization/functional.py @@ -655,3 +655,34 @@ def postprocess_weight_only(tllm_key, weights, quant_mode): } else: return {tllm_key: weights} # Bias + + +def postprocess_fp8_rowwise(tllm_key, weights, **kwargs): + if tllm_key.endswith("per_channel_scale"): + return {} + + config = kwargs.get("config", None) + if weights[1] is not None: + assert weights[0].dtype == torch.float8_e4m3fn + scale = weights[1].to(torch.float32).reshape(-1) + return { + tllm_key: weights[0], + tllm_key.replace("weight", "per_channel_scale"): scale + } + else: + clamp_val = config.quantization.clamp_val + # activation range bound. + x = weights[0].to(torch.float32).clamp(clamp_val[0], clamp_val[1]) + xmax = x.abs().max(-1, keepdim=True).values + # minimum scaling factor. + torch_weight_scales = (xmax / 448.0).clamp(min=1.0 / (448.0 * 512.0)) + out = x / torch_weight_scales + torch_weight_scales = torch_weight_scales.reshape(-1) + out = torch.clamp(out, -448, 448) + processed_torch_weights = out.to(torch.float8_e4m3fn) + processed_torch_weights = processed_torch_weights.to( + torch.float8_e4m3fn) + return { + tllm_key: processed_torch_weights, + tllm_key.replace("weight", "per_channel_scale"): torch_weight_scales + } diff --git a/tensorrt_llm/quantization/layers.py b/tensorrt_llm/quantization/layers.py index 042590103..caf8604b5 100644 --- a/tensorrt_llm/quantization/layers.py +++ b/tensorrt_llm/quantization/layers.py @@ -33,13 +33,15 @@ from ..layers.linear import Linear, RowLinear from ..module import Module from ..parameter import Parameter -from .functional import (change_qkv_leading_dim, dequantize, fp8_rowwise_gemm, - fp8_rowwise_rms_norm, postprocess_weight_only, - quantize, quantize_fp8_per_token, quantize_per_token, - quantize_tensor, smooth_quant_gemm, - smooth_quant_layer_norm, smooth_quant_rms_norm, - weight_only_groupwise_quant_matmul, - weight_only_quant_matmul) + +# isort: off +from .functional import ( + change_qkv_leading_dim, dequantize, fp8_rowwise_gemm, fp8_rowwise_rms_norm, + postprocess_fp8_rowwise, postprocess_weight_only, quantize, + quantize_fp8_per_token, quantize_per_token, quantize_tensor, + smooth_quant_gemm, smooth_quant_layer_norm, smooth_quant_rms_norm, + weight_only_groupwise_quant_matmul, weight_only_quant_matmul) +# isort: on from .mode import QuantMode @@ -443,6 +445,7 @@ def __init__(self, dtype="float32") self.quant_mode = quant_mode + self.tllm_to_externel_key_dict = {"weight": ["weight", "weight_scale"]} def forward(self, x, lora_runtime_params=None): assert lora_runtime_params is None, "lora is not supported on SmoothQuantLinear now" @@ -461,6 +464,9 @@ def forward(self, x, lora_runtime_params=None): return x + def postprocess(self, tllm_key, weights, **kwargs): + return postprocess_fp8_rowwise(tllm_key, weights, **kwargs) + Fp8RowwiseColumnLinear = Fp8RowwiseLinear @@ -500,6 +506,7 @@ def __init__( dtype="float32") self.quant_mode = quant_mode + self.tllm_to_externel_key_dict = {"weight": ["weight", "weight_scale"]} def forward(self, x, lora_runtime_params=None, reduce_fusion_params=None): assert lora_runtime_params is None, "lora is not supported on SmoothQuantRowLinear now" @@ -529,6 +536,9 @@ def forward(self, x, lora_runtime_params=None, reduce_fusion_params=None): return x + def postprocess(self, tllm_key, weights, **kwargs): + return postprocess_fp8_rowwise(tllm_key, weights, **kwargs) + class WeightOnlyQuantLinear(Linear): @@ -600,16 +610,10 @@ def forward(self, x, lora_runtime_params=None): return x - def postprocess(self, - tllm_key, - weights, - using_head_as_leading_dim=False, - num_heads=-1): + def postprocess(self, tllm_key, weights, **kwargs): if "per_channel_scale" in tllm_key: return {} - weights = super().postprocess(tllm_key, weights, - using_head_as_leading_dim, - num_heads)[tllm_key] + weights = super().postprocess(tllm_key, weights, **kwargs)[tllm_key] weights = weights.to(str_dtype_to_torch(self.dtype)) return postprocess_weight_only(tllm_key, weights, self.weight_only_quant_mode) @@ -681,7 +685,7 @@ def forward(self, x, lora_runtime_params=None, reduce_fusion_params=None): return x - def postprocess(self, tllm_key, weights): + def postprocess(self, tllm_key, weights, **kwargs): if "per_channel_scale" in tllm_key: return {} weights = weights.to(str_dtype_to_torch(self.dtype)) @@ -749,16 +753,17 @@ def unpack_int32_into_int8(w_packed): def pad_like(w, target_shape, value=0): - if w.shape == target_shape: - return w + if w.shape != target_shape: + pad_dim = [] + for dim in range(len(target_shape)): + current_dim = -1 - dim + pad_dim.append(0) + pad_dim.append( + max(0, target_shape[current_dim] - w.shape[current_dim])) + res = F.pad(w, pad_dim, value=value) + return res else: - if w.dim() == 1: - return F.pad(w, (0, max(0, target_shape[-1] - w.shape[-1])), - value=value) - else: - return F.pad(w, (0, max(0, target_shape[-1] - w.shape[-1]), 0, - max(0, target_shape[-2] - w.shape[-2])), - value=value) + return w class WeightOnlyGroupwiseQuantLinear(Linear): @@ -811,7 +816,6 @@ def __init__( scale_shape = (self.in_features // group_size, self.out_features) self.weights_scaling_factor = Parameter(shape=scale_shape, dtype=dtype) - self.transposed_weight = True self.tp_rank = tp_rank if self.is_padded: self.tp_dim = -1 @@ -866,11 +870,11 @@ def forward(self, x, lora_runtime_params=None): return x - def postprocess(self, - tllm_key, - weights, - using_head_as_leading_dim=False, - num_heads=-1): + def postprocess(self, tllm_key, weights, **kwargs): + using_head_as_leading_dim = kwargs.get("using_head_as_leading_dim", + False) + config = kwargs.get("config", None) + num_heads = config.num_heads if not (tllm_key.endswith("bias") or tllm_key.endswith("weight")): return {} if self.is_qkv and type(weights) is list and len(weights) > 3: @@ -921,7 +925,7 @@ def postprocess(self, scales_fp16 = pad_like(scales_fp16, self.weights_scaling_factor.shape, 1) qzeros_unpacked_int32 = pad_like(qzeros_unpacked_int32, - self.zero.shape) + self.zero.shape, 7) zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 7) * scales_fp16 zeros_x_scales_fp16 = zeros_x_scales_fp16.to( str_dtype_to_torch(self.dtype)) @@ -956,14 +960,6 @@ def __init__( ): multiple = max((128 if use_w4a8_awq else 64), group_size) * tp_size self.is_padded = False - if in_features % multiple > 0: - in_features = math.ceil(in_features / multiple) * multiple - self.is_padded = True - if out_features % multiple > 0: - out_features = math.ceil(out_features / multiple) * multiple - self.is_padded = True - multiple = max((128 if use_w4a8_awq else 64), group_size) * tp_size - self.is_padded = False if in_features % multiple > 0: in_features = math.ceil(in_features / multiple) * multiple self.is_padded = True @@ -991,11 +987,6 @@ def __init__( scale_shape = (self.in_features // group_size, self.out_features) self.weights_scaling_factor = Parameter(shape=scale_shape, dtype=dtype) - self.transposed_weight = True - self.tp_rank = tp_rank - if self.is_padded: - self.tp_dim = -1 - self.transposed_weight = True self.tp_rank = tp_rank if self.is_padded: self.tp_dim = -1 @@ -1051,7 +1042,7 @@ def forward(self, x, lora_runtime_params=None, reduce_fusion_params=None): return x - def postprocess(self, tllm_key, weights): + def postprocess(self, tllm_key, weights, **kwargs): if not (tllm_key.endswith("bias") or tllm_key.endswith("weight")): return {} if tllm_key.endswith("bias"): @@ -1088,7 +1079,7 @@ def postprocess(self, tllm_key, weights): scales_fp16 = pad_like(scales_fp16, self.weights_scaling_factor.shape, 1) qzeros_unpacked_int32 = pad_like(qzeros_unpacked_int32, - self.zero.shape) + self.zero.shape, 7) zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 7) * scales_fp16 zeros_x_scales_fp16 = zeros_x_scales_fp16.to( str_dtype_to_torch(self.dtype)) @@ -1368,13 +1359,13 @@ def forward(self, x, lora_runtime_params=None): lora_runtime_params=lora_runtime_params, lora_hidden_state=lora_hidden_state) - def postprocess(self, tllm_key, weights): + def postprocess(self, tllm_key, weights, **kwargs): # TODO: add FP8 modelopt format support if self.is_qkv: if tllm_key.endswith("scaling_factor"): return 448.0 / max(weights).unsqueeze(0) else: - return super().postprocess(tllm_key, weights) + return super().postprocess(tllm_key, weights, **kwargs) if tllm_key.endswith("scaling_factor"): return 448.0 / weights.unsqueeze(0) else: @@ -1463,7 +1454,7 @@ def forward(self, x, lora_runtime_params=None, reduce_fusion_params=None): reduce_fusion_params=reduce_fusion_params) return ret - def postprocess(self, tllm_key, weights): + def postprocess(self, tllm_key, weights, **kwargs): # TODO: add FP8 modelopt format support if tllm_key.endswith("scaling_factor"): return 448.0 / weights.unsqueeze(0) diff --git a/tensorrt_llm/runtime/enc_dec_model_runner.py b/tensorrt_llm/runtime/enc_dec_model_runner.py index b69634728..d61b735a7 100644 --- a/tensorrt_llm/runtime/enc_dec_model_runner.py +++ b/tensorrt_llm/runtime/enc_dec_model_runner.py @@ -7,7 +7,6 @@ import tensorrt as trt from ..logger import logger -from .._ipc_utils import set_peer_access from .._utils import torch_to_numpy, trt_dtype_to_torch, mpi_world_size, mpi_rank from ..plugin.plugin import CustomAllReduceHelper from .generation import ModelConfig, SamplingConfig, LoraManager, GenerationSession @@ -329,11 +328,10 @@ def encoder_run(self, device=self.device).contiguous() if self.encoder_runtime_mapping.tp_size > 1: - is_p2p_supported = set_peer_access(self.encoder_runtime_mapping) ipc_buffers, all_reduce_workspace = CustomAllReduceHelper.allocate_workspace( self.encoder_runtime_mapping, CustomAllReduceHelper.max_workspace_size_auto( - self.encoder_runtime_mapping.tp_size), is_p2p_supported) + self.encoder_runtime_mapping.tp_size)) inputs['all_reduce_workspace'] = all_reduce_workspace if self.encoder_model_config.lora_plugin: diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 1fd563cfb..9e8c46c70 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -32,7 +32,6 @@ from tensorrt_llm.runtime.redrafter_utils import * -from .._ipc_utils import set_peer_access from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy, trt_dtype_to_torch) from ..bindings import KVCacheType @@ -518,6 +517,7 @@ class ModelConfig: # ReDrafter redrafter_num_beams: int = 0 redrafter_draft_len_per_beam: int = 0 + num_kv_heads_per_layer: Optional[List[int]] = None @dataclass @@ -736,10 +736,12 @@ def __init__(self, self.first_layer:self.last_layer] self.attn_to_general_idx = {} + self.general_to_attn_idx = {} attn_layer_idx = 0 for i in range(self.first_layer, self.last_layer): if self.layer_types[i] == 'attention': self.attn_to_general_idx[attn_layer_idx] = i + self.general_to_attn_idx[i] = attn_layer_idx attn_layer_idx += 1 # Cyclic KV cache buffer names. @@ -773,11 +775,10 @@ def __init__(self, self.mapping.pp_size, self.decoder_logits_dtype) if self.mapping.tp_size > 1: - is_p2p_supported = set_peer_access(self.mapping) self.ipc_buffers, self.all_reduce_workspace = CustomAllReduceHelper.allocate_workspace( self.mapping, CustomAllReduceHelper.max_workspace_size_auto( - self.mapping.tp_size), is_p2p_supported) + self.mapping.tp_size)) self.gather_tree = torch.ops.tensorrt_llm.gather_tree @@ -997,8 +998,17 @@ def tokens_per_block(self): def remove_input_padding(self): return self._model_config.remove_input_padding - @property - def num_heads_kv(self): + def get_num_heads_kv(self, layer_idx: Optional[int] = None) -> int: + if layer_idx is None or self._model_config.num_kv_heads_per_layer is None: + return self._model_config.num_kv_heads + + if self._model_config.layer_types: + assert self._model_config.layer_types[ + layer_idx] == "attention", f"Layer {layer_idx} is not an attention layer" + + if self._model_config.num_kv_heads_per_layer: + return self._model_config.num_kv_heads_per_layer[layer_idx] + return self._model_config.num_kv_heads @property @@ -1690,7 +1700,7 @@ def setup(self, num_blocks, self.num_attn_layers, 2, - self.num_heads_kv, + self.get_num_heads_kv(), self.tokens_per_block, self.head_size, ) @@ -1706,7 +1716,7 @@ def setup(self, cross_num_blocks, self.num_layers, 2, - self.num_heads_kv, + self.get_num_heads_kv(), self.tokens_per_block, self.head_size, ) @@ -1714,15 +1724,15 @@ def setup(self, dtype=kv_cache_type, device=self.device) elif self.has_attn_layers: - cache_shape = ( - batch_size, - 2, - self.num_heads_kv, - self.max_attention_window_size, - self.head_size, - ) for i in range(self.first_layer, self.last_layer): if self.layer_types[i] == 'attention': + cache_shape = ( + batch_size, + 2, + self.get_num_heads_kv(self.general_to_attn_idx[i]), + self.max_attention_window_size, + self.head_size, + ) self.buffer[f'present_key_value_{i}'] = torch.empty( cache_shape, dtype=kv_cache_type, @@ -1732,7 +1742,7 @@ def setup(self, cross_cache_shape = ( batch_size, 2, - self.num_heads_kv, + self.get_num_heads_kv(), self.encoder_max_input_length, self.head_size, ) @@ -1894,7 +1904,7 @@ def add_tensor_with_bs(x, name, bs): if self.cross_qkv_reuse is None: # see Attention's self.qkv output dim cross_qkv_out_dim = self.num_heads * self.head_size + ( - 2 * self.num_heads_kv * self.head_size) + 2 * self.get_num_heads_kv() * self.head_size) cross_qkv_shape = encoder_output.shape[:-1] + ( cross_qkv_out_dim, ) cross_qkv_reuse = torch.empty(cross_qkv_shape, @@ -1980,7 +1990,9 @@ def add_tensor_with_bs(x, name, bs): for idx in range(self.first_layer, self.last_layer): if not self.use_gpt_attention_plugin and self.layer_types[ idx] == 'attention': - kv_cache_shape = (batch_size, 2, self.num_heads_kv, 0, + kv_cache_shape = (batch_size, 2, + self.get_num_heads_kv( + self.general_to_attn_idx[idx]), 0, self.head_size) # for empty tensor, TRT does not really use the tensor data, so any dtype is fine kv_cache_buffer = torch.zeros((1, ), @@ -1994,7 +2006,7 @@ def add_tensor_with_bs(x, name, bs): if self.cross_attention: cross_kv_cache_shape = (batch_size, 2, - self.num_heads_kv, 0, + self.get_num_heads_kv(), 0, self.head_size) # for empty tensor, TRT does not really use the tensor data, so any dtype is fine cross_kv_cache_buffer = torch.zeros((1, ), @@ -2269,7 +2281,8 @@ def add_tensor_with_shape(x, name, shape): if not self.paged_kv_cache: for attn_idx, layer_idx in self.attn_to_general_idx.items(): if not self.use_gpt_attention_plugin: - next_shape = (batch_size * beam_width, 2, self.num_heads_kv, + next_shape = (batch_size * beam_width, 2, + self.get_num_heads_kv(), max_context_length + step, self.head_size) # We will make current layer's output KV-cache overwrite previous layers input KV-cache # buffer id: ... 5, 6, 7, 8, 9, ... @@ -2785,7 +2798,7 @@ def reorder_kv_cache_for_beam_search( assert self.buffer is not None assert self.parent_ids.shape[:2] == (batch_size, beam_width) - cache_shape = (batch_size * beam_width, 2, self.num_heads_kv, + cache_shape = (batch_size * beam_width, 2, self.get_num_heads_kv(), max_context_length + step, self.head_size) import functools @@ -3738,7 +3751,8 @@ def decode(self, self.buffer[f'host_kv_cache_pool_pointers'] = torch.tensor( [self.kv_cache_pool.data_ptr(), 0], dtype=torch.int64) - block_size = self.num_heads_kv * self.tokens_per_block * self.head_size + block_size = self.get_num_heads_kv( + ) * self.tokens_per_block * self.head_size self.kv_cache_manager = KVCacheManager( num_layers=self.num_attn_layers, num_blocks=num_blocks, @@ -3760,7 +3774,8 @@ def decode(self, [self.cross_kv_cache_pool.data_ptr(), 0], dtype=torch.int64) - cross_block_size = self.num_heads_kv * self.tokens_per_block * self.head_size + cross_block_size = self.get_num_heads_kv( + ) * self.tokens_per_block * self.head_size self.cross_kv_cache_manager = KVCacheManager( num_layers=self.num_layers, num_blocks=cross_num_blocks, @@ -3802,7 +3817,7 @@ def decode(self, if self.paged_kv_cache: self.kv_cache_updater.init_paged_kv_cache( - self.num_layers, self.num_heads_kv, self.head_size, + self.num_layers, self.get_num_heads_kv(), self.head_size, kv_cache_type, self.kv_cache_manager, self.buffer[f'host_kv_cache_pool_pointers']) else: @@ -3811,7 +3826,7 @@ def decode(self, for i in range(self.first_layer, self.last_layer) ] self.kv_cache_updater.init_linear_kv_cache( - self.num_layers, self.num_heads_kv, self.head_size, + self.num_layers, self.get_num_heads_kv(), self.head_size, kv_cache_type, past_key_value_list) stop_words_lens = None diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index 717c3f76d..6c381fcf0 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -465,6 +465,22 @@ def from_engine( 'Build config doesn\'t have kv_cache_type, you might need to rebuild your enigne.' ) + # TODO(oargov): this is a hack, make it prettier! + if hasattr(pretrained_config, "get_layer_num_kv_heads"): + # each layer has a different number of kv heads + attention_layers = [ + layer_idx for layer_idx, layer_type in enumerate( + pretrained_config.layer_types) if layer_type == "attention" + ] if hasattr(pretrained_config, "layer_types") else list( + range(pretrained_config.num_hidden_layers)) + num_kv_heads_per_layer = [ + pretrained_config.get_layer_num_kv_heads(layer_idx) + if layer_idx in attention_layers else 0 + for layer_idx in range(pretrained_config.num_hidden_layers) + ] + else: + num_kv_heads_per_layer = None + model_config = ModelConfig( max_batch_size=build_config.max_batch_size, max_beam_width=build_config.max_beam_width, @@ -498,6 +514,7 @@ def from_engine( pretrained_config, 'num_medusa_heads') else 0, **rnn_configs_kwargs, gpu_weights_percent=gpu_weights_percent, + num_kv_heads_per_layer=num_kv_heads_per_layer, redrafter_num_beams=pretrained_config.redrafter_num_beams if hasattr(pretrained_config, 'redrafter_num_beams') else 0, redrafter_draft_len_per_beam=pretrained_config. diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 516306f65..dd63de63a 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -79,6 +79,7 @@ def from_dir( sink_token_length: Optional[int] = None, kv_cache_free_gpu_memory_fraction: Optional[float] = None, medusa_choices: list[list[int]] | None = None, + lookahead_config: list[int] | None = None, debug_mode: bool = False, lora_ckpt_source: str = "hf", gpu_weights_percent: float = 1, @@ -235,6 +236,11 @@ def from_dir( if multi_block_mode is not None: multi_block_mode = False # Medusa doesn't support multi-block mode. + if lookahead_config is not None: + [w, n, g] = lookahead_config + decoding_config.lookahead_decoding_config = trtllm.LookaheadDecodingConfig( + w, n, g) + if max_batch_size is None: max_batch_size = model_config.max_batch_size else: diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index a68d887f7..cef9e1f45 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.13.0.dev2024081300" +__version__ = "0.13.0.dev2024082000" diff --git a/tests/bindings/test_executor_bindings.py b/tests/bindings/test_executor_bindings.py index 2ff552059..60b70bd5d 100644 --- a/tests/bindings/test_executor_bindings.py +++ b/tests/bindings/test_executor_bindings.py @@ -516,9 +516,16 @@ def validate_results_shapes(result, input_length, max_output_len, assert result.context_logits is None if return_generation_logits: assert len(result.generation_logits.shape) == 3 - assert list(result.generation_logits.shape) == [ - beam_width, max_output_len, vocab_size_padded - ] + if streaming: + assert list(result.generation_logits.shape) == [ + max_output_len, beam_width, vocab_size_padded + ] or list(result.generation_logits.shape) == [ + 1, beam_width, vocab_size_padded + ] + else: + assert list(result.generation_logits.shape) == [ + beam_width, max_output_len, vocab_size_padded + ] def verify_output(beam_tokens, test_data, given_input_lengths): for batch_id, tokens in beam_tokens.items(): @@ -610,6 +617,76 @@ def verify_output(beam_tokens, test_data, given_input_lengths): verify_output(tokens, test_data, given_input_lengths) +@pytest.mark.parametrize("streaming", [False, True]) +@pytest.mark.parametrize("beam_width", [1]) +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_finish_reason(streaming: bool, beam_width: int, model_files, + model_path): + if streaming and beam_width > 1: + pytest.skip("Test does not support streaming with beam search") + executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, + trtllm.ExecutorConfig(beam_width)) + requests = [ + # Finish due to length. + trtllm.Request([1, 2, 3, 4], 5, streaming, + trtllm.SamplingConfig(beam_width)), + # Finish due to end id. + trtllm.Request([1, 2, 3, 4], + 5, + streaming, + trtllm.SamplingConfig(beam_width), + end_id=4), + # Finish due to stop word. + trtllm.Request([1, 2, 3, 4], + 5, + streaming, + trtllm.SamplingConfig(beam_width), + stop_words=[[4, 2]]), + ] + req_ids = executor.enqueue_requests(requests) + req_to_batch_id = {req_ids[i]: i for i in range(len(requests))} + + num_finished = 0 + i = 0 + num_responses = 0 + max_wait_ms = 10000 + while num_finished < len(requests) and i < max_wait_ms: + wait_time = datetime.timedelta(milliseconds=1) + responses = executor.await_responses(wait_time) + for response in responses: + num_responses += 1 + assert not response.has_error( + ), f"Request id {response.request_id} failed with err {response.error_msg}" + result = response.result + num_finished += result.is_final + batch_id = req_to_batch_id[response.request_id] + + # Non final results should have "NOT_FINISHED". Revise this when streaming + beam_width > 1 is enabled. + if not result.is_final: + assert all([ + r == trtllm.FinishReason.NOT_FINISHED + for r in result.finish_reasons + ]) + # Check if finish reason is correct. + elif batch_id == 0: + assert all([ + r == trtllm.FinishReason.LENGTH + for r in result.finish_reasons + ]) + elif batch_id == 1: + assert all([ + r == trtllm.FinishReason.END_ID + for r in result.finish_reasons + ]) + elif batch_id == 2: + assert all([ + r == trtllm.FinishReason.STOP_WORDS + for r in result.finish_reasons + ]) + i += 1 + assert i < max_wait_ms + + @skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture def test_gpt_executor_timed_out(model_files, model_path): beam_width = 1 @@ -784,7 +861,9 @@ def test_request(): "external_draft_tokens_config": trtllm.ExternalDraftTokensConfig([1, 2, 3]), "prompt_tuning_config": trtllm.PromptTuningConfig(torch.ones(100, 64)), - "lora_config": trtllm.LoraConfig(1) + "lora_config": trtllm.LoraConfig(1), + "logits_post_processor_name": "my_logits_pp", + "client_id": 1234 } request = trtllm.Request(**kwargs) for k, v in kwargs.items(): @@ -809,12 +888,16 @@ def test_result(): result.log_probs = [[1.0, 2.0, 3.0]] result.context_logits = torch.ones(3, 100) result.generation_logits = torch.ones(1, 3, 100) + result.encoder_output = torch.ones(1, 1) + result.finish_reasons = [trtllm.FinishReason.LENGTH] assert result.is_final == True assert result.output_token_ids == [[1, 2, 3]] assert result.cum_log_probs == [1.0, 2.0, 3.0] assert result.log_probs == [[1.0, 2.0, 3.0]] assert (result.context_logits == torch.ones(3, 100)).all() assert (result.generation_logits == torch.ones(1, 3, 100)).all() + assert (result.encoder_output == torch.ones(1, 1)).all() + assert result.finish_reasons == [trtllm.FinishReason.LENGTH] def test_response(): @@ -973,6 +1056,24 @@ def test_speculative_decoding_config(): assert config.medusa_choices == [[0, 0], [0, 1]] +def test_logits_post_processor_config(): + config = trtllm.LogitsPostProcessorConfig() + assert config.processor_map == None + assert config.processor_batched == None + assert config.replicate == True + + kwargs = { + "processor_map": { + "test_pp": None + }, + "processor_batched": None, + "replicate": False + } + config = trtllm.LogitsPostProcessorConfig(**kwargs) + for k, v in kwargs.items(): + assert getattr(config, k) == v + + def test_executor_config(): config = trtllm.ExecutorConfig() assert config.max_beam_width == 1 @@ -986,10 +1087,9 @@ def test_executor_config(): assert config.batching_type == trtllm.BatchingType.INFLIGHT assert config.parallel_config is None assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig) - assert config.logits_post_processor_map is None - assert config.logits_post_processor_batched is None - assert config.replicate_logits_post_processor == True + assert config.logits_post_processor_config is None assert config.decoding_config is None + assert config.debug_config is None kwargs = { "max_beam_width": @@ -1014,11 +1114,16 @@ def test_executor_config(): trtllm.ParallelConfig(), "peft_cache_config": trtllm.PeftCacheConfig(10), - "logits_post_processor_map": {}, - "replicate_logits_post_processor": - False, + "logits_post_processor_config": + trtllm.LogitsPostProcessorConfig(), "decoding_config": trtllm.DecodingConfig(trtllm.DecodingMode.TopKTopP()), + "extended_runtime_perf_knob_config": + trtllm.ExtendedRuntimePerfKnobConfig(multi_block_mode=True), + "debug_config": + trtllm.DebugConfig(dump_input_tensors=True, + dump_output_tensors=True, + debug_tensor_names=["test"]) } config = trtllm.ExecutorConfig(**kwargs) for k, v in kwargs.items(): @@ -1029,6 +1134,10 @@ def test_executor_config(): assert isinstance(config.kv_cache_config, trtllm.KvCacheConfig) assert isinstance(config.parallel_config, trtllm.ParallelConfig) assert isinstance(config.peft_cache_config, trtllm.PeftCacheConfig) + assert config.extended_runtime_perf_knob_config.multi_block_mode == True + assert isinstance(config.debug_config, trtllm.DebugConfig) + assert isinstance(config.logits_post_processor_config, + trtllm.LogitsPostProcessorConfig) def test_parallel_config(): @@ -1103,9 +1212,8 @@ def logits_post_processor(req_id: int, logits: torch.Tensor, # Create executor beam_width = 1 executor_config = trtllm.ExecutorConfig(beam_width) - executor_config.logits_post_processor_map = { - "my_logits_pp": logits_post_processor - } + executor_config.logits_post_processor_config = trtllm.LogitsPostProcessorConfig( + {"my_logits_pp": logits_post_processor}) executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, executor_config) @@ -1162,7 +1270,8 @@ def logits_post_processor_batched( # Create executor beam_width = 1 executor_config = trtllm.ExecutorConfig(beam_width) - executor_config.logits_post_processor_batched = logits_post_processor_batched + executor_config.logits_post_processor_config = trtllm.LogitsPostProcessorConfig( + None, logits_post_processor_batched) executor = trtllm.Executor(model_path, trtllm.ModelType.DECODER_ONLY, executor_config) @@ -1298,18 +1407,99 @@ def test_peft_cache_config_pickle(): assert config.host_cache_size == config_copy.host_cache_size +def test_decoding_config_pickle(): + config = trtllm.DecodingConfig( + decoding_mode=trtllm.DecodingMode.BeamSearch()) + config_copy = pickle.loads(pickle.dumps(config)) + assert config_copy.decoding_mode.isBeamSearch + assert config.lookahead_decoding_config == config_copy.lookahead_decoding_config + assert config.medusa_choices == config_copy.medusa_choices + + +def test_debug_config_pickle(): + config = trtllm.DebugConfig(dump_input_tensors=True, + dump_output_tensors=True, + debug_tensor_names=["test"]) + config_copy = pickle.loads(pickle.dumps(config)) + assert config.dump_input_tensors == config_copy.dump_input_tensors + assert config.dump_output_tensors == config_copy.dump_output_tensors + assert config.debug_tensor_names == config_copy.debug_tensor_names + + +def test_logits_post_processor_config_pickle(): + kwargs = { + "processor_map": { + "test_pp": None + }, + "processor_batched": None, + "replicate": False + } + config = trtllm.LogitsPostProcessorConfig(**kwargs) + config_copy = pickle.loads(pickle.dumps(config)) + for k in kwargs: + assert getattr(config, k) == getattr(config_copy, k) + + def test_executor_config_pickle(): beam_width = 2 config = trtllm.ExecutorConfig(beam_width) - config.scheduler_config = trtllm.SchedulerConfig() - config.kv_cache_config = trtllm.KvCacheConfig() - config.parallel_config = trtllm.ParallelConfig() - config.peft_cache_config = trtllm.PeftCacheConfig(1) + + kwargs = { + "max_beam_width": + 2, + "max_batch_size": + 8, + "max_num_tokens": + 128, + "scheduler_config": + trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.MAX_UTILIZATION), + "kv_cache_config": + trtllm.KvCacheConfig(enable_block_reuse=True), + "enable_chunked_context": + True, + "normalize_log_probs": + False, + "iter_stats_max_iterations": + 100, + "batching_type": + trtllm.BatchingType.STATIC, + "parallel_config": + trtllm.ParallelConfig(), + "peft_cache_config": + trtllm.PeftCacheConfig(10), + "logits_post_processor_config": + trtllm.LogitsPostProcessorConfig(), + "decoding_config": + trtllm.DecodingConfig(trtllm.DecodingMode.TopKTopP()), + "extended_runtime_perf_knob_config": + trtllm.ExtendedRuntimePerfKnobConfig(multi_block_mode=True), + "debug_config": + trtllm.DebugConfig(dump_input_tensors=True, + dump_output_tensors=True, + debug_tensor_names=["test"]) + } + config = trtllm.ExecutorConfig(**kwargs) + for k, v in kwargs.items(): + if "config" not in k: + assert getattr(config, k) == v + pickle.dumps(config) config_copy = pickle.loads(pickle.dumps(config)) assert config.max_beam_width == config_copy.max_beam_width + assert config.max_batch_size == config_copy.max_batch_size + assert config.max_num_tokens == config_copy.max_num_tokens assert config.scheduler_config.capacity_scheduler_policy == config_copy.scheduler_config.capacity_scheduler_policy assert config.kv_cache_config.enable_block_reuse == config_copy.kv_cache_config.enable_block_reuse + assert config.enable_chunked_context == config_copy.enable_chunked_context + assert config.normalize_log_probs == config_copy.normalize_log_probs + assert config.normalize_log_probs == config_copy.normalize_log_probs + assert config.iter_stats_max_iterations == config_copy.iter_stats_max_iterations + assert config.batching_type == config_copy.batching_type + assert config.parallel_config.communication_type == config_copy.parallel_config.communication_type + assert config.peft_cache_config.num_host_module_layer == config_copy.peft_cache_config.num_host_module_layer + assert config_copy.decoding_config.decoding_mode.isTopKandTopP + assert config.extended_runtime_perf_knob_config.multi_block_mode == config_copy.extended_runtime_perf_knob_config.multi_block_mode + assert config.debug_config.dump_input_tensors == config_copy.debug_config.dump_input_tensors def test_return_full_tokens(): diff --git a/tests/bindings/test_gpt_manager.py b/tests/bindings/test_gpt_manager.py index 947eaeff8..a90c53eb8 100644 --- a/tests/bindings/test_gpt_manager.py +++ b/tests/bindings/test_gpt_manager.py @@ -32,7 +32,7 @@ def get_model_spec() -> model_spec.ModelSpec: model_spec_obj = model_spec.ModelSpec( 'input_tokens.npy', _tb.DataType.HALF).use_gpt_plugin().set_kv_cache_type( - model_spec.KVCacheType.PAGED).use_packed_input() + _tb.KVCacheType.PAGED).use_packed_input() get_model_spec.model_spec_obj = model_spec_obj return get_model_spec.model_spec_obj diff --git a/tests/functional/test_moe.py b/tests/functional/test_moe.py index 1e1e7e0a7..c6819f3f1 100644 --- a/tests/functional/test_moe.py +++ b/tests/functional/test_moe.py @@ -16,6 +16,7 @@ import math import unittest from collections import OrderedDict +from itertools import product import numpy as np @@ -32,6 +33,7 @@ from tensorrt_llm import Tensor from tensorrt_llm._utils import (torch_to_numpy, trt_dtype_to_str, trt_dtype_to_torch) +from tensorrt_llm.layers.lora import Lora, LoraParams from tensorrt_llm.layers.moe import MoeConfig, MoeOOTB from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization import QuantAlgo, QuantMode @@ -362,6 +364,145 @@ def create_weights(self, num_experts, hidden_size, ffn_hidden_size, bias, self.activation_scaling_factor_1 = None self.activation_scaling_factor_2 = None + def create_lora_weights(self, num_experts, hidden_size, ffn_hidden_size, + dtype, num_reqs, lora_rank): + genfn = torch.randn + + self.lora_rank = lora_rank + + fc1_weight_rescale_1 = math.sqrt(2.0 / lora_rank) + fc1_weight_rescale_2 = math.sqrt(2.0 / ffn_hidden_size) + fc2_weight_rescale_1 = math.sqrt(2.0 / lora_rank) + fc2_weight_rescale_2 = math.sqrt(2.0 / hidden_size) + + self.lora_fc1_weights_1 = (genfn( + (num_experts, lora_rank, hidden_size), + dtype=trt_dtype_to_torch(dtype), + device="cuda", + ) * fc1_weight_rescale_1) + self.lora_fc1_weights_2 = (genfn( + (num_experts, ffn_hidden_size, lora_rank), + dtype=trt_dtype_to_torch(dtype), + device="cuda", + ) * fc1_weight_rescale_2) + + self.lora_fc1_weights_ptrs = torch.tensor( + (self.lora_fc1_weights_1.data_ptr(), + self.lora_fc1_weights_2.data_ptr()), + dtype=torch.int64, + ).repeat(num_reqs, 1) + self.lora_fc1_ranks = torch.tensor((lora_rank, ), + dtype=torch.int32).repeat(num_reqs) + + self.lora_gated_weights_1 = (genfn( + (num_experts, lora_rank, hidden_size), + dtype=trt_dtype_to_torch(dtype), + device="cuda", + ) * fc1_weight_rescale_1) + self.lora_gated_weights_2 = (genfn( + (num_experts, ffn_hidden_size, lora_rank), + dtype=trt_dtype_to_torch(dtype), + device="cuda", + ) * fc1_weight_rescale_2) + + self.lora_gated_weights_ptrs = torch.tensor( + (self.lora_gated_weights_1.data_ptr(), + self.lora_gated_weights_2.data_ptr()), + dtype=torch.int64, + ).repeat(num_reqs, 1) + self.lora_gated_ranks = torch.tensor((lora_rank, ), + dtype=torch.int32).repeat(num_reqs) + + self.lora_fc2_weights_1 = (genfn( + (num_experts, lora_rank, ffn_hidden_size), + dtype=trt_dtype_to_torch(dtype), + device="cuda", + ) * fc2_weight_rescale_1) + self.lora_fc2_weights_2 = (genfn( + (num_experts, hidden_size, lora_rank), + dtype=trt_dtype_to_torch(dtype), + device="cuda", + ) * fc2_weight_rescale_2) + + self.lora_fc2_weights_ptrs = torch.tensor( + (self.lora_fc2_weights_1.data_ptr(), + self.lora_fc2_weights_2.data_ptr()), + dtype=torch.int64, + ).repeat(num_reqs, 1) + self.lora_fc2_ranks = torch.tensor((lora_rank, ), + dtype=torch.int32).repeat(num_reqs) + + def create_lora_params(self, num_reqs): + + moe_h_to_4h_weights_pointers = Tensor( + shape=(num_reqs, 2), + dtype=tensorrt_llm.str_dtype_to_trt("int64"), + name="moe_h_to_4h_weights_pointers", + ) + moe_h_to_4h_lora_ranks = Tensor( + shape=(num_reqs, ), + dtype=tensorrt_llm.str_dtype_to_trt("int32"), + name="moe_h_to_4h_lora_ranks", + ) + moe_4h_to_h_weights_pointers = Tensor( + shape=(num_reqs, 2), + dtype=tensorrt_llm.str_dtype_to_trt("int64"), + name="moe_4h_to_h_weights_pointers", + ) + moe_4h_to_h_lora_ranks = Tensor( + shape=(num_reqs, ), + dtype=tensorrt_llm.str_dtype_to_trt("int32"), + name="moe_4h_to_h_lora_ranks", + ) + moe_gate_weights_pointers = Tensor( + shape=(num_reqs, 2), + dtype=tensorrt_llm.str_dtype_to_trt("int64"), + name="moe_gate_weights_pointers", + ) + moe_gate_lora_ranks = Tensor( + shape=(num_reqs, ), + dtype=tensorrt_llm.str_dtype_to_trt("int32"), + name="moe_gate_lora_ranks", + ) + host_context_lengths = Tensor( + shape=(num_reqs, ), + dtype=tensorrt_llm.str_dtype_to_trt("int32"), + name="host_context_lengths", + ) + host_request_types = Tensor( + shape=(num_reqs, ), + dtype=tensorrt_llm.str_dtype_to_trt("int32"), + name="host_request_types", + ) + + self.lora_params = LoraParams( + lora_ranks=[{ + "moe_h_to_4h_lora_ranks": moe_h_to_4h_lora_ranks, + "moe_4h_to_h_lora_ranks": moe_4h_to_h_lora_ranks, + "moe_gate_lora_ranks": moe_gate_lora_ranks, + "mlp_h_to_4h_lora_ranks": moe_h_to_4h_lora_ranks, + "mlp_4h_to_h_lora_ranks": moe_4h_to_h_lora_ranks, + "mlp_gate_lora_ranks": moe_gate_lora_ranks, + }], + lora_weights_pointers=[{ + "moe_h_to_4h_lora_weights_pointers": + moe_h_to_4h_weights_pointers, + "moe_4h_to_h_lora_weights_pointers": + moe_4h_to_h_weights_pointers, + "moe_gate_lora_weights_pointers": + moe_gate_weights_pointers, + "mlp_h_to_4h_lora_weights_pointers": + moe_h_to_4h_weights_pointers, + "mlp_4h_to_h_lora_weights_pointers": + moe_4h_to_h_weights_pointers, + "mlp_gate_lora_weights_pointers": + moe_gate_weights_pointers, + }], + host_context_lengths=host_context_lengths, + host_request_types=host_request_types, + weight_index=0, + ) + def create_fp8_scaling_factors(self, max_act1, max_act2): self.activation_scaling_factor_1 = torch.tensor([max_act1 ]).float() / 440. @@ -580,10 +721,176 @@ def MLP(network, trt_key): 'int8': 2e-1, 'int4': 2e-1, } - torch.testing.assert_close(outputs['output'], - outputs['mlp_output'], - rtol=tolerances[dtype_str], - atol=tolerances[dtype_str]) + torch.testing.assert_close( + outputs["output"], + outputs["mlp_output"], + rtol=tolerances[dtype_str], + atol=tolerances[dtype_str], + ) + + @parameterized.expand(list( + product(["float16", "bfloat16", "int4", "int8"], ["gelu", "geglu"], + [True], [32, 64])), + name_func=unittest_name_func) + def test_mlp_lora_comparison(self, dtype_str, actfn, use_plugin, lora_rank): + """This test uses one expert and compares the result to a plain MLP""" + skip_bf16_pre_ampere(dtype_str) + + use_int4_weights = dtype_str == "int4" + weight_dtype = (trt.int8 if use_int4_weights else + tensorrt_llm.str_dtype_to_trt(dtype_str)) + + dtype = weight_dtype + quant_mode = QuantMode(0) + hidden_size = 8 + if dtype_str == "int8" or dtype_str == "int4": + dtype = tensorrt_llm.str_dtype_to_trt("float16") + hidden_size = 64 + quant_mode = QuantMode.use_weight_only( + use_int4_weights=use_int4_weights) + + num_sequences = 4 + sequence_lengths = 4 + num_experts = 1 + top_k = 1 + bias = False + ffn_hidden_size = 4 * hidden_size + self.create_weights( + num_experts, + hidden_size, + ffn_hidden_size, + bias, + dtype, + weight_dtype, + is_gated=is_gated_activation(actfn), + ) + + self.create_lora_weights( + num_experts, + hidden_size, + ffn_hidden_size, + dtype, + num_sequences, + lora_rank, + ) + + input_data = gen_uniform_weights( + (num_sequences, sequence_lengths, hidden_size), + dtype=trt_dtype_to_torch(dtype), + ) + + def MLP(network, trt_key, lora_params): + mlp_type = (tensorrt_llm.layers.GatedMLP if + is_gated_activation(actfn) else tensorrt_llm.layers.MLP) + mlp = mlp_type( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + hidden_act=gated2act(actfn), + bias=bias, + quant_mode=quant_mode, + dtype=dtype, + ) + + mlp.fc.lora = Lora( + in_hidden_size=hidden_size, + out_hidden_sizes=[ffn_hidden_size], + max_low_rank=lora_rank, + ) + + mlp.proj.lora = Lora( + in_hidden_size=ffn_hidden_size, + out_hidden_sizes=[hidden_size], + max_low_rank=lora_rank, + ) + + if is_gated_activation(actfn): + mlp.gate.lora = Lora( + in_hidden_size=hidden_size, + out_hidden_sizes=[ffn_hidden_size], + max_low_rank=lora_rank, + ) + # Quantize the weights manually so the results are comparable + fc1_qd = quant_dequant(self.fc1_weights[0].cpu(), quant_mode) + if is_gated_activation(actfn): + # Note that the MLP uses the opposite convention to the GLU paper for naming, + # the gate is the matrix the activations are NOT applied to + gate, fc1_qd = fc1_qd.chunk(2, dim=0) + mlp.gate.weight.value = np.ascontiguousarray( + torch_to_numpy(gate)) + + mlp.fc.weight.value = np.ascontiguousarray(torch_to_numpy(fc1_qd)) + fc2_qd = quant_dequant(self.fc2_weights[0].cpu(), quant_mode) + mlp.proj.weight.value = np.ascontiguousarray(torch_to_numpy(fc2_qd)) + if bias: + fc1_bias = self.fc1_bias[0].cpu() + + if is_gated_activation(actfn): + gate, fc1_bias = fc1_bias.chunk(2, dim=0) + mlp.gate.bias.value = np.ascontiguousarray( + torch_to_numpy(gate)) + + mlp.fc.bias.value = np.ascontiguousarray( + torch_to_numpy(fc1_bias)) + mlp.proj.bias.value = np.ascontiguousarray( + torch_to_numpy(self.fc2_bias[0].cpu())) + + output = mlp(trt_key, lora_params) + output.mark_output("mlp_output", dtype) + + session = self.create_trt_session( + tuple(input_data.shape), + num_experts, + top_k, + hidden_size, + ffn_hidden_size, + actfn, + bias, + dtype, + weight_dtype, + quant_mode, + norm_mode=MoeConfig.ExpertScaleNormalizationMode.NONE, + custom_network=MLP, + use_plugin=use_plugin, + use_lora=True, + ) + + inputs = { + "input_hidden_states": + input_data, + "moe_h_to_4h_weights_pointers": + self.lora_fc1_weights_ptrs, + "moe_h_to_4h_lora_ranks": + self.lora_fc1_ranks, + "moe_4h_to_h_weights_pointers": + self.lora_fc2_weights_ptrs, + "moe_4h_to_h_lora_ranks": + self.lora_fc2_ranks, + "moe_gate_weights_pointers": + self.lora_gated_weights_ptrs, + "moe_gate_lora_ranks": + self.lora_gated_ranks, + "host_context_lengths": + torch.tensor((sequence_lengths, ), + dtype=torch.int32).repeat(num_sequences), + "host_request_types": + torch.tensor((0, ), dtype=torch.int32).repeat(num_sequences), + } + outputs = run_session(session, inputs) + + tolerances = { + "float32": 1e-2, + "float16": (2e-2 if getSMVersion() >= 75 else + 1e-1), # Some issues for geglu on volta + "bfloat16": 1e-1, + "int8": 2e-1, + "int4": 2e-1, + } + torch.testing.assert_close( + outputs["output"], + outputs["mlp_output"], + rtol=tolerances[dtype_str], + atol=tolerances[dtype_str], + ) def set_weight_layer(self, input_weights, @@ -614,21 +921,24 @@ def set_weight_layer(self, moe_weight_wrapper.weight.value = np.ascontiguousarray( torch_to_numpy(input_weights)) - def create_trt_session(self, - input_shape, - num_experts, - top_k, - hidden_size, - ffn_hidden_size, - actfn, - bias, - dtype: trt.DataType, - weight_dtype: trt.DataType, - quant_mode, - norm_mode, - custom_network=None, - use_plugin=True, - max_sizes=None): + def create_trt_session( + self, + input_shape, + num_experts, + top_k, + hidden_size, + ffn_hidden_size, + actfn, + bias, + dtype: trt.DataType, + weight_dtype: trt.DataType, + quant_mode, + norm_mode, + custom_network=None, + use_plugin=True, + max_sizes=None, + use_lora=False, + ): builder = tensorrt_llm.Builder() network = builder.create_network() @@ -649,6 +959,13 @@ def create_trt_session(self, network.plugin_config.moe_plugin = trt_dtype_to_str(dtype) + lora_params = None + if use_lora: + network.plugin_config.lora_plugin = trt_dtype_to_str(dtype) + network.plugin_config.remove_input_padding = False + self.create_lora_params(input_shape[0]) + lora_params = self.lora_params + moe_config = MoeConfig(num_experts=num_experts, top_k=top_k, normalization_mode=norm_mode) @@ -662,6 +979,9 @@ def create_trt_session(self, quant_mode=quant_mode) moe.router.weight.value = torch_to_numpy(self.router_weights.cpu()) + if use_lora: + moe.max_low_rank = self.lora_rank + self.set_weight_layer(self.fc1_weights, moe.fc, quant_mode, self.weight_scaling_factor_1) self.set_weight_layer(self.fc2_weights, moe.proj, quant_mode, @@ -682,7 +1002,10 @@ def create_trt_session(self, moe.proj.bias.value = torch_to_numpy(self.fc2_bias.cpu()) if custom_network: - custom_network(network, trt_key) + if use_lora: + custom_network(network, trt_key, lora_params) + else: + custom_network(network, trt_key) if not use_plugin: quant_config = None @@ -692,9 +1015,8 @@ def create_trt_session(self, kv_cache_quant_algo=QuantAlgo.FP8) moe = moe.to(MoeOOTB, quant_config=quant_config) - output = moe(trt_key) - output.mark_output('output', dtype) - + output = moe(trt_key, lora_layer_params=lora_params) + output.mark_output("output", dtype) # trt run session = create_session(builder, network, diff --git a/tests/functional/test_nccl.py b/tests/functional/test_nccl.py index 0ec927d6e..1abdfbf7d 100644 --- a/tests/functional/test_nccl.py +++ b/tests/functional/test_nccl.py @@ -28,7 +28,6 @@ import tensorrt_llm from tensorrt_llm import Mapping, Tensor -from tensorrt_llm._ipc_utils import peer_access from tensorrt_llm.functional import (AllReduceConfig, AllReduceStrategy, allreduce) from tensorrt_llm.plugin.plugin import current_all_reduce_helper @@ -97,25 +96,24 @@ def test_allreduce(self, dtype: str, strategy: AllReduceStrategy, input = self.reference_tensors[self.rank][:size].to(torch_dtype) inner_loop = 5 - with peer_access(self.mapping): - with tensorrt_llm.net_guard(network): + with tensorrt_llm.net_guard(network): - x = Tensor(name='x', - shape=input.shape, - dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - current_all_reduce_helper().set_workspace_tensor(self.mapping) + x = Tensor(name='x', + shape=input.shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + current_all_reduce_helper().set_workspace_tensor(self.mapping) - current = x - for i in range(inner_loop): - current = allreduce(current, self.mapping.tp_group, - strategy, config) + current = x + for i in range(inner_loop): + current = allreduce(current, self.mapping.tp_group, strategy, + config) - current.mark_output('output', dtype) + current.mark_output('output', dtype) - # trt run - session = create_session(builder, network, precision=dtype) - inputs = {'x': input, 'all_reduce_workspace': workspace} - outputs = run_session(session, inputs) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'x': input, 'all_reduce_workspace': workspace} + outputs = run_session(session, inputs) # compare diff torch.testing.assert_close(outputs['output'], diff --git a/tests/functional/test_reduce_norm.py b/tests/functional/test_reduce_norm.py index c4c13006d..be9dc759c 100644 --- a/tests/functional/test_reduce_norm.py +++ b/tests/functional/test_reduce_norm.py @@ -29,7 +29,6 @@ import tensorrt_llm as tllm from tensorrt_llm import Mapping, Tensor -from tensorrt_llm._ipc_utils import peer_access from tensorrt_llm.functional import (AllReduceConfig, AllReduceFusionOp, AllReduceFusionParams, AllReduceStrategy, allreduce) @@ -105,66 +104,67 @@ def test_allreduce(self, dtype: str, strategy: AllReduceStrategy, input = self.reference_tensors[self.rank][:size].to( torch_dtype).reshape(token_num, hidden_size) - with peer_access(self.mapping): - with tllm.net_guard(net): - network = tllm.default_trtnet() - - x = Tensor(name='x', - shape=input.shape, - dtype=tllm.str_dtype_to_trt(dtype)) - y = Tensor(name='y', - shape=bias.shape, - dtype=tllm.str_dtype_to_trt(dtype)) - z = Tensor(name='z', - shape=residual.shape, - dtype=tllm.str_dtype_to_trt(dtype)) - w = Tensor(name='w', - shape=weight.shape, - dtype=tllm.str_dtype_to_trt(dtype)) - current_all_reduce_helper().set_workspace_tensor(self.mapping) - - current = x - current, z = allreduce( - current, - self.mapping.tp_group, - strategy, - config, - reduce_fusion_params=AllReduceFusionParams( - AllReduceFusionOp.RESIDUAL_RMS_NORM, - bias=y, - residual=z, - norm_weight=w, - eps=eps)) - output = current.trt_tensor - - output.name = 'output' - output.dtype = tllm.str_dtype_to_trt(dtype) - network.mark_output(output) - - build_engine = EngineFromNetwork( - (builder.trt_builder, net.trt_network), - config=CreateConfig( - fp16=(dtype == 'float16'), - bf16=(dtype == 'bfloat16'), - precision_constraints='obey', - )) - - output = torch.zeros_like(input) - - stream = torch.cuda.current_stream() - feed_dict = { - 'x': input, - 'y': bias, - 'z': residual, - 'w': weight, - 'all_reduce_workspace': workspace - } - - session = tllm.runtime.Session.from_engine(build_engine()) - session.run(inputs=feed_dict, - outputs={"output": output}, - stream=stream.cuda_stream) - torch.cuda.synchronize() + with tllm.net_guard(net): + network = tllm.default_trtnet() + + x = Tensor(name='x', + shape=input.shape, + dtype=tllm.str_dtype_to_trt(dtype)) + y = Tensor(name='y', + shape=bias.shape, + dtype=tllm.str_dtype_to_trt(dtype)) + z = Tensor(name='z', + shape=residual.shape, + dtype=tllm.str_dtype_to_trt(dtype)) + w = Tensor(name='w', + shape=weight.shape, + dtype=tllm.str_dtype_to_trt(dtype)) + current_all_reduce_helper().set_workspace_tensor(self.mapping) + + current = x + current, z = allreduce( + current, + self.mapping.tp_group, + strategy, + config, + reduce_fusion_params=AllReduceFusionParams( + AllReduceFusionOp.RESIDUAL_RMS_NORM, + bias=y, + residual=z, + norm_weight=w, + eps=eps), + ) + output = current.trt_tensor + + output.name = 'output' + output.dtype = tllm.str_dtype_to_trt(dtype) + network.mark_output(output) + + build_engine = EngineFromNetwork( + (builder.trt_builder, net.trt_network), + config=CreateConfig( + fp16=(dtype == 'float16'), + bf16=(dtype == 'bfloat16'), + precision_constraints='obey', + ), + ) + + output = torch.zeros_like(input) + + stream = torch.cuda.current_stream() + feed_dict = { + 'x': input, + 'y': bias, + 'z': residual, + 'w': weight, + 'all_reduce_workspace': workspace + } + + session = tllm.runtime.Session.from_engine(build_engine()) + session.run(inputs=feed_dict, + outputs={"output": output}, + stream=stream.cuda_stream) + torch.cuda.synchronize() close = torch.isclose(allreduce_ref, output, rtol=1e-2, atol=1e-3) if not torch.all(close): diff --git a/tests/hlapi/apps/_test_llm_server.py b/tests/hlapi/apps/_test_llm_server.py index 3d72e6a41..2261f6fa1 100644 --- a/tests/hlapi/apps/_test_llm_server.py +++ b/tests/hlapi/apps/_test_llm_server.py @@ -13,14 +13,17 @@ from test_llm import llama_model_path -@pytest.fixture +@pytest.fixture(scope="module") def client(): llm = LLM(llama_model_path) kv_cache_config = KvCacheConfig() app_instance = LlmServer(llm, kv_cache_config) client = TestClient(app_instance.app) - return client + yield client + + del llm + del app_instance.llm def test_generate(client): diff --git a/tests/hlapi/test_llm.py b/tests/hlapi/test_llm.py index c6e9e6cd0..a75b7b1f1 100644 --- a/tests/hlapi/test_llm.py +++ b/tests/hlapi/test_llm.py @@ -534,15 +534,13 @@ def logits_post_processor(req_id: int, logits: torch.Tensor, @force_ampere def test_generate_block_reuse(): - llm = LLM( - model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4, - enable_block_reuse=True), - ) - - # Check the configurations are correctly set - assert llm.args.build_config.plugin_config.use_paged_context_fmha is True - assert llm.args.build_config.plugin_config.paged_kv_cache is True + build_config = BuildConfig() + build_config.plugin_config._use_paged_context_fmha = True + build_config.plugin_config._paged_kv_cache = True + llm = LLM(model=llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4, + enable_block_reuse=True), + build_config=build_config) sampling_params = SamplingParams(max_new_tokens=6) diff --git a/tests/hlapi/test_llm_models.py b/tests/hlapi/test_llm_models.py index b7c74dc17..559629f73 100644 --- a/tests/hlapi/test_llm_models.py +++ b/tests/hlapi/test_llm_models.py @@ -174,10 +174,13 @@ def test_llm_phi_3_mini_4k(): "examples/phi/requirements.txt") command = f"pip install -r {phi_requirement_path}" subprocess.run(command, shell=True, check=True, env=os.environ) - llm_test_harness(phi_3_mini_4k_model_path, - prompts=['A B C'], - references=[' D E F G H I J K L M'], - sampling_params=sampling_params) + phi3_mini_4k_sampling_params = SamplingParams(max_new_tokens=13) + + llm_test_harness( + phi_3_mini_4k_model_path, + prompts=["I am going to Paris, what should I see?"], + references=["\n\nAssistant: Paris is a city rich in history,"], + sampling_params=phi3_mini_4k_sampling_params) @force_ampere @@ -222,7 +225,7 @@ def test_llm_gemma_2b(): sampling_params=sampling_params) -@force_ampere +@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/4800391") def test_llm_gemma_2b_int4weight_only(): quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16) llm_test_harness(gemma_2b_model_path, @@ -232,7 +235,7 @@ def test_llm_gemma_2b_int4weight_only(): quant_config=quant_config) -@force_ampere +@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/4800404") def test_llm_gemma_2_9b_it(): llm_test_harness(gemma_2_9b_it_model_path, prompts=['A B C'], @@ -292,7 +295,7 @@ def test_llm_baichuan2_7b_int4weight_only(): quant_config=quant_config) -@skip_pre_ampere +@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/4800424") def test_llm_qwen(): llm_test_harness(qwen_model_path, prompts=['A B C'], @@ -300,7 +303,7 @@ def test_llm_qwen(): sampling_params=sampling_params) -@skip_pre_ampere +@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/4800424") def test_llm_qwen1_5(): llm_test_harness(qwen1_5_model_path, prompts=['A B C'], diff --git a/tests/hlapi/test_llm_multi_gpu.py b/tests/hlapi/test_llm_multi_gpu.py index 8c86db288..385991877 100644 --- a/tests/hlapi/test_llm_multi_gpu.py +++ b/tests/hlapi/test_llm_multi_gpu.py @@ -202,3 +202,7 @@ def test_llm_multi_node(engine_from_checkpoint: tempfile.TemporaryDirectory): command = f"mpirun --allow-run-as-root -n {nworkers} trtllm-hlapi-launch python3 {test_case_file} --model_dir {engine_from_checkpoint.name} --tp_size {nworkers}" subprocess.run(command, shell=True, check=True, env=os.environ) # nosec B603 + + +if __name__ == '__main__': + test_llm_pp2() diff --git a/tests/hlapi/test_llm_perf_evaluator.py b/tests/hlapi/test_llm_perf_evaluator.py index 6155e05c2..a2422add9 100644 --- a/tests/hlapi/test_llm_perf_evaluator.py +++ b/tests/hlapi/test_llm_perf_evaluator.py @@ -5,7 +5,7 @@ import time from pathlib import Path -from tensorrt_llm.hlapi import KvCacheConfig +from tensorrt_llm.hlapi import BuildConfig, KvCacheConfig from tensorrt_llm.hlapi._perf_evaluator import (LLMPerfEvaluator, MemoryContinuousMonitorThread) @@ -50,12 +50,16 @@ def test_perf_evaluator(): # try to set some flags kvcache_config = KvCacheConfig(enable_block_reuse=True) + build_config = BuildConfig() + build_config.plugin_config._use_paged_context_fmha = True + evaluator = LLMPerfEvaluator.create( model=llama_model_path, num_samples=10, samples_path=samples_path, warmup=10, kv_cache_config=kvcache_config, + build_config=build_config, ) assert evaluator report = evaluator.run() diff --git a/tests/microbenchamarks/README.md b/tests/microbenchamarks/README.md new file mode 100644 index 000000000..276bcb170 --- /dev/null +++ b/tests/microbenchamarks/README.md @@ -0,0 +1,2 @@ +!!! WARNING: This is not intended for external users to benchmark the performance numbers of the TRT-LLM product. +!!! This folder contains the benchmark script used internally to assistant TRT-LLM development. diff --git a/tests/microbenchamarks/build_time_benchmark.py b/tests/microbenchamarks/build_time_benchmark.py new file mode 100644 index 000000000..0ea5026eb --- /dev/null +++ b/tests/microbenchamarks/build_time_benchmark.py @@ -0,0 +1,134 @@ +import argparse +import os +import pathlib +import time + +import tensorrt_llm +from tensorrt_llm import (AutoConfig, AutoModelForCausalLM, BuildConfig, + Mapping, build) + +# model name to the sub dir under the llm-models path +models_name_to_path = { + 'gpt2': ("gpt2", 1, 1), + 'phi2': ('phi-2', 1, 1), + 'llama-7b': ("llama-models/llama-7b-hf", 1, 1), + 'falcon-7b': ("falcon-7b-instruct", 1, 1), + 'gptj-6b': ("gpt-j-6b", 1, 1), + 'llama2-7b': ("llama-models-v2/llama-v2-7b-hf/", 1, 1), + 'llama2-70b.TP4': ("llama-models-v2/llama-v2-70b-hf", 4, 1), + 'mixtral-8x22b.TP4': ("Mixtral-8x22B-v0.1", 4, 1), + 'mixtral-8x7b.TP4': ("Mixtral-8x7B-v0.1", 4, 1), + 'mistral-7b': ("mistral-7b-v0.1", 1, 1) +} + + +def parse_args(): + parser = argparse.ArgumentParser( + description= + "One microbenchmark to measure the engine build time for common models") + + parser.add_argument("--models_root", + type=str, + default=os.environ.get("LLM_MODELS_ROOT"), + help="The llm-models root path") + parser.add_argument("--model", + type=str, + default='gpt2', + choices=list(models_name_to_path.keys()) + ["ALL"], + help="The model subdir under the models_root") + parser.add_argument("--dtype", + type=str, + choices=['auto', 'float32', 'float16', 'bfloat16'], + default='auto', + help="The data type of the fake weights for the model") + parser.add_argument("--verbose", + '-v', + default=False, + action='store_true', + help="Turn on verbose log") + parser.add_argument("--load", + default=False, + action='store_true', + help="Load Hugging Face weights") + parser.add_argument("--opt", + default=3, + type=int, + choices=[0, 1, 2, 3, 4, 5], + help="Builder optimization level") + parser.add_argument("--gemm", + type=str, + default='ootb', + choices=['plugin', 'ootb'], + help="Use plugin or TRT for GEMM") + parser.add_argument("--strong_type", + default=False, + action="store_true", + help="Use strong type") + parser.add_argument("--managed_weights", + default=False, + action="store_true", + help="Turn on TRT-LLM managed weights") + return parser.parse_args() + + +def build_from_hf(args, model_tag, hf_model_dir, dtype, load_weights, tp, pp): + '''Build model and init executor using huggingface model config and fake weights, useful for benchmarking + ''' + world_size = tp * pp + # TODO: Only build 1 rank for now, all the ranks shall have similar build time + # shall we build all ranks in parallel? + mapping = Mapping(world_size=world_size, rank=0, tp_size=tp, pp_size=pp) + + phase_and_time = [] + if load_weights: + start = time.time() + trtllm_model = AutoModelForCausalLM.from_hugging_face( + hf_model_dir, dtype, mapping) + phase_and_time.append(('load_and_convert', time.time() - start)) + + else: # fake weights + trtllm_config = AutoConfig.from_hugging_face(hf_model_dir, dtype, + mapping) + trtllm_model = AutoModelForCausalLM.get_trtllm_model_class( + hf_model_dir)(trtllm_config) + + start = time.time() + build_config = BuildConfig(max_input_len=1024, max_batch_size=16) + + build_config.builder_opt = args.opt + build_config.plugin_config.manage_weights = args.managed_weights + if args.gemm == 'plugin': + build_config.plugin_config.gemm_plugin = 'auto' + else: + assert args.gemm == 'ootb' + build_config.plugin_config.gemm_plugin = None + build.strongly_typed = args.strong_type + + engine = build(trtllm_model, build_config) + assert engine is not None + + phase_and_time.append(('build_engine', time.time() - start)) + for (p, t) in phase_and_time: + tensorrt_llm.logger.info( + f"===BuildTime==== {p} {model_tag} {t} seconds") + + +if __name__ == "__main__": + args = parse_args() + if args.verbose: + tensorrt_llm.logger.set_level('verbose') + else: + tensorrt_llm.logger.set_level('info') + + target_models = args.model + if target_models == "ALL": + target_models = models_name_to_path.keys() + else: + target_models = [target_models] + + for model in target_models: + model_dir, tp, pp = models_name_to_path[model] + model_dir = pathlib.Path(args.models_root) / model_dir + assert model_dir.exists() + build_from_hf(args, model, str(model_dir), args.dtype, args.load, tp, + pp) diff --git a/tests/model/test_decilm.py b/tests/model/test_decilm.py new file mode 100644 index 000000000..083db6671 --- /dev/null +++ b/tests/model/test_decilm.py @@ -0,0 +1,602 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import os +import sys +import tempfile +import unittest +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import transformers +from parameterized import parameterized + +import tensorrt_llm +from tensorrt_llm import logger +from tensorrt_llm._utils import str_dtype_to_torch +from tensorrt_llm.builder import Builder +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.deci.config import DeciConfig, DeciLayerConfig +from tensorrt_llm.models.deci.convert import _ffn_mult_to_intermediate_size +from tensorrt_llm.models.deci.layer_config import (AttentionImplementation, + FFNImplementation) +from tensorrt_llm.models.deci.model import DeciLMForCausalLM +from tensorrt_llm.network import Network, net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.runtime.generation import _Runtime + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.llm_data import llm_models_root +from utils.util import unittest_name_func + + +class TestDeciLM(unittest.TestCase): + + def _make_decilm_config(self, + layer_configs: List[Union[DeciLayerConfig, + Dict[str, Dict[str, + Any]]]], + dtype: str = 'bfloat16', + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = None, + hidden_size: int = 4096, + intermediate_size: int = 16384, + vocab_size: int = 32128, + max_positions_embedding: int = 1024, + norm_epsilon: float = 1e-05) -> DeciConfig: + config = { + 'architecture': 'DeciLMForCausalLM', + 'num_hidden_layers': len(layer_configs), + 'num_attention_heads': num_attention_heads, + 'num_key_value_heads': num_key_value_heads, + 'dtype': dtype, + 'logits_dtype': dtype, + 'hidden_size': hidden_size, + 'intermediate_size': intermediate_size, + 'vocab_size': vocab_size, + 'position_embedding_type': 'rope_gpt_neox', + 'max_position_embeddings': max_positions_embedding, + 'hidden_act': 'silu', + 'norm_epsilon': norm_epsilon, + 'layer_configs': layer_configs + } + + config = DeciConfig.from_dict(config) + return config + + def _gen_tensorrt_llm_network(self, network: Network, + decilm: DeciLMForCausalLM, batch_size: int, + beam_width: int, input_len: int, + output_len: int, rank: int, + tensor_parallel: int, **opt_flags): + list(range(tensor_parallel)) + + with net_guard(network): + # optimize_model(decilm, **opt_flags) + # Prepare + network.set_named_parameters(decilm.named_parameters()) + inputs = decilm.prepare_inputs(max_batch_size=batch_size, + max_input_len=input_len, + max_seq_len=input_len + output_len, + max_num_tokens=batch_size * + input_len, + use_cache=True, + max_beam_width=beam_width) + # Forward + decilm(**inputs) + return network + + def _gen_tensorrt_llm_engine( + self, + rank: int, + world_size: int, + decilm: DeciLMForCausalLM, + model_name: str, + use_plugin: bool, + batch_size: int, + beam_width: int, + input_len: int, + output_len: int, + use_refit: bool, + use_gemm: bool = False, + context_fmha_flag: ContextFMHAType = ContextFMHAType.disabled, + enable_remove_input_padding: bool = False, + **opt_flags) -> trt.IHostMemory: + + builder = Builder() + dtype = decilm.config.dtype + + with tempfile.TemporaryDirectory(): + builder_config = builder.create_builder_config( + name=model_name, + precision=dtype, + timing_cache='model.cache', + tensor_parallel=world_size, # TP only + use_refit=use_refit, + strongly_typed=True, + ) + network = builder.create_network() + network.plugin_config.to_legacy_setting() + if use_plugin: + network.plugin_config.gpt_attention_plugin = dtype + if use_gemm: + network.plugin_config.gemm_plugin = dtype + if enable_remove_input_padding: + network.plugin_config.remove_input_padding = True + network.plugin_config.set_context_fmha(context_fmha_flag) + + self._gen_tensorrt_llm_network(network=network, + decilm=decilm, + batch_size=batch_size, + beam_width=beam_width, + input_len=input_len, + output_len=output_len, + rank=rank, + tensor_parallel=world_size, + **opt_flags) + engine_buffer = builder.build_engine(network, builder_config) + return engine_buffer + + def _gen_tensorrt_llm_runtime( + self, + log_level: str, + world_size: int, + rank: int, + decilm: DeciLMForCausalLM, + model_name: str, + use_plugin: bool, + batch_size: int, + beam_width: int, + input_len: int, + output_len: int, + use_refit: bool, + use_gemm: bool = False, + context_fmha_flag: ContextFMHAType = ContextFMHAType.disabled, + enable_remove_input_padding: bool = False, + **opt_flags) -> Tuple[_Runtime, trt.IHostMemory]: + logger.set_level(log_level) + mapping = Mapping(world_size, rank, tp_size=world_size) + engine_buffer = self._gen_tensorrt_llm_engine( + rank=rank, + world_size=world_size, + decilm=decilm, + model_name=model_name, + use_plugin=use_plugin, + batch_size=batch_size, + beam_width=beam_width, + input_len=input_len, + output_len=output_len, + use_refit=use_refit, + use_gemm=use_gemm, + context_fmha_flag=context_fmha_flag, + enable_remove_input_padding=enable_remove_input_padding, + **opt_flags) + runtime = _Runtime(engine_buffer, mapping) + return runtime, engine_buffer + + def test_config_to_from_dict(self) -> None: + config = self._make_decilm_config(layer_configs=[{ + "attention": { + "num_key_value_heads": 4 + }, + "ffn": {} + }, { + "attention": { + "num_key_value_heads": 2 + }, + "ffn": { + "impl": "no_op" + } + }, { + "attention": { + "impl": "no_op" + }, + "ffn": { + "intermediate_size": 8192 + } + }]) + + config2 = DeciConfig.from_dict(config.to_dict()) + self.assertListEqual(config.layer_configs, config2.layer_configs) + + def test_save_load_config(self) -> None: + config = self._make_decilm_config(layer_configs=[{ + "attention": { + "num_key_value_heads": 4 + }, + "ffn": {} + }, { + "attention": { + "num_key_value_heads": 2 + }, + "ffn": { + "impl": "no_op" + } + }, { + "attention": { + "impl": "no_op" + }, + "ffn": { + "intermediate_size": 8192 + } + }]) + + with tempfile.TemporaryDirectory( + prefix="test_save_load_checkpoint") as ckpt_dir: + config_file = f"{ckpt_dir}/config.json" + config.to_json_file(config_file) + config2 = DeciConfig.from_json_file(config_file) + + self.assertDictEqual(config.to_dict(), config2.to_dict()) + self.assertListEqual(config.layer_configs, config2.layer_configs) + + def get_loader_test_cases(): + model_root = llm_models_root(check=True) + test_models_base_path = Path(model_root, "nvsmall/tests") + + models_path = [ + os.path.join(test_models_base_path, x) + for x in os.listdir(test_models_base_path) + ] + test_cases = list( + itertools.product(models_path, ["bfloat16", "float16"])) + + return test_cases + + @parameterized.expand(get_loader_test_cases, name_func=unittest_name_func) + def test_allclose_to_hf(self, hf_model_dir, dtype): + if hf_model_dir is None: + self.skipTest( + f"Missing nvsmall checkpoint, define a valid checkpoint path with the NVSMALL_CKPT environment variable" + ) + + dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype) + + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + hf_model_dir, trust_remote_code=True, torch_dtype=dtype).cuda() + decilm = DeciLMForCausalLM.from_hugging_face(hf_model) + config = decilm.config + + log_level = "warning" + batch_size = 1 + beam_width = 1 + input_len = 4 + output_len = 2 + max_seq_len = input_len + output_len + dtype = config.dtype + enable_remove_input_padding = False + use_gpt_plugin = True + use_gemm = True + + runtime, engine_buffer = self._gen_tensorrt_llm_runtime( + log_level=log_level, + decilm=decilm, + batch_size=batch_size, + beam_width=beam_width, + input_len=input_len, + output_len=output_len, + rank=0, + world_size=1, + model_name="decilm", + use_gemm=use_gemm, + use_plugin=use_gpt_plugin, + use_refit=False) + + key_value_cache_buffers = [] + head_size = config.hidden_size // config.num_attention_heads + + attn_layer_idx = [ + i for i in range(config.num_hidden_layers) + if config.get_layer_config(i).attention.needs_kv_cache + ] + for layer_idx in attn_layer_idx: + layer_config = config.get_layer_config(layer_idx) + new_cache = torch.zeros(( + batch_size, + 2, + layer_config.attention.num_key_value_heads, + max_seq_len, + head_size, + ), + dtype=str_dtype_to_torch(dtype), + device='cuda') + key_value_cache_buffers.append(new_cache) + + # compare context + ctx_ids = torch.randint(100, (batch_size, input_len)).int().cuda() + ctx_context_lengths = input_len * torch.ones( + (batch_size), dtype=torch.int32, device='cuda') + ctx_position_ids = torch.tensor(range(input_len), + dtype=torch.int32).reshape([ + 1, input_len + ]).expand([batch_size, + input_len]).cuda() + ctx_last_token_ids = ctx_context_lengths.clone() + ctx_host_request_types = torch.tensor([0] * batch_size, + dtype=torch.int32) + + # We need sequence_lengths start as context_lengths for step 0, + # and it will be added one after each step. + sequence_length_buffer = ctx_context_lengths.detach().clone() + + with torch.no_grad(): + hf_outputs = hf_model.forward(ctx_ids, + output_hidden_states=True, + output_attentions=True) + + torch.cuda.synchronize() + ref = hf_outputs.logits[:, -1, :] + + if enable_remove_input_padding: + ctx_ids = ctx_ids.view([batch_size * input_len]) + ctx_position_ids = ctx_position_ids.view([batch_size * input_len]) + ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, dim=0).int() + + cache_indirections = [ + torch.full(( + batch_size, + beam_width, + max_seq_len, + ), + 0, + dtype=torch.int32, + device='cuda'), + torch.full(( + batch_size, + beam_width, + max_seq_len, + ), + 0, + dtype=torch.int32, + device='cuda') + ] # ping-pong buffers + + perf_knob_tensor_size = 16 + # runtime_perf_knobs is not used in context phase + context_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size, + dtype=torch.int64) + + ctx_buffer = { + 'input_ids': ctx_ids, + 'context_lengths': ctx_context_lengths, + 'position_ids': ctx_position_ids, + 'last_token_ids': ctx_last_token_ids, + 'cache_indirection': cache_indirections[0], + 'host_request_types': ctx_host_request_types, + 'host_runtime_perf_knobs': context_runtime_perf_knobs, + } + if enable_remove_input_padding: + ctx_buffer['host_context_lengths'] = ctx_context_lengths.cpu() + + ctx_shape = {k: v.shape for k, v in ctx_buffer.items()} + + ctx_buffer[f'host_max_attention_window_sizes'] = torch.tensor( + [max_seq_len] * len(attn_layer_idx), dtype=torch.int32) + ctx_shape[f'host_max_attention_window_sizes'] = (len(attn_layer_idx), ) + for layer_idx, buf in zip(attn_layer_idx, key_value_cache_buffers): + layer_config = config.get_layer_config(layer_idx) + kv_shape = (batch_size, 2, + layer_config.attention.num_key_value_heads, max_seq_len, + head_size) + ctx_shape[f'past_key_value_{layer_idx}'] = kv_shape + ctx_buffer[f'past_key_value_{layer_idx}'] = buf + ctx_buffer[f'present_key_value_{layer_idx}'] = buf + + ctx_buffer['sequence_length'] = sequence_length_buffer + ctx_shape['sequence_length'] = ctx_buffer['sequence_length'].shape + ctx_shape['host_past_key_value_lengths'] = (batch_size, ) + ctx_buffer['host_past_key_value_lengths'] = torch.tensor( + [0] * batch_size, dtype=torch.int32) + ctx_shape['host_sink_token_length'] = (1, ) + ctx_buffer['host_sink_token_length'] = torch.tensor([0], + dtype=torch.int32) + + context = runtime.ctx_context + runtime._set_shape(context, ctx_shape) + runtime._set_buffer(context, ctx_buffer) + runtime._run(context) + torch.cuda.synchronize() + + res = ctx_buffer['logits'] + np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), + res.to(torch.float32).cpu().numpy(), + atol=0.12) + + # compare generation + step = 1 + step1_id = torch.randint(100, (batch_size, 1)).int().cuda() + gen_context_lengths = ctx_context_lengths.clone() + gen_position_ids = torch.ones_like(step1_id).int().cuda() * input_len + gen_last_token_ids = torch.zeros_like(gen_context_lengths).int().cuda() + gen_host_request_types = torch.tensor([1] * batch_size, + dtype=torch.int32) + gen_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size, + dtype=torch.int64) + + with torch.no_grad(): + hf_outputs = hf_model.forward( + step1_id, + past_key_values=hf_outputs.past_key_values, + use_cache=True, + output_hidden_states=True) + + torch.cuda.synchronize() + ref = hf_outputs.logits[:, -1, :] + + if enable_remove_input_padding: + step1_id = step1_id.view([batch_size]) + gen_position_ids = gen_position_ids.view([batch_size]) + gen_last_token_ids = torch.ones_like( + gen_context_lengths).int().cuda() + gen_last_token_ids = torch.cumsum(gen_last_token_ids, dim=0).int() + + step1_buffer = { + 'input_ids': step1_id, + 'context_lengths': gen_context_lengths, + 'position_ids': gen_position_ids, + 'last_token_ids': gen_last_token_ids, + 'host_request_types': gen_host_request_types, + 'cache_indirection': cache_indirections[1], + 'host_runtime_perf_knobs': gen_runtime_perf_knobs, + } + if enable_remove_input_padding: + step1_buffer['host_context_lengths'] = gen_context_lengths.cpu() + + step1_shape = {k: v.shape for k, v in step1_buffer.items()} + + sequence_length_buffer = torch.add(sequence_length_buffer, step) + step1_buffer[f'host_max_attention_window_sizes'] = torch.tensor( + [max_seq_len] * len(attn_layer_idx), dtype=torch.int32) + step1_shape[f'host_max_attention_window_sizes'] = ( + len(attn_layer_idx), ) + for layer_idx, buf in zip(attn_layer_idx, key_value_cache_buffers): + layer_config = config.get_layer_config(layer_idx) + kv_shape = (batch_size, 2, + layer_config.attention.num_key_value_heads, max_seq_len, + head_size) + step1_shape[f"past_key_value_{layer_idx}"] = kv_shape + step1_buffer[f"past_key_value_{layer_idx}"] = buf + step1_buffer[f"present_key_value_{layer_idx}"] = buf + + step1_buffer['sequence_length'] = sequence_length_buffer + step1_shape['sequence_length'] = ctx_buffer['sequence_length'].shape + step1_shape['sequence_length'] = (batch_size, ) + step1_shape['host_past_key_value_lengths'] = (batch_size, ) + step1_buffer[ + 'host_past_key_value_lengths'] = sequence_length_buffer.cpu() + step1_shape['host_sink_token_length'] = (1, ) + step1_buffer['host_sink_token_length'] = torch.tensor([0], + dtype=torch.int32) + + context = runtime.context_1 + runtime._set_shape(context, step1_shape) + runtime._set_buffer(context, step1_buffer) + runtime._run(context) + torch.cuda.synchronize() + res = step1_buffer['logits'] + + np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), + res.to(torch.float32).cpu().numpy(), + atol=0.12) + + @parameterized.expand( + itertools.product( + (os.getenv("NVSMALL_CKPT"), ), # "deci/decilm-7b"), + (True, False), + (1, 2), + (1, 2), + ("auto", "float16", "bfloat16"))) + def test_convert_config_from_hf(self, ckpt_path: Optional[str], + preloaded: bool, tp_size: int, pp_size: int, + dtype: str) -> None: + if ckpt_path is None: + self.skipTest( + f"Missing nvsmall checkpoint, define a valid checkpoint path with the NVSMALL_CKPT environment variable" + ) + + hf_config = transformers.AutoConfig.from_pretrained( + ckpt_path, trust_remote_code=True) + + mapping = Mapping(world_size=(tp_size * pp_size), + rank=0, + gpus_per_node=1, + tp_size=tp_size, + pp_size=pp_size) + + config = DeciConfig.from_hugging_face( + hf_config if preloaded else ckpt_path, + dtype=dtype, + mapping=mapping, + trust_remote_code=not preloaded) + + if getattr(hf_config, "num_key_value_heads_per_layer", + None) is not None: + # verify layers for old config + for layer_idx, num_kv_heads in enumerate( + hf_config.num_key_value_heads_per_layer): + layer_config = config.get_layer_config(layer_idx) + self.assertEqual(layer_config.attention.impl, + AttentionImplementation.ATTENTION) + self.assertEqual(num_kv_heads, + layer_config.attention.num_key_value_heads) + self.assertEqual(layer_config.ffn.impl, FFNImplementation.MLP) + self.assertEqual(layer_config.ffn.intermediate_size, + config.intermediate_size) + + elif getattr(hf_config, "block_configs", None) is not None: + # verify layers for new config + for layer_idx, block_config in enumerate(hf_config.block_configs): + layer_config = config.get_layer_config(layer_idx) + if layer_config.attention.impl == AttentionImplementation.ATTENTION: + self.assertFalse(block_config.attention.no_op) + self.assertFalse(block_config.attention.replace_with_linear) + self.assertEqual( + config.num_attention_heads // + block_config.attention.n_heads_in_group, + layer_config.attention.num_key_value_heads) + elif layer_config.attention.impl == AttentionImplementation.NO_OP: + self.assertTrue(block_config.attention.no_op) + elif layer_config.attention.impl == AttentionImplementation.LINEAR: + self.assertTrue(block_config.attention.replace_with_linear) + + if layer_config.ffn.impl == FFNImplementation.MLP: + self.assertFalse(block_config.ffn.no_op) + self.assertFalse(block_config.ffn.replace_with_linear) + self.assertEqual( + _ffn_mult_to_intermediate_size( + block_config.ffn.ffn_mult, config.hidden_size), + layer_config.ffn.intermediate_size) + elif layer_config.ffn.impl == FFNImplementation.NO_OP: + self.assertTrue(block_config.ffn.no_op) + elif layer_config.ffn.impl == FFNImplementation.LINEAR: + self.assertTrue(block_config.ffn.replace_with_linear) + + # verify config is valid enough for model creation + DeciLMForCausalLM(config) + + @parameterized.expand( + itertools.product( + (os.getenv("NVSMALL_CKPT"), ), # "deci/decilm-7b"), + (True, False), + (1, 2), + (1, 2), + ("auto", "float16", "bfloat16"))) + def test_convert_model_from_hf(self, ckpt_path: Optional[str], + preloaded: bool, tp_size: int, pp_size: int, + dtype: str) -> None: + if ckpt_path is None: + self.skipTest( + f"Missing nvsmall checkpoint, define a valid checkpoint path with the NVSMALL_CKPT environment variable" + ) + + if preloaded: + hf_model_or_dir = transformers.AutoModelForCausalLM.from_pretrained( + ckpt_path, trust_remote_code=True) + else: + hf_model_or_dir = ckpt_path + + mapping = Mapping(world_size=(tp_size * pp_size), + rank=0, + gpus_per_node=1, + tp_size=tp_size, + pp_size=pp_size) + + DeciLMForCausalLM.from_hugging_face(hf_model_or_dir=hf_model_or_dir, + dtype=dtype, + mapping=mapping, + trust_remote_code=not preloaded) diff --git a/tests/model/test_llama.py b/tests/model/test_llama.py index 20d9b8bf9..344d52172 100644 --- a/tests/model/test_llama.py +++ b/tests/model/test_llama.py @@ -208,14 +208,12 @@ def load_test_cases(): dict())) # GQA test_cases.append((False, True, ContextFMHAType.enabled_with_fp32_acc, False, 'float16', 4, 'silu', dict())) # GQA - test_cases.append((False, True, ContextFMHAType.disabled, False, - 'float16', 2, 'gelu', { - "use_fused_mlp": True - })) # GQA - test_cases.append((False, True, ContextFMHAType.disabled, False, - 'float16', 2, 'silu', { - "use_fused_mlp": True - })) # GQA + test_cases.append( + (False, True, ContextFMHAType.disabled, False, 'float16', 2, 'gelu', + dict())) # GQA + test_cases.append( + (False, True, ContextFMHAType.disabled, False, 'float16', 2, 'silu', + dict())) # GQA return test_cases @parameterized.expand(load_test_cases, name_func=unittest_name_func) @@ -553,7 +551,6 @@ def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): }, 'use_parallel_embedding': use_parallel_embedding, 'embedding_sharding_dim': embedding_sharding_dim, - 'use_fused_mlp': False, } config = PretrainedConfig.from_dict(config) diff --git a/tests/model/test_mamba.py b/tests/model/test_mamba.py index b7a5fa88a..9f16397d7 100644 --- a/tests/model/test_mamba.py +++ b/tests/model/test_mamba.py @@ -68,13 +68,19 @@ def _gen_tensorrt_llm_mamba(self, hf_config, hf_path, hf_mamba, load_mode, 'conv_kernel': hf_config.conv_kernel, 'use_bias': hf_config.use_bias, 'mamba_version': 'Mamba1', + 'mapping': { + 'world_size': 1, + 'tp_size': 1, + 'pp_size': 1 + }, } - config = tensorrt_llm.models.PretrainedConfig.from_dict(config) if load_mode == 'from_checkpoint': - weights = convert_from_hf_checkpoint(model_dir=hf_path, dtype=dtype) + weights = convert_from_hf_checkpoint(mamba_config=config, + model_dir=hf_path, + dtype=dtype) else: weights = convert_hf_mamba(hf_mamba, rank=0, dtype=dtype) - + config = tensorrt_llm.models.PretrainedConfig.from_dict(config) tensorrt_llm_mamba = tensorrt_llm.models.MambaForCausalLM(config) tensorrt_llm_mamba.load(weights) return tensorrt_llm_mamba diff --git a/tests/model/test_mistral.py b/tests/model/test_mistral.py index a479d5d10..5000566aa 100644 --- a/tests/model/test_mistral.py +++ b/tests/model/test_mistral.py @@ -84,7 +84,6 @@ def _gen_tensorrt_llm_network(self, network, hf_mistral, "top_k": 0, "normalization_mode": 1, }, - 'use_fused_mlp': False, } # Initialize model @@ -505,7 +504,6 @@ def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): }, 'use_parallel_embedding': use_parallel_embedding, 'embedding_sharding_dim': embedding_sharding_dim, - 'use_fused_mlp': False, } config = PretrainedConfig.from_dict(config) diff --git a/tests/utils/cpp_paths.py b/tests/utils/cpp_paths.py index 5b4853aa4..02a8abff4 100644 --- a/tests/utils/cpp_paths.py +++ b/tests/utils/cpp_paths.py @@ -45,7 +45,7 @@ def engine_path(resource_path: _pl.Path) -> _pl.Path: def get_base_model_spec() -> model_spec.ModelSpec: model_spec_obj = model_spec.ModelSpec('input_tokens.npy', _tb.DataType.HALF) model_spec_obj.use_gpt_plugin().set_kv_cache_type( - model_spec.KVCacheType.PAGED).use_packed_input() + _tb.KVCacheType.PAGED).use_packed_input() return model_spec_obj