+* [2024/08/06] 🗫 Multilingual Challenge Accepted 🗫
+🤖 #TensorRT #LLM boosts low-resource languages like Hebrew, Indonesian and Vietnamese ⚡[➡️ link](https://developer.nvidia.com/blog/accelerating-hebrew-llm-performance-with-nvidia-tensorrt-llm/?linkId=100000278659647)
+
* [2024/07/30] Introducing🍊 @SliceXAI ELM Turbo 🤖 train ELM once ⚡ #TensorRT #LLM optimize ☁️ deploy anywhere
[➡️ link](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms)
diff --git a/benchmarks/README.md b/benchmarks/README.md
index 575769842..00f450319 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -7,5 +7,5 @@ There are currently three workflows to benchmark TensorRT-LLM:
- The recommended workflow that uses TensorRT-LLM C++ API and can take advantage of the latest features of TensorRT-LLM.
* [Python benchmarks](./python)
- The Python benchmarking scripts can only benchmark the Python runtime, which do not support the latest features, such as in-flight batching.
-* [The Python benchmarking suite](./suite)
+* [The Python benchmarking suite](./Suite.md)
- This benchmarking suite is a current work in progress and is prone to large changes.
diff --git a/benchmarks/Suite.md b/benchmarks/Suite.md
new file mode 100644
index 000000000..f447b73e7
--- /dev/null
+++ b/benchmarks/Suite.md
@@ -0,0 +1,316 @@
+# TensorRT-LLM Benchmarking
+
+> [!WARNING] Work in Progress
+> This benchmarking suite is a current work in progress and is prone to large changes.
+
+TensorRT-LLM provides a packaged benchmarking utility that is accessible via the `trtllm-bench` CLI tool.
+
+#### Supported Networks for Benchmarking
+
+- [`tiiuae/falcon-180B`](https://huggingface.co/tiiuae/falcon-180B)
+- [`meta-llama/Llama-2-7b-hf`](https://huggingface.co/meta-llama/Llama-2-7b-hf)
+- [`meta-llama/Llama-2-70b-hf`](https://huggingface.co/meta-llama/Llama-2-70b-hf)
+- [`meta-llama/Meta-Llama-3-8B`](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
+- [`meta-llama/Meta-Llama-3-70B`](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
+- [`EleutherAI/gpt-j-6b`](https://huggingface.co/EleutherAI/gpt-j-6b)
+
+#### Support Quantization Modes
+
+TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the
+[documentation](https://nvidia.github.io/TensorRT-LLM/precision.html).
+
+- None (no quantization applied)
+- W8A16
+- W4A16
+- W4A16_AWQ
+- W4A8_AWQ
+- W4A16_GPTQ
+- FP8
+- INT8
+
+> [!NOTE] Please see the supported quantization methods for each network [here](https://nvidia.github.io/TensorRT-LLM/precision.html#support-matrix)
+
+
+## Inflight Benchmarking with a Dataset
+
+This section covers how to benchmark TensorRT-LLM using inflight batching.
+
+
+### Quickstart
+
+For this quick start guide, we will focus on running a short max throughput benchmark on
+`meta-llama/Llama-2-7b-hf` on a syntehtic dataset with a uniform distribution of prompts with ISL:OSL
+of 128:128. In order to run the benchmark from start to finish simply run the following commands:
+
+```shell
+python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1400 > /tmp/synthetic_128_128.txt
+trtllm-bench --model meta-llama/Llama-2-7b-hf build --dataset /tmp/synthetic_128_128.txt --quantization FP8
+trtllm-bench --model meta-llama/Llama-2-7b-hf throughput --dataset /tmp/synthetic_128_128.txt --engine-path /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
+```
+
+And that's it! Once the benchmark completes, a summary will be printed with summary metrics.
+
+```
+===========================================================
+= ENGINE DETAILS
+===========================================================
+Model: meta-llama/Llama-2-7b-hf
+Engine Directory: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
+TensorRT-LLM Version: 0.12.0.dev2024073000
+Dtype: float16
+KV Cache Dtype: FP8
+Quantization: FP8
+Max Input Length: 2048
+Max Sequence Length: 4098
+
+===========================================================
+= WORLD + RUNTIME INFORMATION
+===========================================================
+TP Size: 1
+PP Size: 1
+Max Runtime Batch Size: 4096
+Max Runtime Tokens: 8192
+Scheduling Policy: Guaranteed No Evict
+KV Memory Percentage: 99.0%
+Issue Rate (req/sec): 3.680275266452667e+18
+===========================================================
+= STATISTICS
+===========================================================
+Number of requests: 3000
+Average Input Length (tokens): 128.0
+Average Output Length (tokens): 128.0
+Token Throughput (tokens/sec): 23405.927228471104
+Request Throughput (req/sec): 182.8588064724305
+Total Latency (seconds): 16.406100739
+===========================================================
+```
+
+### Workflow
+
+The workflow for `trtllm-bench` is composed of the following steps:
+
+1. Prepare a dataset to drive the inflight batching benchmark.
+2. Build a benchmark engine using `trtllm-bench build` subcommand.
+3. Run the max throughput benchmark using the `trtllm-bench throughput` subcommand.
+
+#### Preparing a Dataset
+
+The inflight benchmark utilizes a fixed JSON schema so that it is simple and
+straightforward to specify requests. The schema is defined as follows:
+
+| Key | Required | Type | Description |
+| :- | :-: | :-: | :- |
+| `task_id`| Y | String | Unique identifier for the request. |
+| `prompt` | N* | String | Input text for a generation request. |
+| `logits` | N* | List[Integer] | List of logits that make up the request prompt. |
+| `output_tokens` | Y | Integer | Number of generated tokens for this request. |
+
+> [!NOTE] Prompt and logits are mutually exclusive*
+> While having both `prompt` and `logits` is not required, at least one is required.
+> If `logits` are specified, the `prompt` entry is ignored for request generation.
+
+Examples of valid entries for the inflight benchmark are:
+
+- Entries with a human-readable prompt and no logits.
+```json
+{"task_id": 1, "prompt": "Generate an infinite response to the following: This is the song that never ends, it goes on and on my friend.", "output_tokens": 1000}
+{"task_id": 2, "prompt": "Generate an infinite response to the following: Na, na, na, na", "output_tokens": 1000}
+```
+
+- Entries which contain logits.
+```json
+{"task_id":0,"logits":[863,22056,25603,11943,8932,13195,3132,25032,21747,22213],"output_tokens":128}
+{"task_id":1,"logits":[14480,13598,15585,6591,1252,8259,30990,26778,7063,30065,21764,11023,1418],"output_tokens":128}
+```
+
+> [!INFO] A whole entry is on a line!
+> To make the passing of data simpler, a complete JSON entry is on each line so that the benchmarker
+> can simply read a line and assume a complete entry. When creating a dataset, be sure that a complete
+> JSON entry is on every line.
+
+#### Using `prepare_dataset` to Create Synthetic Datasets
+
+In order to prepare a synthetic dataset, you can use the provided script in the `benchmarks/cpp`
+directory. For example, to generate a synthetic dataset of 1000 requests with a uniform ISL/OSL of
+128/128 for [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), simply run:
+
+```shell
+benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > $PATH_TO_DATASET
+```
+
+You can pipe the above command to a file to reuse the same dataset, or simply pipe its output to the
+benchmark script (example below).
+
+### Building a Benchmark Engine
+
+The second thing you'll need once you have a dataset is an engine to benchmark against. In order to
+build a pre-configured engine for one of the supported ISL:OSL combinations, you can run the following
+using the dataset you generated with `prepare_dataset.py` to build an FP8 quantized engine:
+
+```shell
+trtllm-bench --model $HF_MODEL_NAME build --dataset $PATH_TO_DATASET --quantization FP8
+```
+
+or manually set a max sequence length thatL you plan to run with specifically:
+
+```shell
+trtllm-bench --model $HF_MODEL_NAME build --max_seq_len $MAX_SEQ_LEN --quantization FP8
+```
+
+The engine in this case will be written to the `/tmp/$HF_MODEL_NAME/tp_1_pp_1/` directory.
+
+### Running a Max Throughput Benchmark
+
+The `trtllm-bench` command line tool provides a max throughput benchmark that is accessible via the
+`throughput` subcommand. This benchmark tests a TensorRT-LLM engine under maximum load to provide an
+upper bound throughput number.
+
+#### How the Benchmarker Works
+
+The benchmarker will read in a data file or standard input (stdin) as a stream where a single line contains
+a complete JSON request entry. The process that the benchmarker is as follows:
+
+1. Iterate over all input requests. If `logits` is specified, construct the request using the specified
+list of logits. Otherwise, tokenize the `prompt` with as specified by `--model $HF_MODEL_NAME`.
+3. Submit the dataset to the TensorRT-LLM `Executor` API at as fast of a rate as possible (offline mode).
+4. Wait for all requests to return, compute statistics, then report out results.
+
+To run the benchmarker, run the following with the engine and dataset generated above:
+
+```
+trtllm-bench --model $HF_MODEL_NAME throughput --dataset $PATH_TO_DATASET --engine_dir /tmp/$HF_MODEL_NAME/tp_1_pp_1/
+```
+
+When the benchmark runs, you will see output similar to the following:
+
+```
+Preparing to run throughput benchmark...
+Setting up benchmarker and infrastructure.
+Initializing Throughput Benchmark. [rate=%d req/s]
+Ready to start benchmark.
+Initializing Executor.
+[TensorRT-LLM][INFO] Engine version 0.12.0.dev2024073000 found in the config file, assuming engine(s) built by new builder API.
+[TensorRT-LLM][INFO] Initializing MPI with thread mode 3
+[TensorRT-LLM][INFO] Initialized MPI
+[TensorRT-LLM][INFO] Engine version 0.12.0.dev2024073000 found in the config file, assuming engine(s) built by new builder API.
+[TensorRT-LLM][INFO] MPI size: 1, MPI local size: 1, rank: 0
+[TensorRT-LLM][INFO] Rank 0 is using GPU 0
+[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 4096
+[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 4096
+[TensorRT-LLM][INFO] TRTGptModel maxBeamWidth: 1
+[TensorRT-LLM][INFO] TRTGptModel maxSequenceLen: 4098
+[TensorRT-LLM][INFO] TRTGptModel maxDraftLen: 0
+[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: 4098
+[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0
+[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 1
+[TensorRT-LLM][INFO] TRTGptModel maxNumTokens: 8192
+[TensorRT-LLM][INFO] TRTGptModel maxInputLen: 4097 = maxSequenceLen - 1 since chunked context is enabled
+[TensorRT-LLM][INFO] Capacity Scheduler Policy: GUARANTEED_NO_EVICT
+[TensorRT-LLM][INFO] Context Chunking Scheduler Policy: FIRST_COME_FIRST_SERVED
+[TensorRT-LLM][INFO] Loaded engine size: 6214 MiB
+[TensorRT-LLM][INFO] [MemUsageChange] Allocated 928.77 MiB for execution context memory.
+[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
+[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
+[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
+[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
+[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
+[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
+[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
+[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
+[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
+[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
+[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 1. Please ensure there are no enqueued operations pending in this context prior to switching profiles
+[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
+[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
+[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
+[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
+[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
+[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 2. Please ensure there are no enqueued operations pending in this context prior to switching profiles
+[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
+[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
+[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
+[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
+[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
+[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 3. Please ensure there are no enqueued operations pending in this context prior to switching profiles
+[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
+[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
+[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
+[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
+[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
+[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 4. Please ensure there are no enqueued operations pending in this context prior to switching profiles
+[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
+[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
+[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
+[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
+[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
+[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 5. Please ensure there are no enqueued operations pending in this context prior to switching profiles
+[TensorRT-LLM][INFO] [MemUsageChange] Allocated 1.14 GB GPU memory for runtime buffers.
+[TensorRT-LLM][INFO] [MemUsageChange] Allocated 4.35 GB GPU memory for decoder.
+[TensorRT-LLM][INFO] Memory usage when calculating max tokens in paged kv cache: total: 79.10 GiB, available: 63.62 GiB
+[TensorRT-LLM][INFO] Number of blocks in KV cache primary pool: 4607
+[TensorRT-LLM][INFO] Number of blocks in KV cache secondary pool: 0, onboard blocks to primary memory before reuse: true
+[TensorRT-LLM][INFO] Max KV cache pages per sequence: 65
+[TensorRT-LLM][INFO] Number of tokens per block: 64.
+[TensorRT-LLM][INFO] [MemUsageChange] Allocated 62.99 GiB for max tokens in paged KV cache (294848).
+[TensorRT-LLM][INFO] Executor instance created by worker
+Starting response daemon...Executor started.
+
+Request serving started.
+Starting statistics collection.
+Collecting live stats...
+Benchmark started.
+Request serving stopped.
+Collecting last stats...
+Ending statistics collection.
+Stop received.
+Stopping response parsing.
+Collecting last responses before shutdown.
+Completed request parsing.
+Parsing stopped.
+Request generator successfully joined.
+Statistics process successfully joined.
+===========================================================
+= ENGINE DETAILS
+===========================================================
+Model: meta-llama/Llama-2-7b-hf
+Engine Directory: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
+TensorRT-LLM Version: 0.12.0.dev2024073000
+Dtype: float16
+KV Cache Dtype: FP8
+Quantization: FP8
+Max Input Length: 2048
+Max Sequence Length: 4098
+
+===========================================================
+= WORLD + RUNTIME INFORMATION
+===========================================================
+TP Size: 1
+PP Size: 1
+Max Runtime Batch Size: 4096
+Max Runtime Tokens: 8192
+Scheduling Policy: Guaranteed No Evict
+KV Memory Percentage: 99.0%
+Issue Rate (req/sec): 3.680275266452667e+18
+===========================================================
+= STATISTICS
+===========================================================
+Number of requests: 3000
+Average Input Length (tokens): 128.0
+Average Output Length (tokens): 128.0
+Token Throughput (tokens/sec): 23405.927228471104
+Request Throughput (req/sec): 182.8588064724305
+Total Latency (seconds): 16.406100739
+===========================================================
+
+Benchmark Shutdown called!
+Shutting down ExecutorServer.
+[TensorRT-LLM][INFO] Orchestrator sendReq thread exiting
+[TensorRT-LLM][INFO] Orchestrator recv thread exiting
+Executor shutdown.
+[TensorRT-LLM][INFO] Leader sendThread exiting
+[TensorRT-LLM][INFO] Leader recvReq thread exiting
+```
+
+> [!WARNING] Some statistics are not reported.
+> There are some statistics that are not reported in the summary (typically as 0.0). These statistics
+> are not available currently.
diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp
index d3861a2f3..488f71a19 100644
--- a/benchmarks/cpp/gptManagerBenchmark.cpp
+++ b/benchmarks/cpp/gptManagerBenchmark.cpp
@@ -24,6 +24,7 @@
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/tensor.h"
+#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
@@ -173,6 +174,9 @@ struct BenchmarkParams
// Decoding params
std::optional
>> medusaChoices;
+
+ std::optional executorLookaheadConfig;
+ std::optional requestLookaheadConfig;
};
class InferenceRequestsAsyncSend
@@ -509,6 +513,7 @@ class Recorder
{
if (!mStreaming)
{
+ TLLM_LOG_DEBUG("response.getResult().outputTokenIds");
auto outputTokenIds = response.getResult().outputTokenIds;
int32_t outSeqLen = 0;
@@ -824,9 +829,11 @@ class ExecutorServer
executorConfig.setMaxNumTokens(benchmarkParams.maxNumTokens.value());
}
- executorConfig.setDecodingConfig(texec::DecodingConfig(
- benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() : texec::DecodingMode::Auto(),
- std::nullopt, benchmarkParams.medusaChoices));
+ executorConfig.setDecodingConfig(
+ texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa()
+ : benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
+ : texec::DecodingMode::Auto(),
+ benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices));
executorConfig.setExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig);
if (executorModelType == texec::ModelType::kDECODER_ONLY)
@@ -910,7 +917,7 @@ class ExecutorServer
for (auto const& response : responses)
{
auto const reqId = response.getRequestId();
-
+ TLLM_LOG_DEBUG("response.getResult().isFinal");
if (response.getResult().isFinal)
{
mActiveCount--;
@@ -1323,7 +1330,8 @@ std::shared_ptr makeRequest(std::uint64_t reqId, Sample const&
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr,
ITensor::SharedPtr const& returnGenerationLogits = nullptr, ITensor::SharedPtr const& loraWeights = nullptr,
- ITensor::SharedPtr const& loraConfig = nullptr)
+ ITensor::SharedPtr const& loraConfig = nullptr,
+ std::optional lookaheadConfig = std::nullopt)
{
auto request = std::make_shared(reqId);
auto const& inputIds = sample.inputIds;
@@ -1361,6 +1369,10 @@ std::shared_ptr makeRequest(std::uint64_t reqId, Sample const&
{
request->setLoraConfig(loraConfig);
}
+ if (lookaheadConfig)
+ {
+ request->setLookaheadConfig(lookaheadConfig.value());
+ }
if (streaming)
{
request->setIsStreaming(true);
@@ -1372,18 +1384,20 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
std::optional const& eosId, std::optional const& padId, bool streaming = false,
bool const& returnContextLogits = false, bool const& returnGenerationLogits = false,
std::optional const& loraConfig = std::nullopt,
+ std::optional const& lookaheadConfig = std::nullopt,
std::optional encoderInputTokenIds = std::nullopt)
{
auto samplingConfig = texec::SamplingConfig{beamWidth};
auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false};
return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId,
- std::nullopt, // badWords
- std::nullopt, // stopWords
- std::nullopt, // embeddingBias
- std::nullopt, // speculativeDecoding
- std::nullopt, // pTuning
- loraConfig,
- std::nullopt, // logitsPostProcessorName
+ std::nullopt, // badWords
+ std::nullopt, // stopWords
+ std::nullopt, // embeddingBias
+ std::nullopt, // speculativeDecoding
+ std::nullopt, // pTuning
+ loraConfig, // loraConfig
+ lookaheadConfig, // lookaheadConfig
+ std::nullopt, // logitsPostProcessorName
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
}
@@ -1429,9 +1443,11 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
optionalParams.maxBatchSize = benchmarkParams.maxBatchSize;
optionalParams.maxNumTokens = benchmarkParams.maxNumTokens;
optionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy};
- optionalParams.decodingConfig = texec::DecodingConfig(
- benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() : texec::DecodingMode::Auto(),
- std::nullopt, benchmarkParams.medusaChoices);
+ optionalParams.decodingConfig
+ = texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa()
+ : benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
+ : texec::DecodingMode::Auto(),
+ benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
@@ -1501,8 +1517,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
++reqId;
if (i == terminateReqId)
++reqId;
- auto request = makeRequest(
- reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
+ auto request = makeRequest(reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
+ padIdTensor, bufferManager, nullptr, nullptr, nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
@@ -1517,7 +1533,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
- padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
+ padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor, nullptr,
+ nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
if (i < numSamples - 1)
@@ -1541,7 +1558,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
- padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
+ padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor,
+ nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
@@ -1644,13 +1662,13 @@ void benchmarkExecutor(std::optional const& decoderEngine
{
Sample s{std::vector{decoderStartTokenId}, 1, static_cast(taskId)};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false,
- loraConfig, std::vector{1, 2, 3, 4, 5}));
+ loraConfig, std::nullopt, std::vector{1, 2, 3, 4, 5}));
}
else
{
Sample s{std::vector{1, 2, 3, 4, 5}, 1, static_cast(taskId)};
requests.emplace_back(
- makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig));
+ makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig, std::nullopt));
}
}
executorServer->enqueue(std::move(requests), true);
@@ -1668,12 +1686,14 @@ void benchmarkExecutor(std::optional const& decoderEngine
{
Sample s{std::vector{decoderStartTokenId}, samples[0].outputLen, samples[0].taskId};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming,
- returnContextLogits, returnGenerationLogits, std::nullopt, samples[0].inputIds));
+ returnContextLogits, returnGenerationLogits, std::nullopt,
+ benchmarkParams.requestLookaheadConfig, samples[0].inputIds));
}
else
{
requests.emplace_back(makeExecutorRequest(samples[0], beamWidth, eosId, padId,
- benchmarkParams.streaming, returnContextLogits, returnGenerationLogits));
+ benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, std::nullopt,
+ benchmarkParams.requestLookaheadConfig));
}
}
executorServer->enqueue(std::move(requests), true);
@@ -1699,12 +1719,14 @@ void benchmarkExecutor(std::optional const& decoderEngine
{
Sample s{std::vector{decoderStartTokenId}, samples[i].outputLen, samples[i].taskId};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming,
- returnContextLogits, returnGenerationLogits, loraConfig, samples[i].inputIds));
+ returnContextLogits, returnGenerationLogits, loraConfig, benchmarkParams.requestLookaheadConfig,
+ samples[i].inputIds));
}
else
{
requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId,
- benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig));
+ benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig,
+ benchmarkParams.requestLookaheadConfig));
}
}
@@ -1789,6 +1811,25 @@ std::vector> parseVectorOfVectors(std::string const& inp
return result;
}
+texec::LookaheadDecodingConfig parseLookaheadConfig(std::string const& input)
+{
+ std::regex regex("\\[ *(\\d+) *, *(\\d+) *, *(\\d+) *\\]");
+ std::smatch match;
+ if (std::regex_match(input, match, regex))
+ {
+ TLLM_CHECK(match.size() == 4);
+ auto w = std::stoi(match[1]);
+ auto n = std::stoi(match[2]);
+ auto g = std::stoi(match[3]);
+ return texec::LookaheadDecodingConfig(w, n, g);
+ }
+ else
+ {
+ TLLM_LOG_WARNING("cannot parse lookahead config from '%s'", input.c_str());
+ return texec::LookaheadDecodingConfig();
+ }
+}
+
} // namespace
int main(int argc, char* argv[])
@@ -1898,6 +1939,14 @@ int main(int argc, char* argv[])
options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation",
cxxopts::value()->default_value("false"));
+ options.add_options()("executor_lookahead_config",
+ "lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size]",
+ cxxopts::value());
+
+ options.add_options()("request_lookahead_config",
+ "lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size], and each <= "
+ "executor lookahead config",
+ cxxopts::value());
auto result = options.parse(argc, argv);
@@ -2055,6 +2104,16 @@ int main(int argc, char* argv[])
{
benchmarkParams.medusaChoices = parseVectorOfVectors(result["medusa_choices"].as());
}
+ if (result.count("executor_lookahead_config"))
+ {
+ benchmarkParams.executorLookaheadConfig
+ = parseLookaheadConfig(result["executor_lookahead_config"].as());
+ }
+ if (result.count("request_lookahead_config"))
+ {
+ benchmarkParams.requestLookaheadConfig
+ = parseLookaheadConfig(result["request_lookahead_config"].as());
+ }
// Argument: multi_block_mode
benchmarkParams.multiBlockMode = result["multi_block_mode"].as();
diff --git a/benchmarks/python/all_reduce.py b/benchmarks/python/all_reduce.py
index ae7cb8868..d91cdd0d4 100644
--- a/benchmarks/python/all_reduce.py
+++ b/benchmarks/python/all_reduce.py
@@ -23,7 +23,6 @@
import tensorrt_llm as tllm
from tensorrt_llm import Mapping, Tensor
-from tensorrt_llm._ipc_utils import peer_access
from tensorrt_llm._utils import OMPI_COMM_TYPE_HOST, mpi_comm
from tensorrt_llm.functional import AllReduceStrategy, allreduce
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
@@ -106,18 +105,18 @@ def allreduce_benchmark(dtype: str,
_, start = cuda.cuEventCreate(0)
_, stop = cuda.cuEventCreate(0)
runtimes = []
- with peer_access(mapping):
- tllm.mpi_barrier()
-
- for _ in range(10):
- cuda.cuEventRecord(start, stream.cuda_stream)
- session.run(inputs=feed_dict,
- outputs={"output": output},
- stream=stream.cuda_stream)
- cuda.cuEventRecord(stop, stream.cuda_stream)
- torch.cuda.synchronize()
- _, ms = cuda.cuEventElapsedTime(start, stop)
- runtimes.append(ms)
+
+ tllm.mpi_barrier()
+
+ for _ in range(10):
+ cuda.cuEventRecord(start, stream.cuda_stream)
+ session.run(inputs=feed_dict,
+ outputs={"output": output},
+ stream=stream.cuda_stream)
+ cuda.cuEventRecord(stop, stream.cuda_stream)
+ torch.cuda.synchronize()
+ _, ms = cuda.cuEventElapsedTime(start, stop)
+ runtimes.append(ms)
median_ms = sorted(runtimes)[len(runtimes) // 2]
assert torch.allclose(output, (input * world_size)**inner_loop)
diff --git a/benchmarks/suite/README.md b/benchmarks/suite/README.md
deleted file mode 100644
index bba21609e..000000000
--- a/benchmarks/suite/README.md
+++ /dev/null
@@ -1,234 +0,0 @@
-# TensorRT-LLM Benchmarking
-
-> [!WARNING] Work in Progress
-> This benchmarking suite is a current work in progress and is prone to large changes.
-
-This package is the official benchmarking suite for TensorRT-LLM. This benchmark will be updated
-as development of TensorRT-LLM continues.
-
-## Installation
-
-From this folder, run `pip install -r requirements.txt` to install the extra dependencies required for this tool.
-
-### Available Build and Benchmark Options
-
-The following model options are available for benchmarking models.
-
-| Option | Required | Default | Description |
-| :- | :-: | :-: | :- |
-| `--model` | Y | - | The name of the model to benchmark. |
-| `--dtype` | N | `float16` | The datatype of the weights. |
-| `--max-batch-size` | Y | - | The batch size to build the engine with for the benchmark. |
-| `--kv-dtype` | N | `float16` | The datatype to store the KV Cache in. |
-| `--kv-cache-free-gpu-mem-fraction` | N | `0.98` | The percentage of free memory that the KV cache is allowed to occupy. |
-| `--quantization` | N | `None` |The quantization algorithm to be used when benchmarking. See the [documentation](https://nvidia.github.io/TensorRT-LLM/precision.html) for more information|
-| `--workspace` | N | `/tmp` | The directory to store benchmarking intermediate files. |
-| `--tensor-parallel-size` | N | `1` | Number of tensor parallel shards to run the benchmark with. |
-| `--pipeline-parallel-size` | N | `1` | Number of pipeline parallel shards to run the benchmark with. |
-
-#### Supported Networks for Benchmarking
-
-- [`tiiuae/falcon-7b`](https://huggingface.co/tiiuae/falcon-7b)
-- [`tiiuae/falcon-40b`](https://huggingface.co/tiiuae/falcon-40b)
-- [`tiiuae/falcon-180B`](https://huggingface.co/tiiuae/falcon-180B)
-- [`meta-llama/Llama-2-7b-hf`](https://huggingface.co/meta-llama/Llama-2-7b-hf)
-- [`meta-llama/Llama-2-13b-hf`](https://huggingface.co/meta-llama/Llama-2-13b-hf)
-- [`meta-llama/Llama-2-70b-hf`](https://huggingface.co/meta-llama/Llama-2-70b-hf)
-- [`EleutherAI/gpt-j-6b`](https://huggingface.co/EleutherAI/gpt-j-6b)
-
-#### Support Quantization Modes
-
-TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the
-[documentation](https://nvidia.github.io/TensorRT-LLM/precision.html).
-
-- None (no quantization applied)
-- W8A16
-- W4A16
-- W4A16_AWQ
-- W4A8_AWQ
-- W4A16_GPTQ
-- FP8
-- INT8
-
-> [!NOTE] Please see the supported quantization methods for each network [here](https://nvidia.github.io/TensorRT-LLM/precision.html#support-matrix)
-
-## Static Benchmarking a Network
-
-In order to benchmark a static batch for a network, run a command like the following:
-
-```shell
-cd tensorrt_llm_bench/
-python benchmark.py --model tiiuae/falcon-7b static --isl 128 --osl 128 --max-batch-size 1
-```
-
-This command line will build a unique engine for the configuration and run the benchmark using
-the `gptSessionBenchmark` binary. You need to build the TensorRT-LLM wheel with the `--benchmarks` flag for this binary to be compiled:
-
-```shell
-python3 ./scripts/build_wheel.py --benchmarks
-```
-
-If you've already compiled the wheel without benchmarks, you can build the benchmarking binaries with the following after the fact:
-
-```shell
-pushd cpp/build/
-make -j benchmarks
-popd
-```
-
-The complete list of arguments for static benchmarking are as follows:
-| Option | Required | Default | Description |
-| :- | :-: | :-: | :- |
-| `--isl` | Y | - | The input sequence length to pass in during benchmark. |
-| `--osl` | Y | - | The output sequence length to generate in the benchmark. |
-| `--gpt-session-path` | N | `../../cpp/build/benchmarks/gptSessionBenchmark` | The path to the built gptSessionBenchmark binary. |
-| `--warm-up-runs` | N | `2` | The number of warm up runs to run before benchmarking actual results. |
-| `--num-runs` | N | `10` | The number runs to generate benchmarking results from. |
-| `--duration` | N | `60` | The minimum iteration time, in seconds, to measure. |
-
-> [!WARNING]
-> `gptSession` will be deprecated for the 1.0 release of TensorRT-LLM. This command line will change in order to match and update benchmarks accordingly.
-
-
-## Inflight Benchmarking with a Dataset
-
-This section covers how to benchmark TensorRT-LLM using inflight batching.
-
-### Workflow
-
-The workflow for inflight batching is slightly different than the [static scenario](#static-benchmarking-a-network) as it requires a workload of requests instead of a single static batch. The following is the workflow for benchmarking using inflight batching:
-
-1. Prepare a dataset to drive the inflight batching benchmark.
-2. Run the `inflight` benchmarking subcommand and provide the dataset from step 1.
-
-#### Preparing a Dataset
-
-The inflight benchmark utilizes a fixed JSON schema so that it is simple and
-straightforward to specify requests. The schema is defined as follows:
-
-| Key | Required | Type | Description |
-| :- | :-: | :-: | :- |
-| `task_id`| Y | String | Unique identifier for the request. |
-| `prompt` | N* | String | Input text for a generation request. |
-| `logits` | N* | List[Integer] | List of logits that make up the request prompt. |
-| `output_tokens` | Y | Integer | Number of generated tokens for this request. |
-
-> [!NOTE] Prompt and logits are mutually exclusive*
-> While having both `prompt` and `logits` is not required, at least one is required.
-> If `logits` are specified, the `prompt` entry is ignored for request generation.
-
-Examples of valid entries for the inflight benchmark are:
-
-- Entries with a human-readable prompt and no logits.
-```json
-{"task_id": 1, "prompt": "Generate an infinite response to the following: This is the song that never ends, it goes on and on my friend.", "output_tokens": 1000}
-{"task_id": 2, "prompt": "Generate an infinite response to the following: Na, na, na, na", "output_tokens": 1000}
-```
-
-- Entries which contain logits.
-```json
-{"task_id":0,"logits":[863,22056,25603,11943,8932,13195,3132,25032,21747,22213],"output_tokens":128}
-{"task_id":1,"logits":[14480,13598,15585,6591,1252,8259,30990,26778,7063,30065,21764,11023,1418],"output_tokens":128}
-```
-
-> [!INFO] A whole entry is on a line!
-> To make the passing of data simpler, a complete JSON entry is on each line so that the benchmarker
-> can simply read a line and assume a complete entry. When creating a dataset, be sure that a complete
-> JSON entry is on every line.
-
-#### Using `prepare_dataset` to Create Synthetic Datasets
-
-In order to prepare a synthetic dataset, you can use the provided script in the `benchmarks/cpp`
-directory. For example, to generate a synthetic dataset of 1000 requests with a uniform ISL/OSL of
-128/128 for [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), simply run:
-
-```shell
-benchmarks/cpp/prepare_dataset.py --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 --stdout
-```
-
-You can pipe the above command to a file to reuse the same dataset, or simply pipe its output to the
-benchmark script (example below).
-
-### Running a Dataset with the Benchmarker
-
-Once you've generated a dataset (see [above](#preparing-a-dataset)), you can run the benchmarker
-in one of two ways:
-
-```shell
-benchmarks/suite/tensorrt_llm_bench/benchmark.py --model $HF_MODEL_NAME --max-batch-size $BATCH_SIZE < $DATASET_PATH
-```
-
-> [!INFO] Alternative to piping.
-> There is also a `--dataset` option for `benchmark.py` that can be used instead of piping a file.
-
-or
-
-```shell
-benchmarks/cpp/prepare_dataset.py --tokenizer $HF_MODEL_NAME --input-mean $ISL --output-mean $OSL --num-requests $NUM_REQUESTS --stdout | benchmarks/suite/tensorrt_llm_bench/benchmark.py --model $HF_MODEL_NAME --max-batch-size $BATCH_SIZE --request-rate $REQUEST_RATE
-```
-
-#### How the Benchmarker Works
-
-The benchmarker will read in a data file or standard input (stdin) as a stream where a single line contains
-a complete JSON request entry. The process that the benchmarker is as follows:
-
-1. Iterate over all input requests. If `logits` is specified, construct the request using the specified
-list of logits. Otherwise, tokenize the `prompt` with as specified by `--model $HF_MODEL_NAME`.
-2. Build the TensorRT-LLM engine.
-3. Submit the dataset to the TensorRT-LLM `Executor` API at the request rate specified by `--request-rate $REQUEST_RATE`
-4. Wait for all requests to return, compute statistics, then report out results.
-
-When the benchmark runs successfully, you will see a report out of the run similar to the following:
-
-```
-[RANK 0] Submitting requests...
-[RANK 0] Completed request submission.
-[RANK 0] Calculating results.
-[RANK 0] Reporting...
-[RANK 0] JSON: {'benchmark_cmd': '', 'binary': '', 'build_cmd': 'trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_seq_len 256 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16', 'first_token_latency': 0.0, 'inflight_batching': True, 'kv_mem_fraction': 0.98, 'latency_units': 'ms', 'max_batch_size': 1024, 'max_tokens': 8000, 'model': 'meta-llama/Llama-2-7b-hf', 'peak_gpu_mem_units': 'GB', 'peak_gpu_mem': 0.0, 'scheduler': 'Max Utilization', 'throughput_units': 'tokens/second', 'throughput': 17634.422523488243, 'time_per_output_token': 0.0, 'total_input_tokens': 128000, 'total_latency': 7.258530855178833, 'total_output_tokens': 128000}
-===========================================================
-= METADATA
-===========================================================
-Model: meta-llama/Llama-2-7b-hf
-TP Size: 1
-PP Size: 1
-Scheduling Policy: Max Utilization
-In-flight Batcher?: True
-Dtype: float16
-KV Cache Dtype: FP8
-Quantization: FP8
-KV Memory Percentage: 98.0%
-
-===========================================================
-= ENGINE DETAILS
-===========================================================
-Engine Directory: /tmp/meta-llama/llama-2-7b-hf
-Max Batch Size: 1024
-Total Input Length: 128000
-Total Output Length: 128000
-Max Tokens: 8000
-
-===========================================================
-= STATISTICS
-===========================================================
-Throughput (tokens/second): 17634.422523488243
-Total Latency (ms): 7258.5309
-First Token Latency (ms): 0.0
-Token-to-token Latency (ms): 0.0
-Peak GPU Memory Usage (GB): 0.0
-
-===========================================================
-= COMMANDS
-===========================================================
-Build: trtllm-build --output_dir /tmp/meta-llama/llama-2-7b-hf --model_config /tmp/generated_config.json --workers 1 --max_batch_size 1024 --max_input_len 128 --max_seq_len 256 --max_num_tokens 8000 --context_fmha enable --gpt_attention_plugin float16 --paged_kv_cache enable --multiple_profiles enable --gemm_plugin float16
-Benchmark:
-
-[RANK 0] Terminating.
-```
-
-> [!WARNING] Some statistics are not reported.
-> There are some statistics that are not reported in the summary (typically as 0.0). These statistics
-> are not available currently.
-
-
-That's it! -- you've successfully benchmarked TensorRT-LLM!
diff --git a/benchmarks/suite/requirements.txt b/benchmarks/suite/requirements.txt
deleted file mode 100644
index e75e33990..000000000
--- a/benchmarks/suite/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-pydantic>=2.2.1
-click-option-group == 0.5.6
-aenum == 3.1.15
diff --git a/benchmarks/suite/tensorrt_llm_bench/__init__.py b/benchmarks/suite/tensorrt_llm_bench/__init__.py
deleted file mode 100644
index d6bfb8507..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Module for running TensorRT-LLM benchmarks."""
diff --git a/benchmarks/suite/tensorrt_llm_bench/benchmark.py b/benchmarks/suite/tensorrt_llm_bench/benchmark.py
deleted file mode 100644
index 65e306459..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/benchmark.py
+++ /dev/null
@@ -1,125 +0,0 @@
-from pathlib import Path
-from typing import get_args
-
-import click
-from ifb import executor_benchmark
-from static import static_benchmark
-from utils import VALID_CACHE_DTYPES, VALID_COMPUTE_DTYPES, VALID_QUANT_ALGOS
-from utils.dataclasses import BenchmarkConfig
-
-
-@click.group(context_settings={'show_default': True})
-@click.option(
- "--model",
- "-m",
- required=True,
- type=str,
- help="The Huggingface name of the model to benchmark.",
-)
-@click.option(
- "--max-batch-size",
- hidden=True,
- default=0,
- type=int,
- help="Maximum batch size to build the benchmark engine with.",
-)
-@click.option(
- "--kv-dtype",
- type=click.Choice(tuple(get_args(VALID_CACHE_DTYPES))),
- default="float16",
- help="The dtype to store the KV Cache in.",
-)
-@click.option(
- "--dtype",
- type=click.Choice(tuple(get_args(VALID_COMPUTE_DTYPES))),
- default="float16",
- help="Activation and plugin data type.",
-)
-@click.option(
- "--quantization",
- "-q",
- type=click.Choice(tuple(get_args(VALID_QUANT_ALGOS))),
- default="None",
- help=
- ("The quantization algorithm to be used when benchmarking. See the "
- "documentations for more information.\n"
- " - https://nvidia.github.io/TensorRT-LLM/precision.html"
- " - https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/quantization-in-TRT-LLM.md"
- ),
-)
-@click.option(
- "--workspace",
- "-w",
- required=False,
- type=click.Path(writable=True, readable=True),
- default="/tmp",
- help="The directory to store benchmarking intermediate files.",
-)
-@click.option(
- "--tensor-parallel-size",
- "-tp",
- type=int,
- default=1,
- required=False,
- help="Number of tensor parallel shards to run the benchmark with.",
-)
-@click.option(
- "--pipeline-parallel-size",
- "-pp",
- type=int,
- default=1,
- required=False,
- help="Number of pipeline parallel shards to run the benchmark with.",
-)
-@click.option(
- "--kv-cache-free-gpu-mem-fraction",
- "-kv-mem",
- type=float,
- default=0.98,
- help="The percentage of free memory that the KV Cache is allowed to occupy.",
-)
-@click.option(
- "--build-opts",
- type=str,
- default="",
- required=False,
- hidden=True,
- help="Passthrough options for trtllm-build to fine-tuning build commands.")
-@click.pass_context
-def benchmark(
- ctx,
- model: str,
- max_batch_size: int,
- workspace: Path,
- dtype: str,
- kv_dtype: str,
- quantization: str,
- tensor_parallel_size: int,
- pipeline_parallel_size: int,
- kv_cache_free_gpu_mem_fraction: float,
- build_opts: str,
-):
- """Utility for using TRT-LLM for benchmarking networks from Huggingface."""
- ctx.obj = BenchmarkConfig(
- model=model,
- max_batch_size=max_batch_size,
- workspace=Path(workspace),
- dtype=dtype,
- cache_dtype=kv_dtype,
- quantization=quantization,
- tensor_parallel=tensor_parallel_size,
- pipeline_parallel=pipeline_parallel_size,
- kv_cache_mem_percentage=kv_cache_free_gpu_mem_fraction,
- build_overrides=build_opts.split(),
- )
-
- # Create the workspace where we plan to store intermediate files.
- ctx.obj.workspace.mkdir(parents=True, exist_ok=True)
-
-
-# Add nested subcommands to main benchmark CLI.
-benchmark.add_command(static_benchmark)
-benchmark.add_command(executor_benchmark)
-
-if __name__ == "__main__":
- benchmark()
diff --git a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/__init__.py b/benchmarks/suite/tensorrt_llm_bench/benchmarkers/__init__.py
deleted file mode 100644
index 2cc2877af..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/__init__.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from typing import List, Protocol
-
-from utils.dataclasses import BenchmarkResults, InferenceRequest
-
-
-class Benchmarker(Protocol):
- """Protocol for defining benchmarking classes for building/benchmarking."""
-
- def build(self) -> None:
- """Build a model to be benchmarked."""
- ...
-
- def benchmark(self) -> BenchmarkResults:
- """Benchmark the constructed model container by a benchmarker."""
- ...
-
-
-class DatasetBenchmarker(Protocol):
-
- def benchmark_dataset(self,
- dataset: List[InferenceRequest]) -> BenchmarkResults:
- """_summary_
-
- Args:
- dataset (List[InferenceRequest]): List of inference requests to benchmark.
-
- Returns:
- BenchmarkResults: The results of the benchmark run.
- """
- ...
diff --git a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/pybind_executor.py b/benchmarks/suite/tensorrt_llm_bench/benchmarkers/pybind_executor.py
deleted file mode 100644
index 5742e90a5..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/pybind_executor.py
+++ /dev/null
@@ -1,146 +0,0 @@
-from datetime import timedelta
-from time import sleep, time
-from typing import List
-
-from mpi4py.MPI import COMM_WORLD
-from transformers import PreTrainedTokenizer
-from utils.dataclasses import BenchmarkConfig, BenchmarkResults
-from utils.enums import IFBSchedulingPolicy, ResultsSchedulingPolicy
-
-from tensorrt_llm.bindings.executor import (Executor, ExecutorConfig,
- KvCacheConfig, ModelType,
- OutputConfig, Request,
- SchedulerConfig)
-
-from . import InferenceRequest
-
-
-class PybindExecutorBenchmarker:
- """Utility class for running inflight benchmarks via the Executor API."""
-
- def __init__(
- self,
- config: BenchmarkConfig,
- ):
- """Initialize a gptSessionBenchmark instance.
-
- Args:
- config (BenchmarkConfig): Benchmark configuration for build/run.
- """
- self.config: BenchmarkConfig = config
-
- @staticmethod
- def get_request(request: InferenceRequest,
- tokenizer: PreTrainedTokenizer) -> Request:
- return Request(
- input_token_ids=request.logits,
- max_new_tokens=request.output_tokens,
- stop_words=[],
- bad_words=[],
- streaming=False,
- output_config=OutputConfig(exclude_input_from_output=True),
- pad_id=tokenizer.pad_token_id,
- end_id=tokenizer.eos_token_id,
- )
-
- def initialize_executor(self) -> Executor:
- """
- Initialize an Executor instance.
-
- Returns:
- Executor: An instance of a TensorRT-LLM Executor.
- """
- policy = IFBSchedulingPolicy(self.config.scheduling_policy).value
- executor_config: ExecutorConfig = ExecutorConfig(
- max_beam_width=1,
- enable_chunked_context=self.config.chunking,
- scheduler_config=SchedulerConfig(
- capacity_scheduler_policy=policy, ),
- kv_cache_config=KvCacheConfig(
- free_gpu_memory_fraction=self.config.kv_cache_mem_percentage, ),
- )
-
- executor: Executor = Executor(
- model_path=self.config.engine_path,
- model_type=ModelType.DECODER_ONLY,
- executor_config=executor_config,
- )
-
- return executor
-
- def benchmark_dataset(self, rate: int,
- dataset: List[InferenceRequest]) -> BenchmarkResults:
- """Benchmark the Executor Pybind interface.
-
- Args:
- dataset (List[InferenceRequest]): List of inference requests to
- benchmark with.
-
- Returns:
- BenchmarkResults: Final results from running the specified dataset.
- """
- request_ids = []
- num_finished = 0
- num_errored = 0
- num_input_tokens = 0
- num_output_tokens = 0
- delay = 1.0 / float(rate)
- last_request = len(dataset) - 1
- bench_result = None
-
- executor = self.initialize_executor()
- if executor.can_enqueue_requests():
- print(f"[RANK {COMM_WORLD.rank}] Submitting requests...")
- start = time()
- for i, request in enumerate(dataset):
- sleep_time = delay if i != last_request else 0
- request_ids.append(executor.enqueue_request(request))
- num_input_tokens += len(request.input_token_ids)
- sleep(sleep_time)
- print(f"[RANK {COMM_WORLD.rank}] Completed request submission.")
-
- while num_finished <= last_request:
- responses = executor.await_responses(timeout=timedelta(
- milliseconds=1))
- for response in responses:
- has_error = response.has_error()
- num_finished += 1
- num_errored += 1 if has_error else 0
-
- if not has_error:
- result = response.result
- for out_tokens in result.output_token_ids:
- num_output_tokens += len(out_tokens)
- end = time()
- print(f"[RANK {COMM_WORLD.rank}] Calculating results.")
- e2e_time = end - start
- e2e_time * 1000.0
- policy = ResultsSchedulingPolicy(
- IFBSchedulingPolicy(self.config.scheduling_policy).value)
-
- bench_result = BenchmarkResults(
- model=self.config.model,
- dtype=self.config.dtype.value,
- quantization=str(self.config.quantization.value),
- max_batch_size=self.config.max_batch_size,
- total_input_tokens=num_input_tokens,
- total_output_tokens=num_output_tokens,
- tp_size=self.config.tensor_parallel,
- pp_size=self.config.pipeline_parallel,
- kv_mem_fraction=self.config.kv_cache_mem_percentage,
- scheduler=policy.value,
- max_tokens=self.config.max_tokens,
- inflight_batching=True,
- total_latency=e2e_time,
- first_token_latency=0,
- time_per_output_token=0,
- latency_units="ms",
- throughput=num_output_tokens / e2e_time,
- throughput_units="tokens/second",
- peak_gpu_mem=0.0,
- peak_gpu_mem_units="GB",
- build_cmd="",
- benchmark_cmd="",
- )
-
- return bench_result
diff --git a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/static.py b/benchmarks/suite/tensorrt_llm_bench/benchmarkers/static.py
deleted file mode 100644
index b5b3c49ea..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/benchmarkers/static.py
+++ /dev/null
@@ -1,208 +0,0 @@
-import platform
-from pathlib import Path
-from subprocess import CompletedProcess
-from typing import Dict, List
-
-from utils import command_logger, process_error_check, run_process
-from utils.dataclasses import BenchmarkConfig, BenchmarkResults
-from utils.trtllm_config import TRTLLMConfig
-
-
-class gptSessionBenchmarker:
- """Utility class for running static benchmarks with gptSessionBenchmark."""
-
- def __init__(
- self,
- config: BenchmarkConfig,
- benchmark_binary: Path,
- batch_size: int,
- isl: int,
- osl: int,
- warm_up_runs: int,
- num_runs: int,
- duration: int,
- kv_cache_free_fraction: float = .9,
- ):
- """Initialize a gptSessionBenchmark instance.
-
- Args:
- config (BenchmarkConfig): Benchmark configuration for build/run.
- benchmark_binary (Path): Path to the benchmarking binary.
- batch_size (int): Batch size to configure the build with.
- isl (int): Input sequence length to configure the build with.
- osl (int): Output sequence length to configure the build with.
- kv_cache_free_fraction (float, optional): The amount of remaining
- GPU memory after model loading to save for the KV Cache. Defaults
- to .9.
- """
- self.config: BenchmarkConfig = config
- self.gpt_session_path = Path(benchmark_binary).absolute()
- self.batch_size = batch_size
- self.input_length = isl
- self.output_length = osl
- self.warm_up = warm_up_runs
- self.num_runs = num_runs
- self.duration = duration
- self.kv_cache_mem = kv_cache_free_fraction
- self.result = None
-
- def get_build_command(self) -> List[str]:
- """Build the engine command for TRT-LLM.
-
- Returns:
- List[str]: A list of command line arguments to run a build command.
- """
- model = self.config.model
- tp = self.config.tensor_parallel
- pp = self.config.pipeline_parallel
- dtype = self.config.dtype.value
- kv_dtype = self.config.cache_dtype
- quant_algo = self.config.quantization.value
- output_dir = self.config.engine_path
- max_batch_size = self.batch_size
- max_isl = self.input_length
- max_osl = self.output_length
- workspace = self.config.workspace
-
- # Generate the TRT-LLM Configuration file using the dataclass
- # NOTE: This method does not use weights.
- trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
- kv_dtype.value)
- # Write the generated configuration file to the benchmark workspace.
- trtllm_config.to_json(workspace)
-
- # Return the full command for building TRT-LLM via subprocess call.
- cmd = [
- "trtllm-build",
- "--output_dir",
- output_dir,
- "--model_config",
- Path(workspace, "generated_config.json"),
- "--workers",
- self.config.world_size,
- # Define the maximums the engine can accept.
- "--max_batch_size",
- max_batch_size,
- "--max_input_len",
- max_isl,
- "--max_seq_len",
- max_osl + max_isl,
- "--context_fmha",
- "enable",
- # Set the attention plugin data type.
- "--gpt_attention_plugin",
- dtype,
- # Disable paged cache since we aren't batching on the fly.
- "--paged_kv_cache",
- "disable",
- ] + kv_dtype.get_build_options(dtype)
-
- return [str(arg) for arg in cmd]
-
- @command_logger(prefix="BUILD COMMAND: ")
- @process_error_check
- def _run_build(self, cmd: List[str]) -> CompletedProcess:
- """Wrapper for calling the build for TRT-LLM.
-
- Purpose of this wrapper is so that we can decorate it/log it.
-
- Args:
- cmd (List[str]): List of command line arguments for running.
-
- Returns:
- CompletedProcess: Completed process information for parsing and
- reporting.
- """
- return run_process(
- cmd,
- self.config.workspace,
- )
-
- def build(self) -> None:
- """Build the engine for benchmarking."""
- self._run_build(self.get_build_command())
-
- @command_logger(prefix="BENCHMARK COMMAND: ")
- @process_error_check
- def _run_benchmark(self, cmd: List[str]) -> CompletedProcess:
- """Run the benchmark command in the configured workspace.
-
- Args:
- cmd (List[str]): List of command line arguments to run via
- subprocess.
-
- Returns:
- CompletedProcess: Completed process information for reporting.
- """
- return run_process(cmd, run_dir=self.config.workspace, use_environ=True)
-
- @staticmethod
- def parse_benchmark_result(benchmark_line: str) -> Dict[str, str]:
- pass
-
- def benchmark(self):
- """Benchmarks a TRT-LLM for a configured instance."""
-
- # Compile the command for running
- cmd = ["mpiexec", "-n", self.config.world_size]
- cmd += ["-allow-run-as-root"] if platform.system() != "Windows" else ""
- cmd += [
- self.gpt_session_path,
- "--engine_dir",
- self.config.engine_path,
- "--batch_size",
- self.batch_size,
- "--log_level",
- "info",
- "--kv_cache_free_gpu_mem_fraction",
- self.kv_cache_mem,
- "--beam_width",
- "1",
- "--warm_up",
- self.warm_up,
- "--num_runs",
- self.num_runs,
- "--duration",
- self.duration,
- "--input_output_len",
- f"{self.input_length},{self.output_length};{self.input_length},1",
- ]
- cmd = [str(arg) for arg in cmd]
- # Run the benchmark using the provided gptSession benchmark binary.
- bench_return = self._run_benchmark(cmd)
- results = [
- x.split(" ") for x in bench_return.stdout.split("\n")
- if "[BENCHMARK]" in x
- ]
-
- ttft = float(results[1][8])
- gen_time = float(results[0][8]) - ttft
- total_out = int(results[0][2]) * int(results[0][6])
- total_in = int(results[0][2]) * int(results[0][4])
- batch_size = int(results[0][2])
-
- bench_result = BenchmarkResults(
- model=self.config.model,
- dtype=self.config.dtype.value,
- quantization=str(self.config.quantization.value),
- max_batch_size=batch_size,
- total_input_tokens=total_in,
- total_output_tokens=total_out,
- tp_size=self.config.tensor_parallel,
- pp_size=self.config.pipeline_parallel,
- kv_mem_fraction=self.kv_cache_mem,
- scheduler="Static",
- inflight_batching=False,
- total_latency=results[0][8],
- first_token_latency=ttft,
- time_per_output_token=gen_time / (total_out - batch_size),
- latency_units="ms",
- throughput=results[0][10],
- throughput_units="tokens/second",
- peak_gpu_mem=results[0][16],
- peak_gpu_mem_units="GB",
- binary=str(self.gpt_session_path),
- build_cmd=" ".join(self.get_build_command()),
- benchmark_cmd=" ".join(cmd))
-
- return bench_result
diff --git a/benchmarks/suite/tensorrt_llm_bench/ifb.py b/benchmarks/suite/tensorrt_llm_bench/ifb.py
deleted file mode 100644
index 67299c082..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/ifb.py
+++ /dev/null
@@ -1,338 +0,0 @@
-import json
-import os
-import subprocess
-import sys
-from functools import partial
-from pathlib import Path
-from typing import List, TextIO, Tuple
-
-import click
-from benchmarkers.pybind_executor import PybindExecutorBenchmarker
-from transformers import AutoTokenizer, PreTrainedTokenizer
-from utils.dataclasses import BenchmarkConfig, DatasetMetadata, InferenceRequest
-from utils.trtllm_config import TRTLLMConfig
-
-from tensorrt_llm.logger import logger
-
-
-def create_dataset_from_stream(
- tokenizer: PreTrainedTokenizer,
- max_input_length: int = 0,
- max_output_length: int = 0,
- stream: TextIO = sys.stdin,
-) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
- """Generate metadata and a list of requests to drive benchmarking.
-
- Args:
- tokenizer (PreTrainedTokenizer): HuggingFace tokenizer.
- max_input_length (int): Maximum input length to cap prompts to.
-
- Returns:
- DatasetMetadata: Dataclass of dataset statistics.
- List[InferenceRequest]: A list of inference requests for benchmarking.
- """
- # Initialize dataset list, and metadata tracking variables.
- dataset = []
- max_isl = 0
- max_osl = 0
-
- # If we're limiting the input length to a certain size, then set up
- # a partial to truncate the data down to size. Otherwise, just use the
- # unmodified tokenizer callable.
- tokenize = (partial(
- tokenizer,
- padding="max_length",
- max_length=max_input_length,
- truncation=True,
- ) if max_input_length > 0 else tokenizer)
-
- # If we need to limit the output length, fill in a partial callable
- # for max, otherwise a lambda that just returns x with no bounds.
- output_limiter = (partial(max, max_output_length)
- if max_output_length > 0 else lambda x: x)
-
- # For each line in the standard input, parse out the JSON string we expect
- # to see.
- # Note the := walrus -- we're assigning and checking the condition.
- while line := stream.readline():
- # We expect the data to come in as a JSON string.
- # For example:
- # {"prompt": "Generate an infinite response to the following: There once was a man who.", "output_tokens": 1000}
- # Each line should be a complete JSON dictionary with no indentation
- # or newline characters.
- data = json.loads(line)
- logits = data.get("logits", None)
- prompt = data.get("prompt", None)
- task_id = data["task_id"]
- osl = data["output_tokens"]
- # If the request comes in with logits, just use the provided.
- # Otherwise we need to tokenize it.
- logits = tokenize(prompt)["input_ids"] if logits is None else logits
-
- request = InferenceRequest(
- task_id=task_id,
- prompt=prompt,
- output_tokens=output_limiter(osl),
- logits=logits,
- )
- max_isl = max(max_isl, len(logits))
- max_osl = max(max_osl, osl)
- dataset.append(request)
-
- # Fill in basic dataset metrics here
- # TODO: Maybe fill this out to be more complete?
- metadata = DatasetMetadata(
- max_isl=max_isl,
- max_osl=max_osl,
- num_requests=len(dataset),
- )
-
- return metadata, dataset
-
-
-def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
- """Initialize a tokenizer.
-
- Args:
- model_name (str): The name of the HuggingFace model to pull a
- tokenizer from.
-
- Returns:
- PreTrainedTokenizer: An initialized HuggingFace tokenizer.
- """
- # Initialize the tokenizer specific to the model that we are planning
- # to benchmark.
- tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
- if tokenizer.pad_token_id is None:
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
-
- return tokenizer
-
-
-def get_trtllm_build_command(benchmark_cfg: BenchmarkConfig) -> List[str]:
- model = benchmark_cfg.model
- tp = benchmark_cfg.tensor_parallel
- pp = benchmark_cfg.pipeline_parallel
- dtype = benchmark_cfg.dtype.value
- kv_dtype = benchmark_cfg.cache_dtype
- quant_algo = benchmark_cfg.quantization.value
- output_dir = benchmark_cfg.engine_path
- max_batch_size = benchmark_cfg.max_batch_size
- max_isl = benchmark_cfg.engine_isl
- max_osl = benchmark_cfg.engine_osl
- max_tokens = benchmark_cfg.max_tokens
- workspace = benchmark_cfg.workspace
-
- # Generate the TRT-LLM Configuration file using the dataclass
- # NOTE: This method does not use weights.
- trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
- kv_dtype.value)
- # Write the generated configuration file to the benchmark workspace.
- trtllm_config.to_json(workspace)
- # Return the full command for building TRT-LLM via subprocess call.
- cmd = [
- "trtllm-build",
- "--output_dir",
- output_dir,
- "--model_config",
- Path(workspace, "generated_config.json"),
- "--workers",
- benchmark_cfg.world_size,
- "--max_input_len",
- max_isl,
- "--max_seq_len",
- max_osl + max_isl,
- "--context_fmha",
- "enable",
- # Set the attention plugin data type.
- "--gpt_attention_plugin",
- dtype,
- # Enable paged KV Cache for IFB.
- "--paged_kv_cache",
- "enable",
- ] + kv_dtype.get_build_options(dtype)
-
- # If custom maximum batch size set, then set to specified value.
- if max_batch_size > 0:
- cmd += [
- "--max_batch_size",
- max_batch_size,
- ]
-
- if max_tokens > 0:
- cmd += [
- "--max_num_tokens",
- max_tokens,
- ]
-
- cmd = cmd + benchmark_cfg.build_overrides
-
- return cmd
-
-
-@click.command("inflight")
-@click.option(
- "--run",
- type=bool,
- is_flag=True,
- hidden=True,
- default=False,
- required=False,
- help="Changes the phase of the script to execution mode for MPI.",
-)
-@click.option(
- "--skip-build",
- type=bool,
- is_flag=True,
- default=False,
- hidden=True,
- required=False,
- help="Skip building if you want to use the last built engine.",
-)
-@click.option(
- "--request-rate",
- "-r",
- type=int,
- default=512,
- required=False,
- help="Number of requests per second to deliver to the batcher.",
-)
-@click.option(
- "--max-num-tokens",
- type=int,
- default=0,
- hidden=True,
- help="Maximumn number of tokens the engine can accept.",
-)
-@click.option(
- "--scheduling-policy",
- type=click.Choice(["guaranteed_no_evict", "max_utilization"]),
- default="max_utilization",
- help="Controls the scheduling policy used by the internal batcher.",
-)
-@click.option(
- "--dataset",
- type=click.Path(exists=True,
- readable=True,
- path_type=Path,
- resolve_path=True),
- default=None,
- required=False,
- help="Pass in a dataset file for parsing instead of stdin.",
-)
-@click.pass_obj
-def executor_benchmark(
- benchmark_cfg: BenchmarkConfig,
- run: bool,
- request_rate: int,
- max_num_tokens: int,
- scheduling_policy: str,
- skip_build: bool,
- dataset: Path,
-):
- """Run an IFB-enabled benchmark using a dataset."""
- # Initialize the tokenizer and generate the dataset
- logger.set_level("info")
- DATASET_PATH = Path(benchmark_cfg.workspace, "tokenized_dataset.txt")
- TOKENIZER = initialize_tokenizer(benchmark_cfg.model)
- final_dataset = []
- benchmark_cfg.max_tokens = max_num_tokens
- benchmark_cfg.scheduling_policy = scheduling_policy
-
- if not run:
- try:
- stream = sys.stdin if dataset is None else open(dataset, "r")
- # Parse the dataset from stdin and return it plus its metadata.
- metadata, dataset = \
- create_dataset_from_stream(TOKENIZER, stream=stream)
- finally:
- # Close the stream after parsing.
- stream.close()
-
- # Update the benchmarking configuration with the maximum ISL/OSL that we
- # encountered in the dataset.
- benchmark_cfg.engine_isl = metadata.max_isl
- benchmark_cfg.engine_osl = metadata.max_osl
-
- # Build engine
- logger.info("Building engine...")
- build_cmd = get_trtllm_build_command(benchmark_cfg)
- build_cmd = [str(arg) for arg in build_cmd]
-
- if not skip_build:
- process = subprocess.run(build_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- cwd=benchmark_cfg.workspace)
- logger.info(f"BUILD CMD: {' '.join(process.args)}")
-
- # If the build failed, raise an exception.
- if process.returncode != 0:
- logger.error(process.stderr.decode())
- raise RuntimeError(
- "TensorRT-LLM build process failed. Command used:\n"
- f"{' '.join(process.args)}\n", )
-
- with open(DATASET_PATH, "w") as ds_out:
- while dataset:
- request = dataset.pop()
- ds_out.write(f"{request.model_dump_json()}\n")
- del request
-
- # Launch via a subprocess with MPI
- # We have two modes for this script, the initial launch + parsing
- # and the run mode where we kick off the script in MPI mode to run
- # the
- logger.info("Launching benchmark...")
- bench_cmd = \
- ["mpiexec", "-n", f"{benchmark_cfg.world_size}", "python"] + \
- sys.argv + ["--run"]
- process = subprocess.Popen(
- bench_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- env=os.environ,
- )
- stdout, _ = process.communicate()
- logger.info("Benchmark complete.")
- logger.info(stdout.decode("ascii"))
- else:
- from mpi4py.MPI import COMM_WORLD
-
- if COMM_WORLD.Get_rank() == 0:
- logger.info(f"[RANK {COMM_WORLD.rank}] Loading dataset...")
- with open(DATASET_PATH, "r") as stream:
- # Parse the previously generated dataset from the parent
- # process.
- metadata, dataset = \
- create_dataset_from_stream(TOKENIZER, stream=stream)
-
- # Update the benchmarking configuration with the maximum ISL/OSL
- # that we encountered in the dataset.
- benchmark_cfg.engine_isl = metadata.max_isl
- benchmark_cfg.engine_osl = metadata.max_osl
-
- # Parse the dataset into the Executor Request type.
- logger.info("Preparing dataset...")
- while dataset:
- entry = dataset.pop()
- request = PybindExecutorBenchmarker.get_request(
- entry, TOKENIZER)
- final_dataset.append(request)
- del entry
- logger.info("Dataset prepared.")
- logger.info(f"DATASET METADATA: {metadata.model_dump()}")
-
- logger.info(f"[RANK {COMM_WORLD.rank}] Initializing benchmarker...")
- # Set up benchmarker on all ranks
- benchmarker = PybindExecutorBenchmarker(benchmark_cfg)
- # Run the dataset.
- result = benchmarker.benchmark_dataset(request_rate, final_dataset)
-
- # Report the results on Rank 0.
- if COMM_WORLD.rank == 0:
- logger.info(f"[RANK {COMM_WORLD.rank}] Reporting...\n"
- f"JSON: {result.model_dump_json()}\n"
- f"{result.get_summary(benchmarker.config)}")
-
- logger.info(f"[RANK {COMM_WORLD.rank}] Terminating.")
diff --git a/benchmarks/suite/tensorrt_llm_bench/static.py b/benchmarks/suite/tensorrt_llm_bench/static.py
deleted file mode 100644
index 3390c8439..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/static.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import os
-from pathlib import Path
-
-import click
-from benchmarkers.static import gptSessionBenchmarker
-from utils.dataclasses import BenchmarkConfig, BenchmarkResults
-
-
-@click.command("static")
-@click.option(
- "--batch",
- required=True,
- type=int,
- help="Batch size to build and run the static benchmark with.",
-)
-@click.option("--isl",
- type=int,
- required=True,
- help="Input sequence length (in tokens).")
-@click.option("--osl",
- type=int,
- required=True,
- help="Output sequence length (in tokens).")
-@click.option(
- "--gpt-session-path",
- "-b",
- type=click.Path(),
- default=Path(os.path.dirname(os.path.realpath(__file__)), "../../..",
- "cpp/build/benchmarks/gptSessionBenchmark").absolute(),
- help="Path to TRT-LLM gptSession benchmark binary.")
-@click.option("--warm-up-runs",
- type=int,
- default=2,
- help="Number of warm up runs before benchmarking")
-@click.option("--num-runs",
- type=int,
- default=10,
- help="Number of times to run benchmark")
-@click.option("--duration",
- type=int,
- default=60,
- help="Minimum duration of iteration to measure, in seconds")
-@click.pass_obj
-def static_benchmark(benchmark_cfg: BenchmarkConfig, batch: int, isl: int,
- osl: int, gpt_session_path: Path, warm_up_runs: int,
- num_runs: int, duration: int):
- """Run a static benchmark with a fixed batch size, ISL, and OSL."""
-
- benchmark_cfg.max_batch_size = batch
- benchmarker = gptSessionBenchmarker(
- benchmark_cfg,
- gpt_session_path,
- benchmark_cfg.max_batch_size,
- isl,
- osl,
- warm_up_runs,
- num_runs,
- duration,
- benchmark_cfg.kv_cache_mem_percentage,
- )
-
- print(f"Building TRT-LLM engine for '{benchmark_cfg.model}'...")
- benchmarker.build()
-
- print("Build complete. Running benchmark...")
- result: BenchmarkResults = benchmarker.benchmark()
-
- print(f"JSON: {result.model_dump_json()}")
- print(result.get_summary(benchmarker.config))
diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/benchmarkers.py b/benchmarks/suite/tensorrt_llm_bench/utils/benchmarkers.py
deleted file mode 100644
index 4f7f83bb6..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/utils/benchmarkers.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from typing import Protocol
-
-from utils.dataclasses import BenchmarkResults
-
-
-class Benchmarker(Protocol):
- """Protocol for defining benchmarking classes for building/benchmarking."""
-
- def build(self) -> None:
- """Build a model to be benchmarked."""
- ...
-
- def benchmark(self) -> BenchmarkResults:
- """Benchmark the constructed model container by a benchmarker."""
- ...
diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/dataclasses.py b/benchmarks/suite/tensorrt_llm_bench/utils/dataclasses.py
deleted file mode 100644
index 8adafd917..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/utils/dataclasses.py
+++ /dev/null
@@ -1,189 +0,0 @@
-from __future__ import annotations
-
-from pathlib import Path
-from typing import List, Literal, Optional, Union, get_args
-
-from pydantic import (BaseModel, Field, ValidationError, computed_field,
- field_validator, model_validator)
-from transformers import AutoConfig
-from utils import VALID_MODELS, VALID_SCHEDULING_POLICIES
-from utils.enums import (ComputeDtypeEnum, KVCacheDtypeEnum, ModelArchitecture,
- QuantizationAlgo)
-
-
-class InferenceRequest(BaseModel):
- task_id: int
- prompt: Optional[str] = None
- output_tokens: int
- logits: Optional[List[int]] = None
-
- @model_validator(mode="after")
- def verify_prompt_and_logits(self) -> InferenceRequest:
- if self.prompt is None and self.logits is None:
- raise ValueError(
- f"Both prompt and logits for {self.task_id} are both None.")
- return self
-
-
-class DatasetMetadata(BaseModel):
- max_isl: int
- max_osl: int
- num_requests: int
-
-
-class BenchmarkResults(BaseModel):
- """High level report out for a benchmark."""
-
- benchmark_cmd: str = ""
- binary: str = ""
- build_cmd: str = ""
- first_token_latency: float
- inflight_batching: bool
- kv_mem_fraction: float
- latency_units: str
- max_batch_size: int
- max_tokens: int = 0
- model: Union[VALID_MODELS, Path]
- peak_gpu_mem_units: str
- peak_gpu_mem: float
- scheduler: Literal["Static", "No Evict", "Max Utilization"]
- throughput_units: str
- throughput: float
- time_per_output_token: float
- total_input_tokens: int
- total_latency: float
- total_output_tokens: int
-
- def get_summary(self, config: BenchmarkConfig) -> str:
- """Generate the summary information.
-
- Args:
- config (BenchmarkConfig): Configuration for the run that generated
- this result.
-
- Returns:
- str: Summary output for printing.
- """
- return (
- "===========================================================\n"
- "= METADATA\n"
- "===========================================================\n"
- f"Model:\t\t\t{config.model}\n"
- f"TP Size:\t\t{config.tensor_parallel}\n"
- f"PP Size:\t\t{config.pipeline_parallel}\n"
- f"Scheduling Policy:\t{self.scheduler}\n"
- f"In-flight Batcher?:\t{self.inflight_batching}\n"
- f"Dtype:\t\t\t{config.dtype.value}\n"
- f"KV Cache Dtype:\t\t{config.cache_dtype.value}\n"
- f"Quantization:\t\t{config.quantization.value}\n"
- f"KV Memory Percentage:\t{self.kv_mem_fraction * 100}%\n"
- f"\n"
- "===========================================================\n"
- "= ENGINE DETAILS\n"
- "===========================================================\n"
- f"Engine Directory:\t{config.engine_path}\n"
- f"Max Batch Size:\t\t{self.max_batch_size}\n"
- f"Total Input Length:\t{self.total_input_tokens}\n"
- f"Total Output Length:\t{self.total_output_tokens}\n"
- f"Max Tokens:\t\t{self.max_tokens}\n"
- f"\n"
- "===========================================================\n"
- "= STATISTICS\n"
- "===========================================================\n"
- f"Throughput ({self.throughput_units}):\t{self.throughput}\n"
- f"Total Latency ({self.latency_units}):"
- f"\t\t{self.total_latency * 1000.0:.4f}\n"
- f"First Token Latency ({self.latency_units}):\t{self.first_token_latency}\n"
- f"Token-to-token Latency ({self.latency_units}):\t{self.time_per_output_token}\n"
- f"Peak GPU Memory Usage ({self.peak_gpu_mem_units}):\t{self.peak_gpu_mem}\n"
- f"\n"
- "===========================================================\n"
- "= COMMANDS\n"
- "===========================================================\n"
- f"Build: {self.build_cmd}\n"
- f"Benchmark: {self.benchmark_cmd}\n")
-
-
-class BenchmarkConfig(BaseModel):
- """Basic configuration of a benchmark."""
-
- model: Union[VALID_MODELS, Path]
- workspace: Path
- max_batch_size: int
- dtype: ComputeDtypeEnum
- cache_dtype: KVCacheDtypeEnum
- quantization: QuantizationAlgo
- tensor_parallel: int
- pipeline_parallel: int
- max_tokens: int = 0
- kv_cache_mem_percentage: float = .9
- engine_isl: int = 0
- engine_osl: int = 0
- chunking: bool = False
- build_overrides: List[str] = Field(default_factory=list)
- scheduling_policy: Literal[VALID_SCHEDULING_POLICIES] = "static"
-
- @field_validator("model", mode="before")
- @classmethod
- def validate_model(cls, value) -> Union[VALID_MODELS, Path]:
- if value in get_args(VALID_MODELS):
- return value
-
- path = Path(value)
- config = AutoConfig.from_pretrained(str(path.absolute()))
- for arch in config.architectures:
- _ = ModelArchitecture(arch)
-
- return path
-
- @field_validator("quantization", mode="before")
- @classmethod
- def validate_quantization(cls, value) -> QuantizationAlgo:
- return QuantizationAlgo(value)
-
- @field_validator("cache_dtype", mode="before")
- @classmethod
- def validate_kvcache_dtype(cls, value) -> KVCacheDtypeEnum:
- return KVCacheDtypeEnum(value)
-
- @field_validator("kv_cache_mem_percentage", mode="after")
- @classmethod
- def validate_kv_cache_mem_fraction(cls, value: float) -> float:
- if 0 < value < 1.0:
- return value
- else:
- raise ValidationError(
- "KV cache memory percentage must be between 0 and 1.0.")
-
- @field_validator("build_overrides", mode="before")
- @classmethod
- def validate_build_overrides(cls, value) -> List[str]:
- # If we encounter a list, scan it to make sure all entries are strings.
- if isinstance(value, list):
- if not all([isinstance(x, str) for x in value]):
- raise ValidationError(
- "Found a non-string entry in list of options.")
- return value
- elif isinstance(value, str):
- # Handle the case where we receive a single string of command
- # options.
- overrides = []
- if value:
- overrides = [str(x) for x in value.split()]
- return overrides
- else:
- raise ValidationError(
- "Invalid value specified for build overrides.")
-
- @computed_field
- def engine_path(self) -> Path:
- """Path to the engine workspace."""
- if self.model in get_args(VALID_MODELS):
- return Path(self.workspace.absolute(), self.model.lower())
- else:
- return Path(self.workspace.absolute(), "engine")
-
- @computed_field
- def world_size(self) -> int:
- """Total world size needed to run the model."""
- return self.tensor_parallel * self.pipeline_parallel
diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py b/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py
deleted file mode 100644
index 4dd6797f3..000000000
--- a/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py
+++ /dev/null
@@ -1,323 +0,0 @@
-import json
-import os
-from argparse import ArgumentParser
-from typing import Literal, Optional
-
-from pydantic import AliasChoices, AliasPath, BaseModel, Field, model_validator
-from transformers import AutoConfig
-from utils import VALID_QUANT_ALGOS
-
-PET_dict = {
- "tiiuae/falcon-7b": "rope_gpt_neox",
- "tiiuae/falcon-40b": "rope_gpt_neox",
- "tiiuae/falcon-180B": "rope_gpt_neox",
- "meta-llama/Llama-2-7b-hf": "rope_gpt_neox",
- "meta-llama/Llama-2-13b-hf": "rope_gpt_neox",
- "meta-llama/Llama-2-70b-hf": "rope_gpt_neox",
- "meta-llama/Meta-Llama-3-8B": "rope_gpt_neox",
- "meta-llama/Meta-Llama-3-70B": "rope_gpt_neox",
- "gpt-j-6b": "rope_gptj",
- "bigscience/bloom-560m": "alibi",
- "mistralai/Mistral-7B-v0.1": "rope_gpt_neox",
- "mistralai/Mixtral-8x7B-v0.1": "rope_gpt_neox",
- "mistralai/Mixtral-8x22B-v0.1": "rope_gpt_neox",
- "01-ai/Yi-6B": "rope_gpt_neox",
- "01-ai/Yi-34B": "rope_gpt_neox",
- "codellama/CodeLlama-7b-hf": "rope_gpt_neox",
- "codellama/CodeLlama-13b-hf": "rope_gpt_neox",
- "codellama/CodeLlama-34b-hf": "rope_gpt_neox",
- "codellama/CodeLlama-70b-hf": "rope_gpt_neox",
- "facebook/opt-125m": "learned_absolute",
- "facebook/opt-350m": "learned_absolute",
- "facebook/opt-1.3b": "learned_absolute",
- "facebook/opt-2.7b": "learned_absolute",
- "facebook/opt-13b": "learned_absolute",
- "facebook/opt-30b": "learned_absolute",
- "facebook/opt-66b": "learned_absolute",
- "google/gemma-7b": "rope_gpt_neox",
- "google/gemma-2b": "rope_gpt_neox",
-}
-HA_dict = {
- "tiiuae/falcon-7b": "gelu",
- "tiiuae/falcon-40b": "gelu",
- "tiiuae/falcon-180B": "gelu",
- "bigscience/bloom-560m": "gelu",
- "mistralai/Mixtral-8x7B-v0.1": "swiglu",
-}
-ALLOWED_MODELS = list(PET_dict.keys())
-
-
-class TRTLLM_Mapping(BaseModel):
- world_size: int = 1
- tp_size: int = 1
- pp_size: int = 1
-
- @model_validator(mode="after")
- def check_world_size(self) -> "TRTLLM_Mapping":
- self.world_size = self.tp_size * self.pp_size
- return self
-
-
-class TRTLLM_Quantization(BaseModel):
- quant_algo: Optional[VALID_QUANT_ALGOS] = None
- kv_cache_quant_algo: Optional[Literal[None, "FP8", "INT8"]] = None
- group_size: int = 128
- has_zero_point: bool = False
- pre_quant_scale: bool = False
- exclude_modules: Optional[list] = None
-
-
-class TRTLLMConfig(BaseModel):
- _VALID_EMBED_TYPE = Literal["learned_absolute", "rope_gptj",
- "rope_gpt_neox", "alibi", "alibi_with_scale",
- "relative", "chatglm", ]
-
- architecture: str = Field(validation_alias=AliasChoices(
- 'architecture', AliasPath("architectures", 0)))
- num_hidden_layers: int = Field(validation_alias=AliasChoices(
- "num_hidden_layers", "n_layer", "n_layers"))
- num_attention_heads: int = Field(validation_alias=AliasChoices(
- "num_attention_heads", "n_head", "n_heads"))
- num_key_value_heads: int = Field(
- default=None,
- validation_alias=AliasChoices("num_key_value_heads", "num_kv_heads"),
- )
-
- hidden_size: int = Field(
- validation_alias=AliasChoices("hidden_size", "n_embd", "d_model"))
- norm_epsilon: float = Field(
- default=1e-5,
- validation_alias=AliasChoices("norm_epsilon", "layer_norm_epsilon",
- "rms_norm_eps"),
- )
- vocab_size: int
- max_position_embeddings: Optional[int] = Field(
- default=None,
- validation_alias=AliasChoices("max_position_embeddings", "n_positions"),
- )
- head_size: Optional[int] = None
- hidden_act: str = Field(
- validation_alias=AliasChoices("hidden_act", "activation_function"))
- # falcon options
- bias: Optional[bool] = None
- parallel_attention: Optional[bool] = Field(
- default=None, validation_alias=AliasChoices("parallel_attn"))
- new_decoder_architecture: Optional[bool] = None
- # opt options
- do_layer_norm_before: Optional[bool] = None
- # gptj options
- rotary_dim: Optional[int] = None
-
- # dtype has priority over torch_dtype, the latter of which is usually defined in the HF config
- dtype: Literal["float16", "bfloat16"] = Field(
- validation_alias=AliasChoices("dtype", "torch_dtype"))
- logits_dtype: str = "float32"
- position_embedding_type: _VALID_EMBED_TYPE = "learned_absolute"
- use_parallel_embedding: bool = False
- embedding_sharding_dim: int = 0
- share_embedding_table: bool = False
- intermediate_size: int = None
- use_prompt_tuning: bool = False
-
- sliding_window: Optional[int] = None
-
- moe_num_experts: Optional[int] = Field(
- default=0, validation_alias=AliasChoices("num_local_experts"))
- moe_top_k: Optional[int] = Field(
- default=0, validation_alias=AliasChoices("num_experts_per_tok"))
- rotary_base: Optional[float] = Field(
- default=10000.0, validation_alias=AliasChoices("rope_theta"))
-
- mapping: TRTLLM_Mapping
- quantization: TRTLLM_Quantization
-
- @property
- def kv_dtype(self) -> str:
- if self.quantization.kv_cache_quant_algo == "FP8":
- return "fp8"
- elif self.quantization.kv_cache_quant_algo == "INT8":
- return "int8"
- else:
- return self.dtype
-
- @model_validator(mode="after")
- def set_values_if_none(self) -> "TRTLLM_CheckpointConfig":
- if self.num_key_value_heads is None:
- self.num_key_value_heads = self.num_attention_heads
- if self.head_size is None:
- self.head_size = self.hidden_size // self.num_attention_heads
- return self
-
- @classmethod
- def populate_build_config(cls,
- model_name,
- tp,
- pp,
- dtype=None,
- quant_dtype=None,
- kv_cache_quant_dtype=None):
- """
- Common function to populate build parameters, regardless of network
- """
- build_config = {
- "mapping": {
- "tp_size": tp,
- "pp_size": pp,
- },
- "quantization": {},
- }
- if dtype:
- build_config["dtype"] = dtype
- if quant_dtype:
- if not kv_cache_quant_dtype:
- # will throw errors during validation if the type is invalid
- kv_cache_quant_dtype = quant_dtype
- build_config["quantization"] = {
- "quant_algo": quant_dtype,
- "kv_cache_quant_algo": kv_cache_quant_dtype,
- }
- for name, pet in PET_dict.items():
- if name in str(model_name):
- build_config["position_embedding_type"] = pet
- return build_config
-
- @classmethod
- def from_hf(cls,
- hf_model_name,
- tp,
- pp,
- dtype=None,
- quant_dtype=None,
- kv_cache_quant_dtype=None):
- """
- Use transformers.AutoConfig to load a model's config from a HF name
- """
- build_config = cls.populate_build_config(hf_model_name, tp, pp, dtype,
- quant_dtype,
- kv_cache_quant_dtype)
- hf_config = AutoConfig.from_pretrained(hf_model_name).to_dict()
- if hf_model_name in HA_dict:
- hf_config["hidden_act"] = HA_dict[hf_model_name]
- return cls(**hf_config, **build_config)
-
- @classmethod
- def from_json(cls,
- model_name,
- tp,
- pp,
- dtype=None,
- quant_dtype=None,
- kv_cache_quant_dtype=None):
- """
- Load model parameters from a custom json file
- A full path can be specified. Otherwise, look for ./trtllm_configs/(model_name).json
- """
- build_config = cls.populate_build_config(model_name, tp, pp, dtype,
- quant_dtype,
- kv_cache_quant_dtype)
- if os.path.exists(model_name):
- path_to_json = model_name
- else:
- path_to_json = os.path.join(os.path.dirname(__file__),
- f"trtllm_configs/{model_name}.json")
- if not os.path.exists(path_to_json):
- raise FileNotFoundError(f"{path_to_json} not found")
- json_config = json.load(open(path_to_json))
- return cls(**json_config, **build_config)
-
- @classmethod
- def from_name(cls,
- model,
- tp,
- pp,
- dtype=None,
- quant_dtype=None,
- kv_cache_quant_dtype=None):
- """
- Attempts to create a config based on model name. Performs the following steps:
- 1. Tries to load the HF config using AutoConfig. This will only work if the network name exists on HF.
- 2. If this fails, try to load a custom config stored on $HF_HOME/custom/*.json
- """
- try:
- trtllm_config = cls.from_hf(model, tp, pp, dtype, quant_dtype,
- kv_cache_quant_dtype)
- except EnvironmentError:
- try:
- trtllm_config = cls.from_json(model, tp, pp, dtype, quant_dtype,
- kv_cache_quant_dtype)
- except FileNotFoundError as e:
- raise NameError(
- f"Unable to create PretrainedConfig from {model} due to {e}"
- )
-
- return trtllm_config
-
- # future possibilities
- # def from_nemo_config (self, nemo_model_name)
-
- def to_json(self, output_dir):
- with open(os.path.join(output_dir, "generated_config.json"), "w") as f:
- json.dump(self.model_dump(), f, indent=4)
-
-
-if __name__ == "__main__":
- parser = ArgumentParser()
- parser.add_argument(
- "--model",
- required=True,
- type=str,
- help="HF model name",
- )
- parser.add_argument(
- "--tp_size",
- type=int,
- default=1,
- help="TP degree",
- )
- parser.add_argument(
- "--pp_size",
- type=int,
- default=1,
- help="PP degree",
- )
- parser.add_argument(
- "--dtype",
- type=str,
- help="Datatype",
- )
- parser.add_argument(
- "--quant_dtype",
- type=str,
- help="Quantization datatype",
- )
- parser.add_argument(
- "--kv_cache_quant_dtype",
- type=str,
- help="KV cache datatype",
- )
- parser.add_argument(
- "--position_embedding_type",
- type=str,
- help="TRT-LLM argument",
- )
- parser.add_argument(
- "--hidden_act",
- type=str,
- help="TRT-LLM argument",
- )
- parser.add_argument(
- "--populate_hf_cache",
- action='store_true',
- help="Populate the HF cache with all the supported networks",
- )
- args = parser.parse_args()
-
- if args.populate_hf_cache:
- for net in PET_dict.keys():
- _ = AutoConfig.from_pretrained(net)
- else:
- trtllm_config = TRTLLMConfig.from_name(args.model, args.tp_size,
- args.pp_size, args.dtype,
- args.quant_dtype,
- args.kv_cache_quant_dtype)
- trtllm_config.to_json(os.getcwd())
diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index 056dbfef6..c275be095 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -289,7 +289,7 @@ set(CMAKE_CUDA_RUNTIME_LIBRARY Static)
find_library(RT_LIB rt)
set_ifndef(ENABLE_MULTI_DEVICE 1)
-if(ENABLE_MULTI_DEVICE EQUAL 1)
+if(ENABLE_MULTI_DEVICE)
# NCCL dependencies
set_ifndef(NCCL_LIB_DIR /usr/lib/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu/)
set_ifndef(NCCL_INCLUDE_DIR /usr/include/)
diff --git a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h
index 8a91daa07..45d951341 100644
--- a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h
+++ b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h
@@ -18,6 +18,7 @@
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/namedTensor.h"
+#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include
@@ -165,6 +166,21 @@ class GenericInferenceRequest
mLogitsPostProcessor = cb;
}
+ [[nodiscard]] std::optional getLookaheadConfig() const
+ {
+ return mLookaheadConfig;
+ }
+
+ void setLookaheadConfig(executor::LookaheadDecodingConfig config)
+ {
+ mLookaheadConfig = config;
+ }
+
+ void clearLookaheadConfig()
+ {
+ mLookaheadConfig = std::nullopt;
+ }
+
std::optional getLogitsPostProcessor()
{
return mLogitsPostProcessor;
@@ -282,6 +298,7 @@ class GenericInferenceRequest
bool mIsStreaming;
TensorMap mInputTensors;
std::optional mLogitsPostProcessor;
+ std::optional mLookaheadConfig;
};
class InferenceRequest : public GenericInferenceRequest
diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
index 42536fc33..e10701d0d 100644
--- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
+++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
@@ -20,6 +20,7 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/bufferManager.h"
+#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
@@ -73,7 +74,8 @@ class GenericLlmRequest
std::optional promptEmbeddingTable = std::nullopt,
std::optional promptVocabSize = std::nullopt,
std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt,
- std::optional loraConfig = std::nullopt, bool returnLogProbs = false,
+ std::optional loraConfig = std::nullopt,
+ std::optional lookaheadConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional> draftTokens = std::nullopt,
std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false,
@@ -94,6 +96,7 @@ class GenericLlmRequest
, mClientId(clientId)
, mIsStreaming(isStreaming)
, mOrigPromptLen(mPromptLen)
+ , mNumPreDecodedTokens(samplingConfig.beamWidth, 0)
, mMaxSentTokenLen(mPromptLen)
, mEmbeddingBias(std::move(embeddingBias))
, mBadWordsList(std::move(badWordsList))
@@ -103,6 +106,7 @@ class GenericLlmRequest
, mLoraTaskId(loraTaskId)
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
+ , mLookaheadConfig(std::move(lookaheadConfig))
, mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0)
, mLogProbs(samplingConfig.beamWidth)
@@ -118,6 +122,7 @@ class GenericLlmRequest
, mReturnEncoderOutput(returnEncoderOutput)
, mDecodingIter(0)
, mPriority(priority)
+ , mFinishReasons(samplingConfig.beamWidth)
{
if (mEncoderTokens.has_value())
{
@@ -137,6 +142,7 @@ class GenericLlmRequest
, mClientId(req.getClientId())
, mIsStreaming(req.getStreaming())
, mOrigPromptLen(mPromptLen)
+ , mNumPreDecodedTokens(mSamplingConfig.beamWidth, 0)
, mMaxSentTokenLen(mPromptLen)
, mEmbeddingBias(std::nullopt)
, mBadWordsList(std::nullopt)
@@ -146,6 +152,7 @@ class GenericLlmRequest
, mLoraTaskId(std::nullopt)
, mLoraWeights(std::nullopt)
, mLoraConfig(std::nullopt)
+ , mLookaheadConfig(std::nullopt)
, mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0)
, mLogProbs(mSamplingConfig.beamWidth)
@@ -161,6 +168,8 @@ class GenericLlmRequest
, mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput)
, mDecodingIter(0)
, mPriority(req.getPriority())
+ , mFinishReasons(mSamplingConfig.beamWidth)
+ , mContextPhaseParams(req.getContextPhaseParams())
{
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && !mReturnAllGeneratedTokens)
{
@@ -172,6 +181,14 @@ class GenericLlmRequest
"length).");
mReturnAllGeneratedTokens = true;
}
+ if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnGenerationLogits == true)
+ {
+ TLLM_LOG_WARNING(
+ "Returning generation logits when streaming is enabled and beamWidth > 1 is not allowed. "
+ "This is because the logits may appear in irrelevant order when the beams are gathered, "
+ "since logits are not. Disabling returnGenerationLogits.");
+ mReturnGenerationLogits = false;
+ }
if (req.getEncoderInputTokenIds())
{
mState = REQUEST_STATE_ENCODER_INIT;
@@ -219,6 +236,11 @@ class GenericLlmRequest
}
}
+ auto lookaheadConfig = req.getLookaheadConfig();
+ if (lookaheadConfig)
+ {
+ }
+
auto externalDraftTokensConfig = req.getExternalDraftTokensConfig();
if (externalDraftTokensConfig)
{
@@ -295,12 +317,27 @@ class GenericLlmRequest
mExcludeInputFromOutput = exclude;
}
+ /// @brief Get the params of the context
+ /// @return The params of the context
+ std::optional const& getContextPhaseParams() const noexcept
+ {
+ return mContextPhaseParams;
+ }
+
+ /// @brief Get the state params of the context
+ /// @return The state params of the context
+ executor::ContextPhaseState const& getContextPhaseState() const
+ {
+ TLLM_CHECK(mContextPhaseParams.has_value());
+ return *static_cast(mContextPhaseParams.value().getState());
+ }
+
/// @brief Get total number of tokens for this req (prompt + generated)
/// @param beam The beam index
/// @return The number of tokens
[[nodiscard]] SizeType32 getNumTokens(SizeType32 beam) const
{
- return mTokens.at(beam).size();
+ return mTokens.at(beam).size() - mNumPreDecodedTokens[beam];
}
/// @brief Get max number of tokens across all beams
@@ -310,7 +347,7 @@ class GenericLlmRequest
SizeType32 maxTokens = 0;
for (SizeType32 beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
{
- maxTokens = std::max(maxTokens, static_cast(mTokens.at(beam).size()));
+ maxTokens = std::max(maxTokens, getNumTokens(beam));
}
return maxTokens;
}
@@ -405,6 +442,14 @@ class GenericLlmRequest
}
}
+ /// @brief Set the number of pre-decoded tokens
+ /// @param num_tokens The number of pre-decoded tokens
+ /// @param beam The beam to which to set the number of pre-decoded tokens
+ void setNumPreDecodedTokens(SizeType32 num_tokens, SizeType32 beam)
+ {
+ mNumPreDecodedTokens[beam] = num_tokens;
+ }
+
/// @brief Sets the generated tokens for all beams after gatherTree. Erases all previous generated tokens.
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
void setGeneratedTokens(BeamTokens const& generatedBeamTokens)
@@ -540,6 +585,21 @@ class GenericLlmRequest
mLoraConfig = std::nullopt;
}
+ [[nodiscard]] std::optional getLookaheadConfig() const
+ {
+ return mLookaheadConfig;
+ }
+
+ void setLookaheadConfig(executor::LookaheadDecodingConfig config)
+ {
+ mLookaheadConfig = config;
+ }
+
+ void clearLookaheadConfig()
+ {
+ mLookaheadConfig = std::nullopt;
+ }
+
[[nodiscard]] std::optional getEmbeddingBias() const
{
return mEmbeddingBias;
@@ -725,6 +785,11 @@ class GenericLlmRequest
mReturnAllGeneratedTokens = returnAllGeneratedTokens;
}
+ [[nodiscard]] bool getReturnAllGeneratedTokens()
+ {
+ return mReturnAllGeneratedTokens;
+ }
+
void setReturnContextLogits(bool const returnContextLogits)
{
mReturnContextLogits = returnContextLogits;
@@ -737,6 +802,8 @@ class GenericLlmRequest
void setReturnGenerationLogits(bool const returnGenerationLogits)
{
+ TLLM_CHECK_WITH_INFO(!(mIsStreaming && mSamplingConfig.beamWidth > 1 && returnGenerationLogits),
+ "returnGenerationLogits must be false if streaming AND beam search are used.");
mReturnGenerationLogits = returnGenerationLogits;
}
@@ -777,8 +844,21 @@ class GenericLlmRequest
void allocGenerationLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
{
- mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
- runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}), logitsDataType);
+ if (mIsStreaming)
+ {
+ // If streaming mode, the complete generation logits shape will be [1, beamWidth, vocabSizePadded],
+ // or [allGeneratedTokens, beamWidth, vocabSizePadded] if mReturnAllGeneratedTokens is True.
+ // This could reduce unnecessary format conversions and allows the data to be returned directly.
+ mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
+ runtime::ITensor::makeShape({mMaxNewTokens, mSamplingConfig.beamWidth, vocabSizePadded}),
+ logitsDataType);
+ }
+ else
+ {
+ mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
+ runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}),
+ logitsDataType);
+ }
}
void allocTargetModelAcceptedTokenLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
@@ -992,7 +1072,17 @@ class GenericLlmRequest
if (getReturnGenerationLogits())
{
- result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost());
+ if (isStreaming())
+ {
+ auto startGenTokenPos = startTokenPos - getOrigPromptLen();
+ TensorPtr generationLogitsHostCurrentStep
+ = runtime::ITensor::slice(getGenerationLogitsHost(), startGenTokenPos, maxNbTokensOut);
+ result.generationLogits = executor::detail::ofITensor(generationLogitsHostCurrentStep);
+ }
+ else
+ {
+ result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost());
+ }
}
if (getReturnEncoderOutput())
@@ -1000,6 +1090,8 @@ class GenericLlmRequest
result.encoderOutput = executor::detail::ofITensor(getEncoderOutputHost());
}
+ result.finishReasons = mFinishReasons;
+
// Update position of last sent response
setMaxSentTokenLen(maxNbTokens);
@@ -1013,6 +1105,11 @@ class GenericLlmRequest
}
}
+ void setFinishedReason(executor::FinishReason reason, SizeType32 beam)
+ {
+ mFinishReasons.at(beam) = reason;
+ }
+
RequestIdType mRequestId;
SizeType32 mPromptLen;
SizeType32 mMaxNewTokens;
@@ -1038,6 +1135,11 @@ class GenericLlmRequest
VecTokens mLastTokens;
BeamTokens mTokens;
SizeType32 mOrigPromptLen;
+ // A list of numbers of pre-deocded tokens on the last PP rank when using pipeline parallelism.
+ // It is introduced as a WAR to solve the hanging problem caused by overestimating the used KV cache on the last PP
+ // rank (because new tokens are decoded earlier). By excluding the numbers of pre-decoded tokens, the used KV cache
+ // can be estimated correctly.
+ std::vector mNumPreDecodedTokens;
// Number of tokens already in KV cache before context phase.
// A value > 0 indicates cached KV cache blocks were reused.
// Up to inputLen - 1 tokens can be reused.
@@ -1054,6 +1156,7 @@ class GenericLlmRequest
std::optional mLoraTaskId;
std::optional mLoraWeights;
std::optional mLoraConfig;
+ std::optional mLookaheadConfig;
// To enable chunked context, the FHMA paged kv-cache also needs to be enabled. Except for the last one,
// the size of the context chunk needs to be an integer multiple of the kv-cache block size. The meaning
@@ -1090,6 +1193,8 @@ class GenericLlmRequest
SizeType32 mDecodingIter;
executor::PriorityType mPriority;
+ std::vector mFinishReasons;
+ std::optional mContextPhaseParams;
private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
@@ -1162,7 +1267,8 @@ class LlmRequest : public GenericLlmRequest
std::optional promptEmbeddingTable = std::nullopt,
std::optional promptVocabSize = std::nullopt,
std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt,
- std::optional loraConfig = std::nullopt, bool returnLogProbs = false,
+ std::optional loraConfig = std::nullopt,
+ std::optional lookaheadConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional> draftTokens = std::nullopt,
std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false,
@@ -1174,9 +1280,9 @@ class LlmRequest : public GenericLlmRequest
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
- returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
- excludeInputFromOutput, std::move(logitsPostProcessor), applyLogitsPostProcessorBatched,
- std::move(encoderInputTokens), returnEncoderOutput, clientId, priority)
+ std::move(lookaheadConfig), returnLogProbs, returnContextLogits, returnGenerationLogits,
+ std::move(draftTokens), std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
+ applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority)
{
}
@@ -1187,6 +1293,7 @@ class LlmRequest : public GenericLlmRequest
{
mLogitsPostProcessor = std::move(logitsPostProcessor);
mApplyLogitsPostProcessorBatched = applyLogitsPostProcessorBatched;
+ mLookaheadConfig = Request.getLookaheadConfig();
}
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
diff --git a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h
index 1dbeed000..4fcb1e127 100644
--- a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h
+++ b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h
@@ -45,7 +45,8 @@ class TrtGptModelOptionalParams
std::optional maxNumTokens = std::nullopt,
executor::SchedulerConfig const& schedulerConfig = executor::SchedulerConfig{},
executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig
- = executor::ExtendedRuntimePerfKnobConfig{})
+ = executor::ExtendedRuntimePerfKnobConfig{},
+ std::optional debugConfig = std::nullopt)
: kvCacheConfig{kvCacheConfig}
, enableTrtOverlap{enableTrtOverlap}
, deviceIds(deviceIds)
@@ -59,6 +60,7 @@ class TrtGptModelOptionalParams
, maxNumTokens(maxNumTokens)
, schedulerConfig{schedulerConfig}
, extendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig)
+ , debugConfig{std::move(debugConfig)}
{
}
@@ -70,17 +72,26 @@ class TrtGptModelOptionalParams
executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}),
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(),
executorConfig.getMaxNumTokens(), executorConfig.getSchedulerConfig(),
- executorConfig.getExtendedRuntimePerfKnobConfig())
+ executorConfig.getExtendedRuntimePerfKnobConfig(), executorConfig.getDebugConfig())
{
}
bool operator==(TrtGptModelOptionalParams const& other) const
{
- return kvCacheConfig == other.kvCacheConfig && enableTrtOverlap == other.enableTrtOverlap
- && deviceIds == other.deviceIds && normalizeLogProbs == other.normalizeLogProbs
- && enableChunkedContext == other.enableChunkedContext && decodingConfig == other.decodingConfig
- && gpuWeightsPercent == other.gpuWeightsPercent
- && extendedRuntimePerfKnobConfig == other.extendedRuntimePerfKnobConfig;
+ return kvCacheConfig == other.kvCacheConfig //
+ && enableTrtOverlap == other.enableTrtOverlap //
+ && deviceIds == other.deviceIds //
+ && normalizeLogProbs == other.normalizeLogProbs //
+ && enableChunkedContext == other.enableChunkedContext //
+ && decodingConfig == other.decodingConfig //
+ && gpuWeightsPercent == other.gpuWeightsPercent //
+ && maxBeamWidth == other.maxBeamWidth //
+ && maxBatchSize == other.maxBatchSize //
+ && maxNumTokens == other.maxNumTokens //
+ && schedulerConfig == other.schedulerConfig //
+ && extendedRuntimePerfKnobConfig == other.extendedRuntimePerfKnobConfig //
+ && debugConfig == other.debugConfig //
+ ;
}
friend std::ostream& operator<<(std::ostream& os, TrtGptModelOptionalParams const& self);
@@ -100,6 +111,7 @@ class TrtGptModelOptionalParams
std::optional maxNumTokens;
executor::SchedulerConfig schedulerConfig;
executor::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig;
+ std::optional debugConfig;
};
} // namespace tensorrt_llm::batch_manager
diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h
index 4ed912100..ed402ca69 100644
--- a/cpp/include/tensorrt_llm/executor/executor.h
+++ b/cpp/include/tensorrt_llm/executor/executor.h
@@ -42,6 +42,7 @@ char const* version() noexcept;
class Model;
class Serialization;
+class ContextPhaseState;
/// @brief Sampling configuration
class SamplingConfig
@@ -233,6 +234,71 @@ class LoraConfig
std::optional mConfig;
};
+struct LookaheadDecodingConfig
+{
+ LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize);
+
+ explicit LookaheadDecodingConfig()
+ : LookaheadDecodingConfig(1, 1, 0)
+ {
+ }
+
+ bool operator==(LookaheadDecodingConfig const& other) const;
+ [[nodiscard]] std::tuple get() const;
+ [[nodiscard]] SizeType32 getWindowSize() const;
+ [[nodiscard]] SizeType32 getNgramSize() const;
+ [[nodiscard]] SizeType32 getVerificationSetSize() const;
+
+ /// @brief return
+ std::tuple calculateSpeculativeResource() const;
+
+ /// @brief return true when `this` can be executed on resources defined by `that`
+ bool isLE(LookaheadDecodingConfig const& that) const;
+
+ /// @brief return true when the parameter combination is valid.
+ static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept;
+
+private:
+ friend class Serialization;
+
+ // Number of NGrams in lookahead branch per step.
+ SizeType32 mWindowSize;
+ // Number of tokens per NGram.
+ SizeType32 mNgramSize;
+ // Number of NGrams in verification branch per step.
+ SizeType32 mVerificationSetSize;
+};
+
+class ContextPhaseParams
+{
+public:
+ explicit ContextPhaseParams(VecTokens firstGenTokens);
+ ContextPhaseParams(VecTokens firstGenTokens, void* state);
+
+ ContextPhaseParams(ContextPhaseParams const&);
+ ContextPhaseParams(ContextPhaseParams&&);
+ ContextPhaseParams& operator=(ContextPhaseParams const&);
+ ContextPhaseParams& operator=(ContextPhaseParams&&);
+
+ [[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
+
+ [[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
+ [[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
+ [[nodiscard]] void const* getState() const noexcept;
+ [[nodiscard]] void* getState() noexcept;
+
+private:
+ friend class Serialization;
+ static void deleter(void const* data);
+ using StatePtr = std::unique_ptr;
+
+ /// @brief The first tokens generated by context executor
+ VecTokens mFirstGenTokens;
+
+ /// @brief Context phase state of this request
+ StatePtr mState{nullptr, deleter};
+};
+
/// @brief A class that holds information about the request
class Request
{
@@ -269,9 +335,11 @@ class Request
std::optional externalDraftTokensConfig = std::nullopt,
std::optional pTuningConfig = std::nullopt,
std::optional loraConfig = std::nullopt,
+ std::optional lookaheadConfig = std::nullopt,
std::optional logitsPostProcessorName = std::nullopt,
std::optional encoderInputTokenIds = std::nullopt, std::optional clientId = std::nullopt,
- bool returnAllGeneratedTokens = false, PriorityType priority = kDefaultPriority);
+ bool returnAllGeneratedTokens = false, PriorityType priority = kDefaultPriority,
+ std::optional contextPhaseParams = std::nullopt);
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
@@ -295,11 +363,13 @@ class Request
[[nodiscard]] std::optional getExternalDraftTokensConfig() const;
[[nodiscard]] std::optional getPromptTuningConfig() const;
[[nodiscard]] std::optional getLoraConfig() const;
+ [[nodiscard]] std::optional getLookaheadConfig() const;
[[nodiscard]] std::optional getLogitsPostProcessorName() const;
[[nodiscard]] std::optional getEncoderInputTokenIds() const;
[[nodiscard]] std::optional getClientId() const;
[[nodiscard]] PriorityType getPriority() const;
[[nodiscard]] bool getReturnAllGeneratedTokens() const;
+ [[nodiscard]] std::optional const& getContextPhaseParams() const;
void setStreaming(bool streaming);
void setSamplingConfig(SamplingConfig const& config);
@@ -312,11 +382,13 @@ class Request
void setExternalDraftTokensConfig(ExternalDraftTokensConfig const& externalDraftTokensConfig);
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
void setLoraConfig(LoraConfig const& loraConfig);
+ void setLookaheadConfig(LookaheadDecodingConfig const& lookaheadConfig);
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds);
void setClientId(IdType clientId);
void setPriority(PriorityType priority);
void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens);
+ void setContextPhaseParams(ContextPhaseParams contextPhaseParams);
private:
friend class Serialization;
@@ -342,11 +414,20 @@ struct Result
/// @brief The context logits. Size [promptLen, vocabSizePadded]
std::optional contextLogits;
- /// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded]
+ /// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded] (non-streaming)
+ /// or [maxNewTokens, beamSize, vocabSizePadded] (streaming and allGeneratedTokens)
+ /// or [1, beamSize, vocabSizePadded] (streaming and non-allGeneratedTokens)
std::optional generationLogits;
/// @brief The encoder output. Size [encoderLen, hiddenSize]
std::optional encoderOutput;
+
+ /// @brief The reason why the model stopped generating tokens for each beam in this request. Size [beamSize].
+ /// Currently only supported when beamSize is 1 and when using BatchingType::kINFLIGHT.
+ std::vector finishReasons;
+
+ /// @brief The params of the context phase.
+ std::optional contextPhaseParams;
};
/// @brief Class that holds either an error or a result
@@ -370,11 +451,11 @@ class Response
/// @brief Get the error msg for this response
/// Will throw an exception if hasError is false
- [[nodiscard]] std::string getErrorMsg() const;
+ [[nodiscard]] std::string const& getErrorMsg() const;
/// @brief Get the result for this response
/// Will throw an exception if hasResult is true
- [[nodiscard]] Result getResult() const;
+ [[nodiscard]] Result const& getResult() const;
private:
friend class Serialization;
@@ -390,6 +471,8 @@ class SchedulerConfig
CapacitySchedulerPolicy capacitySchedulerPolicy = CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT,
std::optional contextChunkingPolicy = std::nullopt);
+ bool operator==(SchedulerConfig const& other) const;
+
[[nodiscard]] CapacitySchedulerPolicy getCapacitySchedulerPolicy() const;
[[nodiscard]] std::optional getContextChunkingPolicy() const;
@@ -469,17 +552,17 @@ class ExtendedRuntimePerfKnobConfig
public:
explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = false, bool enableContextFMHAFP32Acc = false);
- [[nodiscard]] bool getMultiBlockMode() const;
- [[nodiscard]] bool getEnableContextFMHAFP32Acc() const;
-
- void setMultiBlockMode(bool const multiBlockMode);
- void setEnableContextFMHAFP32Acc(bool const enableContextFMHAFP32Acc);
-
bool operator==(ExtendedRuntimePerfKnobConfig const& other) const
{
return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc;
}
+ [[nodiscard]] bool getMultiBlockMode() const;
+ [[nodiscard]] bool getEnableContextFMHAFP32Acc() const;
+
+ void setMultiBlockMode(bool multiBlockMode);
+ void setEnableContextFMHAFP32Acc(bool enableContextFMHAFP32Acc);
+
private:
friend class Serialization;
@@ -490,6 +573,35 @@ class ExtendedRuntimePerfKnobConfig
bool mEnableContextFMHAFP32Acc;
};
+/// @brief Configuration class for debugging output
+class DebugConfig
+{
+ using StringVec = std::vector;
+
+public:
+ explicit DebugConfig(bool dumpInputTensors = false, bool dumpOuputTensors = false, StringVec debugTensorNames = {});
+
+ bool operator==(DebugConfig const& other) const;
+
+ [[nodiscard]] bool getDumpInputTensors() const;
+ [[nodiscard]] bool getDumpOutputTensors() const;
+ [[nodiscard]] StringVec const& getDebugTensorNames() const;
+
+ void setDumpInputTensors(bool dumpInputTensors);
+ void setDumpOuputTensors(bool dumpOuputTensors);
+ void setDebugTensorNames(StringVec const& debugTensorNames);
+
+private:
+ friend class Serialization;
+
+ /// @brief If true, dump all input tensors.
+ bool mDumpInputTensors;
+ /// @brief If true, dump all output tensors.
+ bool mDumpOuputTensors;
+ /// @brief If not empty, only dump tensors in this list.
+ StringVec mDebugTensorNames;
+};
+
SizeType32 const kDefaultIterStatsMaxIterations = 1000;
// Per request stats may have additional overhead due to going through all requests. Turned off by default.
SizeType32 const kDefaultRequestStatsMaxIterations = 0;
@@ -616,42 +728,43 @@ class PeftCacheConfig
std::optional mHostCacheSize;
};
-struct LookaheadDecodingConfig
-{
- LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize);
-
- explicit LookaheadDecodingConfig()
- : LookaheadDecodingConfig(1, 1, 0)
- {
- }
-
- bool operator==(LookaheadDecodingConfig const& other) const;
- [[nodiscard]] std::tuple get() const;
- [[nodiscard]] SizeType32 getWindowSize() const;
- [[nodiscard]] SizeType32 getNgramSize() const;
- [[nodiscard]] SizeType32 getVerificationSetSize() const;
-
- /// @brief return
- std::tuple calculateSpeculativeResource() const;
-
- /// @brief return true when `this` can be executed on resources defined by `that`
- bool isLE(LookaheadDecodingConfig const& that) const;
-
- /// @brief return true when the parameter combination is valid.
- static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept;
-
-private:
- friend class Serialization;
-
- // Number of NGrams in lookahead branch per step.
- SizeType32 mWindowSize;
- // Number of tokens per NGram.
- SizeType32 mNgramSize;
- // Number of NGrams in verification branch per step.
- SizeType32 mVerificationSetSize;
-};
-
/// @brief Configuration class for the speculative decoding.
+// struct LookaheadDecodingConfig
+//{
+// LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize);
+//
+// explicit LookaheadDecodingConfig()
+// : LookaheadDecodingConfig(1, 1, 0)
+// {
+// }
+//
+// bool operator==(LookaheadDecodingConfig const& other) const;
+// [[nodiscard]] std::tuple get() const;
+// [[nodiscard]] SizeType32 getWindowSize() const;
+// [[nodiscard]] SizeType32 getNgramSize() const;
+// [[nodiscard]] SizeType32 getVerificationSetSize() const;
+//
+// /// @brief return
+// std::tuple calculateSpeculativeResource() const;
+//
+// /// @brief return true when `this` can be executed on resources defined by `that`
+// bool isLE(LookaheadDecodingConfig const& that) const;
+//
+// /// @brief return true when the parameter combination is valid.
+// static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept;
+//
+// private:
+// friend class Serialization;
+//
+// // Number of NGrams in lookahead branch per step.
+// SizeType32 mWindowSize;
+// // Number of tokens per NGram.
+// SizeType32 mNgramSize;
+// // Number of NGrams in verification branch per step.
+// SizeType32 mVerificationSetSize;
+// };
+
+/// @brief Configuration class for the decoding.
class DecodingConfig
{
public:
@@ -687,6 +800,29 @@ class DecodingConfig
std::optional mMedusaChoices;
};
+class LogitsPostProcessorConfig
+{
+public:
+ explicit LogitsPostProcessorConfig(std::optional processorMap = std::nullopt,
+ std::optional processorBatched = std::nullopt, bool replicate = true);
+
+ [[nodiscard]] std::optional getProcessorMap() const;
+ [[nodiscard]] std::optional getProcessorBatched() const;
+ [[nodiscard]] bool getReplicate() const;
+
+ void setProcessorMap(LogitsPostProcessorMap const& processorMap);
+ void setProcessorBatched(LogitsPostProcessorBatched const& processorBatched);
+ void setReplicate(bool replicate);
+
+private:
+ /// @brief mapping from post processor names to non-batched post processors
+ std::optional mProcessorMap;
+ /// @brief single batched post processor
+ std::optional mProcessorBatched;
+ /// @brief If set to true, logits post processor will run on all TP ranks in last PP rank
+ bool mReplicate;
+};
+
/// @brief Configuration class for the model executor
class ExecutorConfig
{
@@ -699,11 +835,11 @@ class ExecutorConfig
std::optional maxNumTokens = std::nullopt,
std::optional parallelConfig = std::nullopt,
std::optional const& peftCacheConfig = std::nullopt,
- std::optional logitsPostProcessorMap = std::nullopt,
- std::optional logitsPostProcessorBatched = std::nullopt,
- bool replicateLogitsPostProcessor = true, std::optional decodingConfig = std::nullopt,
- float gpuWeightsPercent = 1, std::optional maxQueueSize = std::nullopt,
- ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig = ExtendedRuntimePerfKnobConfig());
+ std::optional logitsPostProcessorConfig = std::nullopt,
+ std::optional decodingConfig = std::nullopt, float gpuWeightsPercent = 1,
+ std::optional maxQueueSize = std::nullopt,
+ ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig = ExtendedRuntimePerfKnobConfig(),
+ std::optional debugConfig = std::nullopt);
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@@ -717,13 +853,12 @@ class ExecutorConfig
[[nodiscard]] std::optional getMaxNumTokens() const;
[[nodiscard]] std::optional getParallelConfig() const;
[[nodiscard]] std::optional getPeftCacheConfig() const;
- [[nodiscard]] std::optional getLogitsPostProcessorMap() const;
- [[nodiscard]] std::optional getLogitsPostProcessorBatched() const;
- [[nodiscard]] bool getReplicateLogitsPostProcessor() const;
+ [[nodiscard]] std::optional getLogitsPostProcessorConfig() const;
[[nodiscard]] std::optional getDecodingConfig() const;
[[nodiscard]] float getGpuWeightsPercent() const;
[[nodiscard]] std::optional getMaxQueueSize() const;
[[nodiscard]] ExtendedRuntimePerfKnobConfig getExtendedRuntimePerfKnobConfig() const;
+ [[nodiscard]] std::optional getDebugConfig() const;
void setMaxBeamWidth(SizeType32 maxBeamWidth);
void setMaxBatchSize(SizeType32 maxBatchSize);
@@ -737,13 +872,12 @@ class ExecutorConfig
void setBatchingType(BatchingType batchingType);
void setParallelConfig(ParallelConfig const& parallelConfig);
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
- void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
- void setLogitsPostProcessorBatched(LogitsPostProcessorBatched const& logitsPostProcessorBatched);
- void setReplicateLogitsPostProcessor(bool const replicateLogitsPostProcessor);
+ void setLogitsPostProcessorConfig(LogitsPostProcessorConfig const& logitsPostProcessorConfig);
void setDecodingConfig(DecodingConfig const& decodingConfig);
void setGpuWeightsPercent(float const& gpuWeightsPercent);
void setMaxQueueSize(std::optional const& maxQueueSize);
- void setExtendedRuntimePerfKnobConfig(ExtendedRuntimePerfKnobConfig const& ExtendedRuntimePerfKnobConfig);
+ void setExtendedRuntimePerfKnobConfig(ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig);
+ void setDebugConfig(DebugConfig const& debugConfig);
private:
friend class Serialization;
@@ -781,10 +915,9 @@ class ExecutorConfig
/// @brief The parallel execution configuration.
std::optional mParallelConfig;
std::optional mPeftCacheConfig;
- std::optional mLogitsPostProcessorMap;
- std::optional mLogitsPostProcessorBatched;
- /// @brief If set to true, logits post processor will run on all TP ranks in last PP rank
- bool mReplicateLogitsPostProcessor;
+
+ /// @brief Logits post processor configuration
+ std::optional mLogitsPostProcessorConfig;
/// @brief Decoding configuration.
std::optional mDecodingConfig;
@@ -797,6 +930,9 @@ class ExecutorConfig
/// @brief Config for perf knobs that can be set in runtime.
ExtendedRuntimePerfKnobConfig mExtendedRuntimePerfKnobConfig;
+
+ /// @brief Debugging configuration.
+ std::optional mDebugConfig;
};
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h
index 4dace9acb..5f29afde2 100644
--- a/cpp/include/tensorrt_llm/executor/serialization.h
+++ b/cpp/include/tensorrt_llm/executor/serialization.h
@@ -53,6 +53,11 @@ class Serialization
static void serialize(LoraConfig const& config, std::ostream& os);
[[nodiscard]] static size_t serializedSize(LoraConfig const& config);
+ // ContextPhaseParams
+ [[nodiscard]] static ContextPhaseParams deserializeContextPhaseParams(std::istream& is);
+ static void serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os);
+ [[nodiscard]] static size_t serializedSize(ContextPhaseParams const& contextPhaseParams);
+
// Request
[[nodiscard]] static Request deserializeRequest(std::istream& is);
static void serialize(Request const& request, std::ostream& os);
@@ -122,6 +127,11 @@ class Serialization
static void serialize(DecodingConfig const& decodingConfig, std::ostream& os);
static size_t serializedSize(DecodingConfig const& decodingConfig);
+ // DebugConfig
+ static DebugConfig deserializeDebugConfig(std::istream& is);
+ static void serialize(DebugConfig const& debugConfig, std::ostream& os);
+ static size_t serializedSize(DebugConfig const& debugConfig);
+
// ExecutorConfig
static ExecutorConfig deserializeExecutorConfig(std::istream& is);
static void serialize(ExecutorConfig const& executorConfig, std::ostream& os);
diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h
index 4861a39aa..1dce228cd 100644
--- a/cpp/include/tensorrt_llm/executor/types.h
+++ b/cpp/include/tensorrt_llm/executor/types.h
@@ -348,6 +348,22 @@ struct RequestStatsPerIteration
std::vector requestStats;
};
+/// @brief The reason why the model stopped generating tokens for a request.
+enum class FinishReason
+{
+ /// @brief The request is not finished.
+ kNOT_FINISHED = 0,
+
+ /// @brief The request finished because the end id was generated.
+ kEND_ID = 1,
+
+ /// @brief The request finished because a stop word was generated.
+ kSTOP_WORDS = 2,
+
+ /// @brief The request finished because the maximum number of tokens was reached.
+ kLENGTH = 3,
+};
+
/// @brief mode of the decoder
class DecodingMode
{
diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h
index 6c4d7c805..4f92cbd06 100644
--- a/cpp/include/tensorrt_llm/runtime/decodingInput.h
+++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h
@@ -80,7 +80,7 @@ class DecodingInput
batchSlots; //!< [batchSize], address map of the linear batch id to to the seq slots, int32_t, pinned
// optional parameters
- TensorConstPtr finished; //!< [batchSize, beamWidth], finished states at current iteration.
+ TensorConstPtr finishReasons; //!< [batchSize, beamWidth], finished states at current iteration.
//!< If true for some request, the decoding step of it is skipped, on gpu
TensorConstPtr
sequenceLimitLength; //!< [batchSize], on gpu. The maximum sequence length for each sequence in the batch.
@@ -129,9 +129,16 @@ class DecodingInput
TensorConstPtr seqSlots; //!< [batchSize]
};
+ struct LookaheadInputs
+ {
+ TensorPtr tokensPerStep;
+ };
+
std::optional medusaInputs;
std::optional explicitDraftTokensInputs;
+
+ std::optional lookaheadInputs;
};
} // namespace tensorrt_llm::runtime
diff --git a/cpp/include/tensorrt_llm/runtime/decodingOutput.h b/cpp/include/tensorrt_llm/runtime/decodingOutput.h
index c07ae057b..146db40a4 100644
--- a/cpp/include/tensorrt_llm/runtime/decodingOutput.h
+++ b/cpp/include/tensorrt_llm/runtime/decodingOutput.h
@@ -20,9 +20,15 @@
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
#include "tensorrt_llm/runtime/iTensor.h"
+#include "tensorrt_llm/runtime/lookaheadBuffers.h"
#include
#include
+namespace tensorrt_llm::batch_manager
+{
+class LookaheadDecodingBuffers;
+}
+
namespace tensorrt_llm::runtime
{
class DecodingOutput
@@ -81,10 +87,10 @@ class DecodingOutput
// Vector of views on newTokensSteps for each token
// optional parameters
- TensorPtr finished; // [BS, BM], set to true by decoding if any of the stop conditions are met or if
- // DecodingInput.finished is true. In beam search and to determine whether to stop according to
- // DecodingInput.sequenceLimitLength
- TensorPtr finishedSum; // [BS], the sum of finished sequences per request, in pinned memory
+ TensorPtr finishReasons; // [BS, BM], set to FinishedState by decoding if any of the stop conditions are met or if
+ // DecodingInput.finished is true. In beam search and to determine whether to stop
+ // according to DecodingInput.sequenceLimitLength
+ TensorPtr finishedSum; // [BS], the sum of finished sequences per request, in pinned memory
// mandatory parameters for beam search
TensorPtr logProbs; // [BS, BM, MSL], must be float*
@@ -110,6 +116,8 @@ class DecodingOutput
std::optional speculativeDecodingOutputs;
std::optional explicitDraftTokensBuffers;
+
+ std::optional lookaheadOutputs;
};
} // namespace tensorrt_llm::runtime
diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoder.h b/cpp/include/tensorrt_llm/runtime/gptDecoder.h
index 1753f24c6..8b0dc994b 100644
--- a/cpp/include/tensorrt_llm/runtime/gptDecoder.h
+++ b/cpp/include/tensorrt_llm/runtime/gptDecoder.h
@@ -16,11 +16,13 @@
#pragma once
+#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
+#include "tensorrt_llm/runtime/request.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include
@@ -52,7 +54,8 @@ class IGptDecoder
virtual ~IGptDecoder() = default;
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots,
- std::optional const& output = std::nullopt)
+ std::optional const& output = std::nullopt,
+ std::optional const> const& requests = std::nullopt)
= 0;
virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0;
@@ -95,7 +98,8 @@ class GptDecoder : public virtual IGptDecoder
std::shared_ptr speculativeDecodingModule = nullptr);
void setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots,
- std::optional const& output = std::nullopt) override;
+ std::optional const& output = std::nullopt,
+ std::optional const> const& requests = std::nullopt) override;
void forwardAsync(DecodingOutput& output, DecodingInput const& input) override;
diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
index bbef3ae7a..358826f50 100644
--- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
+++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
@@ -54,6 +54,8 @@ class GptDecoderBatched : public IGptDecoderBatched
void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) override;
+ void setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) override;
+
void newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
@@ -77,6 +79,12 @@ class GptDecoderBatched : public IGptDecoderBatched
return {mFinished.begin(), mFinished.begin() + mActualBatchSize};
}
+ //! @returns [batchSize, beamWidth], FinishedState value, on gpu
+ [[nodiscard]] TensorPtr getFinishReasons() const override
+ {
+ return ITensor::slice(mJointDecodingOutput->finishReasons, 0, mActualBatchSize);
+ }
+
//! @param batchIdx index of the batch
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
//! padding for request `batchIdx`, on gpu. In case of beam search, contains the ungathered data.
@@ -242,6 +250,9 @@ class GptDecoderBatched : public IGptDecoderBatched
//! @brief Setup buffers for speculative decoding.
void setupSpeculativeDecoding(ModelConfig const& modelConfig);
+ //! @brief Setup buffers for lookahead decoding.
+ void setupLookahead(ModelConfig const& modelConfig);
+
//! @brief Setups decoder internal tensors for new speculative decoding request
void newRequestSpeculativeDecoding(
SizeType32 batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
diff --git a/cpp/include/tensorrt_llm/runtime/iBuffer.h b/cpp/include/tensorrt_llm/runtime/iBuffer.h
index 46fb3972e..5675de5ad 100644
--- a/cpp/include/tensorrt_llm/runtime/iBuffer.h
+++ b/cpp/include/tensorrt_llm/runtime/iBuffer.h
@@ -18,6 +18,7 @@
#include "tensorrt_llm/common/arrayView.h"
#include "tensorrt_llm/common/dataType.h"
+#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include
@@ -323,6 +324,12 @@ struct TRTDataType
static constexpr auto value = TRTDataType::value;
};
+template <>
+struct TRTDataType
+{
+ static constexpr auto value = TRTDataType::value;
+};
+
template <>
struct TRTDataType
{
diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
index 4495c102a..11464f80e 100644
--- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
+++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
@@ -21,6 +21,7 @@
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h"
+#include "tensorrt_llm/runtime/lookaheadBuffers.h"
#include "tensorrt_llm/runtime/request.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
@@ -100,6 +101,9 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder
//! @brief Setup buffers for ExplicitDraftTokens decoding.
virtual void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) = 0;
+ //! @brief Setup buffers for Lookahead decoding.
+ virtual void setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) = 0;
+
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
@@ -135,6 +139,9 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder
//! @returns [batchSize (actual)], marks finished requests (per batch)
[[nodiscard]] virtual std::vector getFinished() const = 0;
+ //! @returns [batchSize, beamWidth], FinishedState value, on gpu
+ [[nodiscard]] virtual TensorPtr getFinishReasons() const = 0;
+
//! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu
[[nodiscard]] virtual TensorPtr getCumLogProbs() const = 0;
diff --git a/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h b/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h
index f5e0f142d..4719e4902 100644
--- a/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h
+++ b/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h
@@ -29,6 +29,11 @@
#include
+namespace tensorrt_llm::batch_manager
+{
+struct DecoderBuffers;
+}
+
namespace tensorrt_llm::runtime
{
diff --git a/cpp/include/tensorrt_llm/runtime/iTensor.h b/cpp/include/tensorrt_llm/runtime/iTensor.h
index 04937fc0d..faa14ca90 100644
--- a/cpp/include/tensorrt_llm/runtime/iTensor.h
+++ b/cpp/include/tensorrt_llm/runtime/iTensor.h
@@ -50,6 +50,7 @@ class ITensor : virtual public IBuffer
using SharedConstPtr = std::shared_ptr;
using Shape = nvinfer1::Dims;
using DimType64 = std::remove_reference_t;
+ using TensorMap = runtime::StringPtrMap;
static_assert(std::is_same_v, "This version of TRT-LLM requires TensorRT 10.0 or later.");
diff --git a/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h b/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h
new file mode 100644
index 000000000..56504bd94
--- /dev/null
+++ b/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "tensorrt_llm/executor/executor.h"
+#include "tensorrt_llm/runtime/iTensor.h"
+#include "tensorrt_llm/runtime/modelConfig.h"
+#include "tensorrt_llm/runtime/tllmRuntime.h"
+#include "tensorrt_llm/runtime/worldConfig.h"
+
+namespace tensorrt_llm::runtime
+{
+
+class LookaheadDecodingBuffers
+{
+public:
+ using SizeType32 = runtime::SizeType32;
+ using TensorPtr = runtime::ITensor::SharedPtr;
+ using ITensor = tensorrt_llm::runtime::ITensor;
+ LookaheadDecodingBuffers(
+ SizeType32 maxNumSequences, SizeType32 maxTokensPerStep, runtime::BufferManager const& bufferManager);
+ TensorPtr generationLengths; // [mMaxNumRequests]
+ TensorPtr positionOffsets; // [mMaxNumRequests, maxTokensPerStep]
+ TensorPtr packedMasks; // [mMaxNumRequests, maxTokensPerStep, divUp(maxTokensPerStep, 32)]
+ TensorPtr positionIds;
+};
+
+class LookaheadRuntimeBuffers
+{
+public:
+ using SizeType32 = tensorrt_llm::runtime::SizeType32;
+ using ITensor = tensorrt_llm::runtime::ITensor;
+ using TensorPtr = runtime::ITensor::SharedPtr;
+ using TensorMap = runtime::StringPtrMap;
+
+ LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::BufferManager const& manager,
+ runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
+ executor::DecodingConfig const& decodingConfig, runtime::TllmRuntime const& runtime);
+
+ void setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSequences, runtime::ITensor const& requestTypes,
+ ITensor const& seqSlots, LookaheadDecodingBuffers const& decoderLookaheadBuffers,
+ runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
+ runtime::WorldConfig const& worldConfig) const;
+
+ void reshape(SizeType32 numCtxSequences, SizeType32 numGenSequences, SizeType32 tokensPerStep);
+
+ void insertInputTensors(
+ TensorMap& inputBuffers, TensorMap& outputBuffers, runtime::WorldConfig const& worldConfig) const;
+
+public:
+ TensorPtr packedMasksDevice; // [forwardBatchSize, tokensPerStep, numPackedMasks], on gpu
+ TensorPtr generationLengthsDevice; // [forwardBatchSize], on gpu
+ TensorPtr positionOffsetsDevice; // [forwardBatchSize, tokensPerStep], on gpu
+ TensorPtr positionIdsDevice; // [forwardBatchSize, tokensPerStep], on gpu
+
+ TensorPtr packedMaskHost;
+ TensorPtr generationLengthsHost;
+ TensorPtr positionOffsetsHost;
+ TensorPtr positionIdsHost;
+
+ TensorPtr packedMaskHostCopy;
+ TensorPtr generationLengthsHostCopy;
+ TensorPtr positionOffsetsHostCopy;
+ TensorPtr positionIdsHostCopy;
+
+ TensorPtr batchSlotsHostCopy;
+};
+
+} // namespace tensorrt_llm::runtime
diff --git a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
index e3103ea91..8226c411c 100644
--- a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
+++ b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
@@ -108,8 +108,7 @@ class SpeculativeDecodingMode
[[nodiscard]] bool constexpr needsDecoderPrologue() const
{
- // Potentially lookahead should require it too.
- return anyBitSet(kExplicitDraftTokens);
+ return anyBitSet(kExplicitDraftTokens | kLookaheadDecoding);
}
using UnderlyingType = std::uint8_t;
diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt
index df6e62c84..a8b6f276a 100644
--- a/cpp/tensorrt_llm/CMakeLists.txt
+++ b/cpp/tensorrt_llm/CMakeLists.txt
@@ -22,7 +22,7 @@ set(API_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cutlass_extensions/include
${API_INCLUDE_DIR})
-if(ENABLE_MULTI_DEVICE EQUAL 1)
+if(ENABLE_MULTI_DEVICE)
find_package(MPI REQUIRED)
message(STATUS "Using MPI_C_INCLUDE_DIRS: ${MPI_C_INCLUDE_DIRS}")
message(STATUS "Using MPI_C_LIBRARIES: ${MPI_C_LIBRARIES}")
@@ -269,7 +269,7 @@ set(TRTLLM_LINK_LIBS
runtime_src
${DECODER_SHARED_TARGET})
-if(ENABLE_MULTI_DEVICE EQUAL 1)
+if(ENABLE_MULTI_DEVICE)
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} ${MPI_C_LIBRARIES} ${NCCL_LIB})
endif()
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
index 1781be133..f9fe9ff33 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:84a6439038eb0a7d2913c3fe051684ab7779e42635c074bc6df30bfc46807929
-size 4358834
+oid sha256:460b75a97c0de65941839ccd5e0458cf5929574b9345b3cb723a695ae5a056e0
+size 4404838
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
index f90fc6d03..555117707 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:1b321340ea622b28ed7d38fe18ff7707091d2efa414af40f6db516959a4fa2f4
-size 4466694
+oid sha256:645bbbad2c38b573df7c6e56588a6728d356a58444ac7c2f881d773faaca7593
+size 4516944
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
index 2eff4c1d7..484d3274b 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-99062b35da1cb99df9c79368da1ff9de libtensorrt_llm_batch_manager_static.a
-72e6a44f7636bb6d48b016db1c62cdc7 libtensorrt_llm_batch_manager_static.pre_cxx11.a
-90dd0ad72954a5cc7cc2e298495e784906fe49b1 commit
\ No newline at end of file
+a348613d480961aa14d4e77939be8a34 libtensorrt_llm_batch_manager_static.a
+317ec85caec48184c9c8b9cbd3eb44b1 libtensorrt_llm_batch_manager_static.pre_cxx11.a
+49402939d007b39393cabaa8fe96c110d16f5b35 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
index a2560a5a6..3e658d50e 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:1df6689f399313cac54ec1f4422975d6060957edbc698468a96f3a2c2a6542bc
-size 4221016
+oid sha256:a785c4459bdb4a7dad9df0c832211f26f699a331ce0b2b9516e7a666f83b895a
+size 4272894
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
index f9251acc4..26532800d 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:b118f1bbccd6fe8e5e001916f4de19ec34886e7dcc288b91dcdadf5500f0eb50
-size 4205756
+oid sha256:9a6b98589222f8bf8e82f122110cb1824b1728646bad45a41e8b9ada632539dc
+size 4248190
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
index f69634e79..e3f219a93 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:47162c3eaab9b6f60bca8927eef0423c5521d7750f112b87a2f9175156ccb6cd
-size 24807904
+oid sha256:7daa6c306a2fb738bbe8b3d30324691c83d59aa933c79b0e48342976edb4e356
+size 25540884
diff --git a/cpp/tensorrt_llm/common/cublasMMWrapper.cpp b/cpp/tensorrt_llm/common/cublasMMWrapper.cpp
index 224ca77f6..27f179236 100644
--- a/cpp/tensorrt_llm/common/cublasMMWrapper.cpp
+++ b/cpp/tensorrt_llm/common/cublasMMWrapper.cpp
@@ -37,16 +37,12 @@ CublasMMWrapper::CublasMMWrapper(std::shared_ptr cublasHandle,
{
}
-CublasMMWrapper::~CublasMMWrapper()
-{
- mMutex = nullptr;
-}
+CublasMMWrapper::~CublasMMWrapper() {}
CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
: mCublasHandle(wrapper.mCublasHandle)
, mCublasLtHandle(wrapper.mCublasLtHandle)
, mStream(wrapper.mStream)
- , mMutex(wrapper.mMutex)
{
}
@@ -135,8 +131,6 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, i
half h_alpha = (half) (f_alpha);
half h_beta = (half) (f_beta);
- std::lock_guard lock(*mMutex);
-
// TODO: default cublas libs
usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3);
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
@@ -179,8 +173,6 @@ void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperati
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
- std::lock_guard lock(*mMutex);
-
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta);
@@ -198,7 +190,6 @@ void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperati
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
- std::lock_guard lock(*mMutex);
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta);
diff --git a/cpp/tensorrt_llm/common/cublasMMWrapper.h b/cpp/tensorrt_llm/common/cublasMMWrapper.h
index 9418b0faf..21062f2f2 100644
--- a/cpp/tensorrt_llm/common/cublasMMWrapper.h
+++ b/cpp/tensorrt_llm/common/cublasMMWrapper.h
@@ -21,7 +21,6 @@
#include
#include
#include