diff --git a/README.md b/README.md
index 0b88d361f..0da5128a5 100644
--- a/README.md
+++ b/README.md
@@ -7,8 +7,8 @@ TensorRT-LLM
[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/)
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.5.1-green)](https://developer.nvidia.com/cuda-downloads)
-[![trt](https://img.shields.io/badge/TRT-10.3.0-green)](https://developer.nvidia.com/tensorrt)
-[![version](https://img.shields.io/badge/release-0.13.0.dev-green)](./tensorrt_llm/version.py)
+[![trt](https://img.shields.io/badge/TRT-10.4.0-green)](https://developer.nvidia.com/tensorrt)
+[![version](https://img.shields.io/badge/release-0.14.0.dev-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
[Architecture](./docs/source/architecture/overview.md) | [Results](./docs/source/performance/perf-overview.md) | [Examples](./examples/) | [Documentation](./docs/source/)
@@ -17,6 +17,12 @@ TensorRT-LLM
## Latest News
+* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
+[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
+
+
+
+
* [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup
[➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link)
diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp
index 45632350c..b901a17bc 100644
--- a/benchmarks/cpp/gptManagerBenchmark.cpp
+++ b/benchmarks/cpp/gptManagerBenchmark.cpp
@@ -159,6 +159,8 @@ struct BenchmarkParams
std::optional
sinkTokenLength{std::nullopt};
bool multiBlockMode{true};
bool enableContextFMHAFP32Acc{false};
+ bool cudaGraphMode{false};
+ SizeType32 cudaGraphCacheSize{0};
// lora / peft params
std::optional loraDir{std::nullopt};
@@ -470,7 +472,38 @@ class Recorder
mRequestBenchInfos[requestId].firstTokenSeen = true;
}
- mRequestBenchInfos[requestId].outputLength += 1;
+ mRequestBenchInfos[requestId].decodingIter += 1;
+ }
+
+ void recordToken(uint64_t requestId, std::list const& responseTensors)
+ {
+ int32_t outputLength = 1;
+ for (auto& tensor : responseTensors)
+ {
+ if (tensor.name == inference_request::kSequenceLengthTensorName)
+ {
+ // Tensor of shape nBeams, and we only need the first one
+ outputLength = *(bufferCast(*(tensor.tensor)));
+ break;
+ }
+ }
+
+ mRequestBenchInfos[requestId].outputLength += outputLength;
+ this->recordToken(requestId);
+ }
+
+ void recordToken(uint64_t requestId, texec::Response const& response)
+ {
+ auto outputTokenIds = response.getResult().outputTokenIds;
+
+ int32_t outputLength = 1;
+ for (auto const& beam : outputTokenIds)
+ {
+ outputLength = std::max(static_cast(beam.size()), outputLength);
+ }
+
+ mRequestBenchInfos[requestId].outputLength += outputLength;
+ this->recordToken(requestId);
}
void recordEnd(uint64_t requestId, std::list const& responseTensors, bool hasError)
@@ -500,7 +533,7 @@ class Recorder
}
else
{
- this->recordToken(requestId);
+ this->recordToken(requestId, responseTensors);
}
}
@@ -532,7 +565,7 @@ class Recorder
}
else
{
- this->recordToken(requestId);
+ this->recordToken(requestId, response);
}
}
}
@@ -821,8 +854,9 @@ class ExecutorServer
benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks);
texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8,
std::nullopt, benchmarkParams.loraHostCacheSize);
- texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(
- benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
+ texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
+ benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode,
+ benchmarkParams.cudaGraphCacheSize);
texec::ExecutorConfig executorConfig(
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
executorConfig.setGpuWeightsPercent(benchmarkParams.gpuWeightsPercent);
@@ -940,7 +974,7 @@ class ExecutorServer
{
if (!warmup && !response.hasError())
{
- mRecorder->recordToken(reqId);
+ mRecorder->recordToken(reqId, response);
}
}
}
@@ -1228,7 +1262,7 @@ class GptServer
{
if (errMsg.empty())
{
- mRecorder->recordToken(requestId);
+ mRecorder->recordToken(requestId, response_tensors);
}
}
}
@@ -1458,8 +1492,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
- optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(
- benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
+ optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
+ benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, benchmarkParams.cudaGraphCacheSize);
auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
auto const worldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
@@ -1895,7 +1929,8 @@ int main(int argc, char* argv[])
options.add_options()("return_generation_logits", "Whether to return generation logits.",
cxxopts::value()->default_value("false"));
- options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
+ options.add_options()("scheduler_policy",
+ "Choose scheduler policy between max_utilization/guaranteed_no_evict/static_batch.",
cxxopts::value()->default_value("guaranteed_no_evict"));
options.add_options()("first_batch_delay",
@@ -1946,6 +1981,12 @@ int main(int argc, char* argv[])
cxxopts::value()->default_value("true"));
options.add_options()(
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value());
+ options.add_options()("cuda_graph_mode", "When enabled, inference is executed with cuda graph.",
+ cxxopts::value()->default_value("false"));
+ options.add_options()("cuda_graph_cache_size",
+ "Specify how many cuda graphs are cached in the runtime. Larger cache gives better perf, but consumes more GPU "
+ "memory.",
+ cxxopts::value()->default_value("0"));
options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation",
cxxopts::value()->default_value("false"));
@@ -2131,6 +2172,12 @@ int main(int argc, char* argv[])
// Argument: enable_context_fmha_fp32_acc
benchmarkParams.enableContextFMHAFP32Acc = result["enable_context_fmha_fp32_acc"].as();
+ // Argument: cuda_graph_mode
+ benchmarkParams.cudaGraphMode = result["cuda_graph_mode"].as();
+
+ // Argument: cuda_graph_mode
+ benchmarkParams.cudaGraphCacheSize = result["cuda_graph_cache_size"].as();
+
std::optional padId;
// Argument: Padding token id
if (result.count("pad_id"))
@@ -2168,6 +2215,10 @@ int main(int argc, char* argv[])
{
capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT;
}
+ else if (capacitySchedulerPolicyArg == "static_batch")
+ {
+ capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kSTATIC_BATCH;
+ }
else
{
TLLM_LOG_ERROR("Unexpected scheduler policy: " + capacitySchedulerPolicyArg);
diff --git a/benchmarks/python/gpt_benchmark.py b/benchmarks/python/gpt_benchmark.py
index 04ba2ab0f..ce06c9f9f 100644
--- a/benchmarks/python/gpt_benchmark.py
+++ b/benchmarks/python/gpt_benchmark.py
@@ -80,7 +80,7 @@ def __init__(self, args, batch_sizes, in_out_lens, gpu_weights_percents,
kv_cache_type = KVCacheType.CONTINUOUS
if hasattr(self, 'kv_cache_type'):
- kv_cache_type = self.kv_cache_type
+ kv_cache_type = KVCacheType(self.kv_cache_type)
else:
if hasattr(self, 'paged_kv_cache'):
kv_cache_type = KVCacheType.PAGED if self.paged_kv_cache == True else KVCacheType.CONTINUOUS
diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
index 5f26170e9..959e2c39c 100644
--- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
+++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
@@ -282,14 +282,37 @@ class GenerationRequest
std::vector> mCacheBlockIds;
};
-// BlockManager manages overall metadata of KVCacheBlocks in a layer of the
-// network. Layers are expected to be symmetric, so the metadata can be
-// reused for all layers of the network.
-// The array of cache blocks for a layer is called a pool.
-// Each pool has shape [max_blocks, 2, num_heads, tokens_per_block, head_size].
-// Size per block and number of blocks per pool are pre-determined and set in
-// constructor. These should not be changed after.
-// Block shape is [2, num_heads, tokens_per_block, head_size].
+// attach metadata to a pool pointer
+class KVCacheBlockPool
+{
+public:
+ SizeType32 numKvHeads;
+ SizeType32 numLayers;
+ SizeType32 blockSize;
+
+ // Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
+ runtime::ITensor::SharedPtr primaryPtr;
+ runtime::ITensor::SharedPtr secondaryPtr;
+
+ KVCacheBlockPool(SizeType32 numKvHeads, SizeType32 numLayers, SizeType32 blockSize,
+ runtime::ITensor::SharedPtr primaryPtr = nullptr, runtime::ITensor::SharedPtr secondaryPtr = nullptr)
+ : numKvHeads(numKvHeads)
+ , numLayers(numLayers)
+ , blockSize(blockSize)
+ , primaryPtr(std::move(primaryPtr))
+ , secondaryPtr(std::move(secondaryPtr))
+ {
+ }
+};
+
+// The BlockManager manages the metadata of KVCacheBlocks.
+// It manages multiple arrays of cache blocks called pools.
+// Layers with the same number of kv heads are grouped under the same pool.
+// Each pool has shape [max_blocks, num_layers, 2, num_kv_heads, tokens_pre_block, head_size], where num_layers refers
+// to the number of layers with the same num_kv_heads that share that pool.
+// The metadata of KVCacheBlocks is shared between layers, so each block spans all of the managed pool - an allocated
+// block matches some chunk of memory in each pool. The shape of the chunk in every pool is [2, num_kv_heads,
+// tokens_per_block, head_size]. The size per block and number of blocks are pre-determined and set in the constructor.
// BlockManager maintains a list of free blocks at any time.
// Alloc pops off the block at the front, and Free pushes it back to the vector.
// BlockManager maintains a vector of lists of seqSlotIdx to allocated blocks
@@ -300,7 +323,7 @@ class BlockManager
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
- explicit BlockManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead,
+ explicit BlockManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
std::shared_ptr stream, bool onboardBlocks, CacheType cacheType = CacheType::kSELF);
@@ -338,7 +361,7 @@ class BlockManager
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept
{
- return mFreePrimaryBlocks.size();
+ return mFreePrimaryBlocksSize;
}
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
@@ -381,21 +404,26 @@ class BlockManager
return mTokensPerBlock;
}
- //! \brief Get size of one K/V cache block in one layer.
- //! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead]
- [[nodiscard]] SizeType32 getBlockSize() const
+ //! \brief Get size of one K/V cache block in one layer for the specified pool.
+ //! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead] in the specified pool.
+ [[nodiscard]] SizeType32 getBlockSize(SizeType32 poolIdx) const
+ {
+ return mPools.at(poolIdx).blockSize;
+ }
+
+ [[nodiscard]] SizeType32 getNumPools() const noexcept
{
- return mBlockSize;
+ return mPools.size();
}
- [[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool() const noexcept
+ [[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const
{
- return mPrimaryPool;
+ return mPools.at(poolIdx).primaryPtr;
}
- [[nodiscard]] runtime::ITensor::SharedPtr getSecondaryPool() const noexcept
+ [[nodiscard]] runtime::ITensor::SharedPtr getSecondaryPool(SizeType32 poolIdx) const
{
- return mSecondaryPool;
+ return mPools.at(poolIdx).secondaryPtr;
}
[[nodiscard]] SizeType32 getNumLayers() const
@@ -403,10 +431,32 @@ class BlockManager
return mNumLayers;
}
+ [[nodiscard]] SizeType32 getNumPrimaryBlocks() const
+ {
+ return mNumPrimaryBlocks;
+ }
+
+ [[nodiscard]] SizeType32 getNumSecondaryBlocks() const
+ {
+ return mNumSecondaryBlocks;
+ }
+
+ [[nodiscard]] CacheType getCacheType() const
+ {
+ return mCacheType;
+ }
+
+ [[nodiscard]] SizeType32 getLayerPoolIdx(SizeType32 layerIdx) const
+ {
+ return mLayerToPool.at(layerIdx);
+ }
+
//! \brief Get index in pool to K or V block.
//! \param blockId the blockId as returned by getBlockId()
//! \param fieldIdx either 0 (K) or 1 (V),
- [[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(KVCacheBlock::IdType blockId, SizeType32 fieldIdx) const;
+ //! \param poolIdx the index of the pool for which the index is calculated (each pool has different strides)
+ [[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(
+ KVCacheBlock::IdType blockId, SizeType32 fieldIdx, SizeType32 poolIdx) const;
//! \brief Bring offloaded block from secondary to primary memory.
//! \details Does nothing of block is already in primary memory.
@@ -451,7 +501,8 @@ class BlockManager
void claimLeafBlock(KVCacheBlock& block);
//! \brief Compute pointer to raw KV block (K & V, all layers).
- [[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(std::shared_ptr block) const;
+ [[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(
+ std::shared_ptr block, SizeType32 poolIdx) const;
//! \brief Copy content of src block to dst.
void copyBlock(BlockPtr src, BlockPtr dst);
@@ -460,23 +511,30 @@ class BlockManager
// Number of blocks in pools
SizeType32 mNumPrimaryBlocks;
SizeType32 mNumSecondaryBlocks;
- // List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory,
- // we maintain separate queues for these.
+ // List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory.
+ // We maintain separate queues for these.
+ // We cache size of each queue instead of calling std::list::size, because size is O(N) function.
+ SizeType32 mFreePrimaryBlocksSize;
+ SizeType32 mFreeSecondaryBlocksSize;
FreeBlocksQueue mFreePrimaryBlocks;
FreeBlocksQueue mFreeSecondaryBlocks;
// List of allocated blocks for each sequences
std::vector> mAllocatedBlocksPerSeq;
- // Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
- runtime::ITensor::SharedPtr mPrimaryPool;
- runtime::ITensor::SharedPtr mSecondaryPool;
+
+ // Pool per unique numKvHeads in the model
+ std::vector mPools;
+ // Matching of model layers to their pools
+ std::vector mLayerToPool;
+
// Whether offloaded blocks should be onboarded before reuse.
bool mOnboardBlocks;
// Buffer manager
runtime::BufferManager mBufferManager;
+
+ // Size of a single KV heads
+ SizeType32 mSizePerHead;
// Number of layers
SizeType32 mNumLayers;
- // Volume of [numKvHeads, tokensPerBlock, sizePerHead]
- SizeType32 mBlockSize;
// Used to keep track of number of free blocks during scheduling
SizeType32 mSchedulingNumFreeBlocks;
// Number of tokens per one block
@@ -502,12 +560,18 @@ class KVCacheManager
using CudaStreamPtr = std::shared_ptr;
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
- KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
+ KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
CudaStreamPtr stream, bool enableBlockReuse = false, bool onboardBlocks = true,
CacheType cacheType = CacheType::kSELF);
+ KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
+ SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
+ SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
+ CudaStreamPtr stream, bool enableBlockReuse = true, bool onboardBlocks = true,
+ CacheType cacheType = CacheType::kSELF);
+
void allocatePools(nvinfer1::DataType dtype, bool useUvm = false);
void startScheduling();
@@ -577,11 +641,11 @@ class KVCacheManager
/// @return The number of blocks
[[nodiscard]] SizeType32 getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
- /// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for
- /// maxNewTokens)
+ /// @brief Function that computes the number of KV cache blocks remaining to advance a request to completion (i.e.
+ /// for maxNewTokens); the allocated blocks are excluded
/// @param req The request for which we need to calculate the number of needed KV cache blocks
/// @return The number of blocks
- [[nodiscard]] SizeType32 getNeededBlocksToCompletion(LlmRequest const& req) const;
+ [[nodiscard]] SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req) const;
void addContextTokens(SizeType32 seqSlotIdx, SizeType32 numTokens);
@@ -603,6 +667,8 @@ class KVCacheManager
[[nodiscard]] runtime::ITensor::UniquePtr getBlockPoolPointers() const;
+ [[nodiscard]] runtime::ITensor::UniquePtr getLayerToPoolMapping() const;
+
void getBlockOffsetsOfBatch(
runtime::ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const;
@@ -610,18 +676,16 @@ class KVCacheManager
SizeType32 copyBlockOffsets(
runtime::ITensor& output, SizeType32 outputSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
- // Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
- [[nodiscard]] static SizeType32 constexpr calculatePageSize(tensorrt_llm::runtime::ModelConfig const& modelConfig)
- {
- return 2 * modelConfig.getNbKvHeads() * modelConfig.getTokensPerBlock() * modelConfig.getSizePerHead();
- }
-
- // numLayers * 2 * numKvHeads * sizePerHead
- [[nodiscard]] static SizeType32 constexpr calculateCacheSizePerToken(
+ // Sum of numLayers * 2 * numKvHeads * sizePerHead for each pool
+ [[nodiscard]] static SizeType32 calculateCacheSizePerToken(
tensorrt_llm::runtime::ModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
{
- return modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
- * modelConfig.getSizePerHead();
+ // NOTE: We expect the initialization of modelConfig to have already taken the tp size into account and do not
+ // address it here
+ // consider only local layers for the calculation
+ return modelConfig.getSumLocalKvHeads(
+ worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank())
+ * 2 * modelConfig.getSizePerHead();
}
[[nodiscard]] static std::tuple const calculateMaxNumBlocks(KvCacheConfig const& config,
@@ -640,7 +704,7 @@ class KVCacheManager
[[nodiscard]] bool isCrossKv() const
{
- return mCacheType == CacheType::kCROSS;
+ return mBlockManager.getCacheType() == CacheType::kCROSS;
}
//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vector.
@@ -691,8 +755,6 @@ class KVCacheManager
runtime::ITensor::SharedPtr mSequenceBlockIndices;
// Whether to cache KV pages for reuse
bool mEnableBlockReuse;
- // KV cache type (self or cross)
- CacheType mCacheType;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
index 81b91e24a..1738cc428 100644
--- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
+++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
@@ -91,9 +91,9 @@ class BlockIterator
};
[[nodiscard]] BlockIterator getBlockBeginIt(
- KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
+ KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam, SizeType32 poolIdx);
[[nodiscard]] BlockIterator getBlockEndIt(
- KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
+ KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam, SizeType32 poolIdx);
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
index 46c808de0..fed9dd21e 100644
--- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
+++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
@@ -400,7 +400,8 @@ class GenericLlmRequest
TLLM_CHECK_WITH_INFO(mInputTokenExtraIds.has_value() && mInputTokenExtraIds.value(),
"Input token extra ids must be provided when enabling kv cache reuse with prompt table");
TLLM_CHECK_WITH_INFO(mInputTokenExtraIds.value()->size() == static_cast(mOrigPromptLen),
- "inputTokenExtraIds vector size must be the same as input token vector size.");
+ "inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).",
+ mInputTokenExtraIds.value()->size(), static_cast(mOrigPromptLen));
}
}
@@ -411,7 +412,7 @@ class GenericLlmRequest
/// @brief Get the params of the context
/// @return The params of the context
- std::optional const& getContextPhaseParams() const noexcept
+ [[nodiscard]] std::optional const& getContextPhaseParams() const noexcept
{
return mContextPhaseParams;
}
@@ -423,10 +424,10 @@ class GenericLlmRequest
/// @brief Get the state params of the context
/// @return The state params of the context
- executor::ContextPhaseState const& getContextPhaseState() const
+ [[nodiscard]] executor::DataTransceiverState const& getDataTransceiverState() const
{
TLLM_CHECK(mContextPhaseParams.has_value());
- return *static_cast(mContextPhaseParams.value().getState());
+ return *static_cast(mContextPhaseParams.value().getState());
}
/// @brief Get total number of tokens for this req (prompt + generated)
@@ -659,6 +660,11 @@ class GenericLlmRequest
return mSequenceIndex > 0;
}
+ [[nodiscard]] RequestIdType getParentRequestId() const
+ {
+ return mParentRequestId;
+ }
+
/// @brief Return a vector of the last-generated tokens of shape [num_beams]
[[nodiscard]] VecTokens const& getLastTokens()
{
@@ -858,14 +864,46 @@ class GenericLlmRequest
return mOrigPromptLen;
}
- void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen)
+ [[nodiscard]] SizeType32 getPromptLen() const
{
- mPrepopulatedPromptLen = prepopulatedPromptLen;
+ return mPromptLen;
}
- [[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
+ void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
{
- return mPrepopulatedPromptLen;
+ auto const promptLen = getPromptLen();
+ TLLM_CHECK(prepopulatedPromptLen < promptLen);
+
+ if (prepopulatedPromptLen > 0)
+ {
+ // Currently, the runtime process is to apply for cache first and then determine prepopulation.
+ // Use the prepopulated length to advance the context position and decrease chunk size if necessary.
+ if (isFullContextRequest())
+ {
+ setContextCurrentPosition(prepopulatedPromptLen);
+ setContextChunkSize(promptLen);
+ }
+ else
+ {
+ auto chunkSize = getContextChunkSize();
+ if (prepopulatedPromptLen + chunkSize < promptLen)
+ {
+ // make sure to end at block boundary after current chunk
+ auto const flooredEndPosition
+ = (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock;
+ chunkSize = flooredEndPosition - prepopulatedPromptLen;
+ TLLM_CHECK(chunkSize <= getContextChunkSize());
+ }
+ setContextCurrentPosition(prepopulatedPromptLen);
+ setContextChunkSize(chunkSize);
+ }
+ if (!isLastContextChunk())
+ {
+ TLLM_CHECK_WITH_INFO((getContextCurrentPosition() + getContextChunkSize()) % kvTokensPerBlock == 0,
+ "To prevent cache fragmentation, the context position after current chunk should be divisible "
+ "by the number of tokens per block, except for the last chunk.");
+ }
+ }
}
void setDraftTokens(std::shared_ptr const& draftTokens)
@@ -1276,7 +1314,7 @@ class GenericLlmRequest
}
// TODO: fill the rank ids
result.contextPhaseParams = executor::ContextPhaseParams{
- std::move(firstGenTokens), mContextPhaseParams.value().releaseState()};
+ std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState()};
}
auto const calculateNbTokensOut = [this](SizeType32 maxNbTokens)
@@ -1513,8 +1551,8 @@ class GenericLlmRequest
{
if (mInputTokenExtraIds.value()->size() != inputTokens.size())
{
- std::string errStr = "inputTokenExtraIds vector size must be the same as input token vector size.";
- TLLM_THROW(errStr);
+ TLLM_THROW("inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).",
+ mInputTokenExtraIds.value()->size(), inputTokens.size());
}
VecTokenExtraIds tokenExtraIds = *mInputTokenExtraIds.value();
for (std::size_t i = 0; i < inputTokens.size(); ++i)
diff --git a/cpp/include/tensorrt_llm/common/cudaUtils.h b/cpp/include/tensorrt_llm/common/cudaUtils.h
index 3bb203f8e..023f97d87 100644
--- a/cpp/include/tensorrt_llm/common/cudaUtils.h
+++ b/cpp/include/tensorrt_llm/common/cudaUtils.h
@@ -161,7 +161,7 @@ inline std::optional isCudaLaunchBlocking()
return result;
}
-inline void syncAndCheck(char const* const file, int const line)
+inline bool doCheckError()
{
auto const cudaLaunchBlocking = isCudaLaunchBlocking();
#ifndef NDEBUG
@@ -170,10 +170,15 @@ inline void syncAndCheck(char const* const file, int const line)
bool const checkError = cudaLaunchBlocking.value_or(false);
#endif
- if (checkError)
+ return checkError;
+}
+
+inline void syncAndCheck(char const* const file, int const line)
+{
+ if (doCheckError())
{
- cudaError_t result = cudaDeviceSynchronize();
- check(result, "cudaDeviceSynchronize", file, line);
+ check(cudaGetLastError(), "cudaGetLastError", file, line);
+ check(cudaDeviceSynchronize(), "cudaDeviceSynchronize", file, line);
}
}
diff --git a/cpp/include/tensorrt_llm/common/mpiUtils.h b/cpp/include/tensorrt_llm/common/mpiUtils.h
index edf3da004..4a7bb53ae 100644
--- a/cpp/include/tensorrt_llm/common/mpiUtils.h
+++ b/cpp/include/tensorrt_llm/common/mpiUtils.h
@@ -380,6 +380,10 @@ class MpiComm
void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const;
void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const;
+
+ void allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
+ std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const;
+
void barrier() const;
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h
index a5c6cf03c..807382c4a 100644
--- a/cpp/include/tensorrt_llm/executor/executor.h
+++ b/cpp/include/tensorrt_llm/executor/executor.h
@@ -43,7 +43,7 @@ char const* version() noexcept;
class Model;
class Serialization;
-class ContextPhaseState;
+class DataTransceiverState;
/// @brief Sampling configuration
class SamplingConfig
@@ -283,8 +283,10 @@ struct LookaheadDecodingConfig
class ContextPhaseParams
{
public:
- explicit ContextPhaseParams(VecTokens firstGenTokens);
- ContextPhaseParams(VecTokens firstGenTokens, void* state);
+ using RequestIdType = std::uint64_t;
+
+ explicit ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId);
+ ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state);
ContextPhaseParams(ContextPhaseParams const&);
ContextPhaseParams(ContextPhaseParams&&);
@@ -295,6 +297,8 @@ class ContextPhaseParams
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
+ [[nodiscard]] RequestIdType getReqId() const noexcept;
+
[[nodiscard]] void const* getState() const noexcept;
[[nodiscard]] void* getState() noexcept;
[[nodiscard]] void* releaseState() noexcept;
@@ -304,6 +308,9 @@ class ContextPhaseParams
static void deleter(void const* data);
using StatePtr = std::unique_ptr;
+ /// @brief This request corresponds to the request ID in the context phase.
+ RequestIdType mReqId{0};
+
/// @brief The first tokens generated by context executor
VecTokens mFirstGenTokens;
@@ -593,18 +600,24 @@ class KvCacheConfig
class ExtendedRuntimePerfKnobConfig
{
public:
- explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = true, bool enableContextFMHAFP32Acc = false);
+ explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = true, bool enableContextFMHAFP32Acc = false,
+ bool cudaGraphMode = false, SizeType32 cudaGraphCacheSize = 0);
bool operator==(ExtendedRuntimePerfKnobConfig const& other) const
{
- return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc;
+ return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc
+ && mCudaGraphMode == other.mCudaGraphMode && mCudaGraphCacheSize == other.mCudaGraphCacheSize;
}
[[nodiscard]] bool getMultiBlockMode() const;
[[nodiscard]] bool getEnableContextFMHAFP32Acc() const;
+ [[nodiscard]] bool getCudaGraphMode() const;
+ [[nodiscard]] SizeType32 getCudaGraphCacheSize() const;
void setMultiBlockMode(bool multiBlockMode);
void setEnableContextFMHAFP32Acc(bool enableContextFMHAFP32Acc);
+ void setCudaGraphMode(bool cudaGraphMode);
+ void setCudaGraphCacheSize(SizeType32 cacheSize);
private:
friend class Serialization;
@@ -614,6 +627,13 @@ class ExtendedRuntimePerfKnobConfig
/// @brief If enable FMHA runner FP32 accumulation.
bool mEnableContextFMHAFP32Acc;
+
+ /// @brief Control if enable cuda graph.
+ bool mCudaGraphMode;
+
+ /// @brief Number of cuda graphs to be cached in the runtime.
+ /// The larger the cache, the better the perf, but more GPU memory is consumed.
+ SizeType32 mCudaGraphCacheSize;
};
/// @brief Configuration class for debugging output
diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h
index 11d22c3f0..9fe197dc9 100644
--- a/cpp/include/tensorrt_llm/executor/serialization.h
+++ b/cpp/include/tensorrt_llm/executor/serialization.h
@@ -75,10 +75,10 @@ class Serialization
static void serialize(kv_cache::CacheState const& state, std::ostream& os);
[[nodiscard]] static size_t serializedSize(kv_cache::CacheState const& state);
- // ContextPhaseState
- [[nodiscard]] static ContextPhaseState deserializeContextPhaseState(std::istream& is);
- static void serialize(ContextPhaseState const& contextPhaseState, std::ostream& os);
- [[nodiscard]] static size_t serializedSize(ContextPhaseState const& contextPhaseState);
+ // DataTransceiverState
+ [[nodiscard]] static DataTransceiverState deserializeDataTransceiverState(std::istream& is);
+ static void serialize(DataTransceiverState const& dataTransceiverState, std::ostream& os);
+ [[nodiscard]] static size_t serializedSize(DataTransceiverState const& dataTransceiverState);
// ContextPhaseParams
[[nodiscard]] static ContextPhaseParams deserializeContextPhaseParams(std::istream& is);
diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h
index 39e6693d7..a2476caff 100644
--- a/cpp/include/tensorrt_llm/executor/types.h
+++ b/cpp/include/tensorrt_llm/executor/types.h
@@ -198,6 +198,10 @@ enum class CapacitySchedulerPolicy
/// @brief GUARANTEED_NO_EVICT uses KV cache more conservatively guaranteeing that a request, once started, will run
/// to completion without eviction.
kGUARANTEED_NO_EVICT = 1,
+
+ /// @brief kSTATIC_BATCH does not schedule new requests until all requests in current batch are completed.
+ /// Similar to kGUARANTEED_NO_EVICT, requests will run to completion without eviction.
+ kSTATIC_BATCH = 2
};
std::ostream& operator<<(std::ostream& os, CapacitySchedulerPolicy policy);
diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
index 358826f50..2db8fcc18 100644
--- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
+++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
@@ -62,12 +62,12 @@ class GptDecoderBatched : public IGptDecoderBatched
void newRequests(std::vector const& seqSlots, std::vector const& requests,
std::vector const& samplingConfigs) override;
- TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
+ DecoderFinishedEventPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
- void forwardSync(decoder_batch::Token const& token) override;
+ void forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent) override;
- void forwardSync(
- decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input) override;
+ void forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent, decoder_batch::Output& output,
+ decoder_batch::Input const& input) override;
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
@@ -271,7 +271,7 @@ class GptDecoderBatched : public IGptDecoderBatched
void newRequestExplicitDraftTokens(SizeType32 batchIdx, decoder_batch::Request const& request);
//! @brief Updates finished state on host for all active requests
- void updateFinished(decoder_batch::Token const& token);
+ void updateFinished(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent);
//! @brief Sets inputs for explicit draft tokens.
void setExplicitDraftTokensInputs(decoder_batch::Input const& input);
@@ -289,7 +289,7 @@ class GptDecoderBatched : public IGptDecoderBatched
CudaStreamPtr mRuntimeStream;
CudaStreamPtr mDecoderStream;
BufferManager mBufferManager;
- TokenPtr mForwardToken;
+ DecoderFinishedEventPtr mDecoderFinishEvent;
CudaEvent mForwardEvent;
using GptDecoderPtr = std::unique_ptr;
diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
index 11464f80e..048fa05a7 100644
--- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
+++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
@@ -75,11 +75,11 @@ class Input
using Output = decoder::Output;
-// TODO: is this a bad name to mix up with token concept in LLM? Would 'Event' be better? And should move to common.h
-class Token
+// used just as a container for easy returning / passing to function
+class DecoderFinishedEvent
{
public:
- explicit Token(CudaEvent&& event, std::vector const& active)
+ explicit DecoderFinishedEvent(CudaEvent&& event, std::vector const& active)
: event(std::move(event))
, active(active)
{
@@ -96,7 +96,7 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder
public:
using CudaStreamPtr = std::shared_ptr;
using TensorPtr = std::shared_ptr;
- using TokenPtr = std::unique_ptr;
+ using DecoderFinishedEventPtr = std::unique_ptr;
//! @brief Setup buffers for ExplicitDraftTokens decoding.
virtual void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) = 0;
@@ -105,15 +105,15 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder
virtual void setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) = 0;
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
- virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
+ virtual DecoderFinishedEventPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
//! @brief Call decoder forwardSync and wait for the call to `forwardAsync` associated with a token to complete.
- virtual void forwardSync(
- decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input)
+ virtual void forwardSync(decoder_batch::DecoderFinishedEvent const& token, decoder_batch::Output& output,
+ decoder_batch::Input const& input)
= 0;
//! @brief Wait for the call to `forwardAsync` associated with a token to complete.
- virtual void forwardSync(decoder_batch::Token const& token) = 0;
+ virtual void forwardSync(decoder_batch::DecoderFinishedEvent const& token) = 0;
//! @brief Run one step for all requests and wait for completion on the host.
virtual void forward(decoder_batch::Output& output, decoder_batch::Input const& input)
diff --git a/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h b/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h
index 56504bd94..3c6fe731a 100644
--- a/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h
+++ b/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h
@@ -62,6 +62,7 @@ class LookaheadRuntimeBuffers
TensorMap& inputBuffers, TensorMap& outputBuffers, runtime::WorldConfig const& worldConfig) const;
public:
+ TensorPtr cumSumLength; // [1] the cumulative sum of generation length, on pinned
TensorPtr packedMasksDevice; // [forwardBatchSize, tokensPerStep, numPackedMasks], on gpu
TensorPtr generationLengthsDevice; // [forwardBatchSize], on gpu
TensorPtr positionOffsetsDevice; // [forwardBatchSize, tokensPerStep], on gpu
diff --git a/cpp/include/tensorrt_llm/runtime/loraModule.h b/cpp/include/tensorrt_llm/runtime/loraModule.h
index 15ac50fa8..d76178eaf 100644
--- a/cpp/include/tensorrt_llm/runtime/loraModule.h
+++ b/cpp/include/tensorrt_llm/runtime/loraModule.h
@@ -179,7 +179,7 @@ class LoraModule
static std::vector createLoraModules(std::vector const& loraModuleNames,
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
- SizeType32 attentionHeadSize, SizeType32 tpSize);
+ SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts);
static ModuleType constexpr toModuleType(std::string_view const& name)
{
diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h
index fc3ac2928..b1b495e75 100644
--- a/cpp/include/tensorrt_llm/runtime/modelConfig.h
+++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h
@@ -60,6 +60,9 @@ class ModelConfig
{
kATTENTION,
kRECURRENT,
+ // NOTE: Linear and noop are attention alternatives introduced in Nemotron-NAS. They do not use the KV cache.
+ kLINEAR,
+ kNOOP,
};
enum class KVCacheType : std::int32_t
@@ -97,13 +100,13 @@ class ModelConfig
kEnabled,
};
- explicit ModelConfig(SizeType32 vocabSize, SizeType32 nbAttentionLayers, SizeType32 nbRnnLayers, SizeType32 nbHeads,
- SizeType32 hiddenSize, nvinfer1::DataType dtype)
+ explicit ModelConfig(SizeType32 vocabSize, SizeType32 nbLayers, SizeType32 nbAttentionLayers,
+ SizeType32 nbRnnLayers, SizeType32 nbHeads, SizeType32 hiddenSize, nvinfer1::DataType dtype)
: mVocabSize(vocabSize)
+ , mNbLayers(nbLayers)
, mNbAttentionLayers(nbAttentionLayers)
, mNbRnnLayers(nbRnnLayers)
, mNbHeads(nbHeads)
- , mNbKvHeads(nbHeads)
, mHiddenSize(hiddenSize)
, mSizePerHead(mHiddenSize / mNbHeads)
, mDataType(dtype)
@@ -134,6 +137,10 @@ class ModelConfig
, mUseShapeInference(true)
, mManageWeightsType(ManageWeightsType::kDisabled)
{
+ TLLM_CHECK_WITH_INFO(mNbLayers >= mNbAttentionLayers + mNbRnnLayers,
+ "Number of layers (%d) expected to be >= number of attention (%d) + number of rnn layers (%d)", mNbLayers,
+ mNbAttentionLayers, mNbRnnLayers);
+ setNbKvHeads(mNbHeads);
}
[[nodiscard]] static std::vector getOptProfilesSplitPoints() noexcept
@@ -151,14 +158,55 @@ class ModelConfig
return (mVocabSize + worldSize - 1) / worldSize * worldSize;
}
- [[nodiscard]] SizeType32 constexpr getNbAttentionLayers(SizeType32 pipelineParallelism = 1) const
+ [[nodiscard]] SizeType32 countLocalLayers(
+ LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
{
- return mNbAttentionLayers / pipelineParallelism;
+ TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
+ auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder
+ auto const firstLocalLayerIt = mLayerTypes.cbegin() + (numLocalLayers * pipelineParallelismRank);
+ return std::count(firstLocalLayerIt, firstLocalLayerIt + numLocalLayers, layerType);
}
- [[nodiscard]] SizeType32 constexpr getNbRnnLayers(SizeType32 pipelineParallelism = 1) const
+ [[nodiscard]] SizeType32 countLowerRankLayers(
+ LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
{
- return mNbRnnLayers / pipelineParallelism;
+ auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder
+ auto const firstLocalLayer = numLocalLayers * pipelineParallelismRank;
+ // count number of previous non-local attention layers
+ return std::count(mLayerTypes.cbegin(), mLayerTypes.cbegin() + firstLocalLayer, layerType);
+ }
+
+ [[nodiscard]] SizeType32 getNbLayers(SizeType32 pipelineParallelism = 1) const
+ {
+ return mNbLayers / pipelineParallelism; // WARNING: assume no remainder
+ }
+
+ [[nodiscard]] SizeType32 getNbAttentionLayers(
+ SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
+ {
+ // TODO(oargov): get rid of this invalid state
+ if (mLayerTypes.empty())
+ {
+ // this assumption might be wrong in a few cases, for example:
+ // layer types: [attention, recurrent, recurrent], pp=2 ==> first rank has 1 attention layer, not 0
+ TLLM_LOG_DEBUG("Assuming uniform distribution of attention layers between ranks");
+ return mNbAttentionLayers / pipelineParallelism;
+ }
+ return countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
+ }
+
+ [[nodiscard]] SizeType32 getNbRnnLayers(
+ SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
+ {
+ // TODO(oargov): get rid of this invalid state
+ if (mLayerTypes.empty())
+ {
+ // this assumption might be wrong in a few cases, for example:
+ // layer types: [attention, attention, recurrent], pp=2 ==> second rank has 1 rnn layer, not 0
+ TLLM_LOG_DEBUG("Assuming uniform distribution of recurrent layers between ranks");
+ return mNbRnnLayers / pipelineParallelism;
+ }
+ return countLocalLayers(LayerType::kRECURRENT, pipelineParallelism, pipelineParallelismRank);
}
[[nodiscard]] SizeType32 constexpr getNbHeads() const noexcept
@@ -166,14 +214,16 @@ class ModelConfig
return mNbHeads;
}
- [[nodiscard]] SizeType32 constexpr getNbKvHeads() const noexcept
+ [[nodiscard]] SizeType32 getNbKvHeads(SizeType32 layerIdx) const
{
- return mNbKvHeads;
+ TLLM_CHECK_WITH_INFO(layerIdx < mNbAttentionLayers, "Layer index %d is out of bounds", layerIdx);
+ return mNumKvHeadsPerAttentionLayer[layerIdx];
}
- void constexpr setNbKvHeads(SizeType32 nbKvHeads) noexcept
+ // set the number of kv heads for all layers
+ void setNbKvHeads(SizeType32 nbKvHeads)
{
- mNbKvHeads = nbKvHeads;
+ mNumKvHeadsPerAttentionLayer = std::vector(mNbAttentionLayers, nbKvHeads);
}
[[nodiscard]] SizeType32 constexpr getHiddenSize() const noexcept
@@ -645,12 +695,46 @@ class ModelConfig
mModelName = modelName;
}
+ [[nodiscard]] std::vector const& getNumKvHeadsPerLayer() const
+ {
+ return mNumKvHeadsPerAttentionLayer;
+ }
+
+ [[nodiscard]] std::pair::const_iterator, std::vector::const_iterator>
+ getNumKvHeadsPerLayerLocalRange(SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
+ {
+ TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
+ // count number of previous non-local attention layers
+ auto const numPrevAttnLayers
+ = countLowerRankLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
+ auto const firstLocalAttentionLayerIt = mNumKvHeadsPerAttentionLayer.cbegin() + numPrevAttnLayers;
+ auto const numLocalAttentionLayers
+ = countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
+ return std::make_pair(firstLocalAttentionLayerIt, firstLocalAttentionLayerIt + numLocalAttentionLayers);
+ }
+
+ void setNumKvHeadsPerLayer(std::vector const& headsPerLayer)
+ {
+ auto const numElems = static_cast(headsPerLayer.size());
+ TLLM_CHECK_WITH_INFO(numElems == mNbAttentionLayers,
+ "Length of head_per_layer (%d) must match number of attention layers (%d)", numElems, mNbAttentionLayers);
+ mNumKvHeadsPerAttentionLayer = headsPerLayer;
+ }
+
+ [[nodiscard]] SizeType32 getSumLocalKvHeads(
+ SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
+ {
+ auto [cbegin, cend] = getNumKvHeadsPerLayerLocalRange(pipelineParallelism, pipelineParallelismRank);
+ auto const sumLocalHeads = std::reduce(cbegin, cend);
+ return sumLocalHeads;
+ }
+
private:
SizeType32 mVocabSize;
+ SizeType32 mNbLayers;
SizeType32 mNbAttentionLayers;
SizeType32 mNbRnnLayers;
SizeType32 mNbHeads;
- SizeType32 mNbKvHeads;
SizeType32 mHiddenSize;
SizeType32 mSizePerHead;
nvinfer1::DataType mDataType;
@@ -703,6 +787,7 @@ class ModelConfig
bool mUseShapeInference;
ManageWeightsType mManageWeightsType;
std::string mModelName;
+ std::vector mNumKvHeadsPerAttentionLayer;
};
} // namespace tensorrt_llm::runtime
diff --git a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
index 8226c411c..e739e8188 100644
--- a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
+++ b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
@@ -97,8 +97,7 @@ class SpeculativeDecodingMode
[[nodiscard]] bool constexpr variableDraftLength() const
{
- // Add Lookahead, when lookahead supports it.
- return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens);
+ return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding);
}
[[nodiscard]] bool constexpr hasDraftLogits() const
diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt
index 10debf560..2ff2de09e 100644
--- a/cpp/tensorrt_llm/CMakeLists.txt
+++ b/cpp/tensorrt_llm/CMakeLists.txt
@@ -348,9 +348,11 @@ endif()
if(NOT WIN32) # Unix-like compilers
set(UNDEFINED_FLAG "-Wl,--no-undefined")
set(AS_NEEDED_FLAG "-Wl,--as-needed")
+ set(NO_AS_NEEDED_FLAG "-Wl,--no-as-needed")
else() # Windows
set(UNDEFINED_FLAG "")
set(AS_NEEDED_FLAG "")
+ set(NO_AS_NEEDED_FLAG "")
endif()
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
index b18a4d582..c54da94de 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:e08b60b89bb4934490ee61383c55c22d831fa1cfcccedea5735400e3574aadbc
-size 4671466
+oid sha256:10b940475c5acd80a61674d8ce4e42cc4ef3d806bafb245bbed26751378274e3
+size 4904726
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
index d5279318b..ac692bb61 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:2b6b3bf449c4b4d67f0bb9879af6b8eda6f46f272eaa5b7305582a2cc8c73e17
-size 4775694
+oid sha256:b2754f7887a1b5c37ba3d589320e16144039cfe5dc6a6c78ee71925861d7d511
+size 5015842
diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
index 975995460..c8b35c9c0 100644
--- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-f229593e4699180b52e38f99c8ac31dc libtensorrt_llm_batch_manager_static.a
-440b3ae47982d88fc8517c5f01f67b3c libtensorrt_llm_batch_manager_static.pre_cxx11.a
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+ff71eabd0ac6ede5398b5b6ce4e26dcf libtensorrt_llm_batch_manager_static.a
+846eb112a182973e7c3b0b193300b4b8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
index 2b9c3f003..2b867222c 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:cb9d3d05ef4b08df0fc02f39c053a4435b58f9431d1ce269439b2c1f0a055b21
-size 4523116
+oid sha256:13b8701dd767b414a5376a91905985979ad9d2b975465ac00835c04656ee6508
+size 4766226
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
index 54354fcf8..64680e7ae 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:6597b35faffe93244d89595dc369ece59729c984871ad5aab531d714d39c8e49
-size 4487214
+oid sha256:cd0b73a017fc5c663235dcd724eb104ecc49d12ff29b6e3744be6ea952d027db
+size 4722522
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt
index 1b565a5fa..833efc826 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-c015186a31c789891a27e44f5a9ab9ec libtensorrt_llm_batch_manager_static.a
-cac21708838abf82b18e1846c40b5c79 libtensorrt_llm_batch_manager_static.pre_cxx11.a
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+1eb5c88f894f3361445d7254cbc29b03 libtensorrt_llm_batch_manager_static.a
+4e73341b23e8fb20b732ba08e03a54a8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
index c234b864e..9fd773218 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:85f18d38a66cd8b15c7e447be16171b9db854f2b2fe9dc49daa4f93fae9bc125
-size 30145896
+oid sha256:b4ac61c0b0816477c11bd6c66ec4c2f23f7b6e1400eacd8c07c333f79dec0bea
+size 30794956
diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt
index 063e584b1..db6e80406 100644
--- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt
+++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt
@@ -1,2 +1,2 @@
-8ebb7d383e97bcd738cc24b00d58a2d0 tensorrt_llm_batch_manager_static.lib
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+eefe7310a60098897724f46cf4aa54f8 tensorrt_llm_batch_manager_static.lib
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/common/mpiUtils.cpp b/cpp/tensorrt_llm/common/mpiUtils.cpp
index 6022dfd6a..b637e57f1 100644
--- a/cpp/tensorrt_llm/common/mpiUtils.cpp
+++ b/cpp/tensorrt_llm/common/mpiUtils.cpp
@@ -314,6 +314,18 @@ void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType d
#endif // ENABLE_MULTI_DEVICE
}
+void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
+ std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const
+{
+#if ENABLE_MULTI_DEVICE
+ MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
+ getMpiDtype(recvtype), mComm));
+
+#else
+ TLLM_THROW("Multi device support is disabled.");
+#endif // ENABLE_MULTI_DEVICE
+}
+
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a
index 8726ff8ae..d7f58205a 100644
--- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a
+++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:41eae80923f8634a635f2fce84fdbe33101ee6cf86c0a98ed4ce30a7f4cea350
-size 1782460
+oid sha256:ebab2cc2c62a826ddec02597178b8e0c9bc316726f37f8eef37c06795aebcf03
+size 1784658
diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
index 482344588..b8e5962bf 100644
--- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:dd4fb660b1c3e664a012e26cea949dcded855e133aa6aadd01157a15df3e0d44
-size 1808956
+oid sha256:4b630f89708614e63c67871e21b6e32bfde71acc51549b650c57048c0fa343e7
+size 1812686
diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
index af0c5cd81..a4434f2dd 100644
--- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-2744c4784cfd34c7311148c7f7614757 libtensorrt_llm_executor_static.a
-d56af9e74a9d49e32860d89dcca024d0 libtensorrt_llm_executor_static.pre_cxx11.a
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+136f1b9d2168cbb9011a341b267af9a2 libtensorrt_llm_executor_static.a
+183bd079377d6cd698d46370168a5726 libtensorrt_llm_executor_static.pre_cxx11.a
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
index a9670e009..d1c437693 100644
--- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
+++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:e9a5c64c347400297bc8b6907b4feaa890305aaf5c1b45ce57aca8fcae3e881f
-size 1846898
+oid sha256:e04c76f6441a49db4d3996c62b4055395ae018384d8ee2f02ea5f0c4c0843902
+size 1853180
diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
index 0edd3b394..61c25133c 100644
--- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:e4b5912ce9a1c13554f4b16d29deb6f2ad51477c56810b758ba488212f8e5dc9
-size 1757522
+oid sha256:95ba1a4b6bdcecbb592bbb42b4998bcb0eb1f45a318163635183bcde6950c4bf
+size 1764982
diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt
index 6717291c8..ad7ba2bf9 100644
--- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-69dc85fe48625b6f8f684487f2048458 libtensorrt_llm_executor_static.a
-d46f0be3543e24c4df51ae287086ca52 libtensorrt_llm_executor_static.pre_cxx11.a
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+dfbd0d424c150253ff758aa5bd37a971 libtensorrt_llm_executor_static.a
+e82866739fef1d6df8293541967924bf libtensorrt_llm_executor_static.pre_cxx11.a
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib
index cc641f706..2799dc524 100644
--- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib
+++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:fc94de246db88ca06e5a7c20d04f78057cc3c2928b3fa3e79f49e8b9d90b76da
-size 19683718
+oid sha256:aa8ba34fb98c5407e3d6944245086158c61b2c784b15c7b923fdd156b942224d
+size 19670642
diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt
index e88bce29b..d2e341ae7 100644
--- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt
+++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt
@@ -1,2 +1,2 @@
-0e361ba639fa897f489f6d0f48cfe13f tensorrt_llm_executor_static.lib
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+784ad1fabd3d02466f95fbc463b64f5b tensorrt_llm_executor_static.lib
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/cumsumLastDim.cu b/cpp/tensorrt_llm/kernels/cumsumLastDim.cu
index 1b7977708..8989e95fc 100644
--- a/cpp/tensorrt_llm/kernels/cumsumLastDim.cu
+++ b/cpp/tensorrt_llm/kernels/cumsumLastDim.cu
@@ -115,6 +115,11 @@ template
void invokeCumsumLastDim(SizeType32 batchSize, SizeType32 inputLength, void const* __restrict__ input,
void* __restrict__ output, void* deviceTempStorage, size_t tempStorageBytes, cudaStream_t stream)
{
+ // For empty tensor support
+ if (batchSize == 0)
+ {
+ return;
+ }
if (deviceTempStorage != nullptr) // we need to use DeviceScan
{
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt
index 49e66bdce..6f3b1ed98 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt
@@ -1,2 +1,2 @@
88c30973b9b3452baa3f063d34d08169 libtensorrt_llm_nvrtc_wrapper.so
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt
index f5da2f14a..d3923a7d2 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt
@@ -1,2 +1,2 @@
95e9f87610383348e444d2d0b8396f2d libtensorrt_llm_nvrtc_wrapper.so
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll
index fadecd1af..643b3b831 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:6fc6f35c712d83404e40a7840a0c9d1f5157df61df91a7207c4e4131783f4676
+oid sha256:1471e322bb44cd65b98ee30e0befa32ae4c86e828f0b4fd4f02d4af4e710d08f
size 1128448
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib
index 0592bf8c5..cfe4399d6 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:e74ab8e65851dfc44e015714fe166f521649b781c85bd0215d42b488218e9ca5
+oid sha256:e207a8f57b944529163c7ed2ab30639a5f2779c5118602c6ebd50a623d16f845
size 3488
diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt
index b9555dc07..6dded519b 100644
--- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt
+++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt
@@ -1,3 +1,3 @@
-700fc148d9a0f939e0088bf69e899360 tensorrt_llm_nvrtc_wrapper.lib
-6ea6ac6dff8793afbd79dd5768daae85 tensorrt_llm_nvrtc_wrapper.dll
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+b7e624ba775e9f5090ef4b67bcdbd7a2 tensorrt_llm_nvrtc_wrapper.lib
+f9b1cc37a27dd0574bb41a2763a97be7 tensorrt_llm_nvrtc_wrapper.dll
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
index a6ff47b69..36daef37b 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:f4cc1e6f0b6d1e7bc875284275b591d34c707471e636019b4c2904f30798dbc9
+oid sha256:9117f7cf5eef0ed452c0d0bc79242b84def103e7038c9d3df6e366690801ca92
size 25364090
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
index cc7772d70..7eaca6cd9 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:658a3f0bc5b9877e5ad447437287908dc9b7df87ae0e86f5338aaf81e26f723e
+oid sha256:2b04913f9e9029a5ce5a222d5cc7492ff53323a548079d2fb32d5b2aeb0c2268
size 25768990
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
index b8d933d49..ecfff5209 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-979f3165fbc7a68528df6e343cc54e3f libtensorrt_llm_internal_cutlass_kernels_static.a
-68e84c294a658734a8b26d7270540e1d libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+d54fb93f256601f4c4ad7f1c8e6e9919 libtensorrt_llm_internal_cutlass_kernels_static.a
+71028d801074f11138e890391e48591d libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
index 61b711b17..715fba593 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:24358d8eb15a5e802cbee6d2d27735033eedb33091e9355d199229c3ba7b6447
+oid sha256:d8c685f8ea2f84838dfdbf448eab41c76fe88fe29db0d4a511d6d6d241ad1832
size 44173632
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
index 0f981d9d4..4f403b38e 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:4b54efb78e98bf6580f73a3b6b689823d7c2eb851c0bab5f36906f4ebbfc44fc
-size 43561142
+oid sha256:b9d75392ba3b59853c43072b4f9949b32cb6724813a39048e4585e9a8fb3e136
+size 43561206
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
index b6f327833..dcd8a686a 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt
@@ -1,3 +1,3 @@
-0b49d88b8b5e83c8c6997c725a37f373 libtensorrt_llm_internal_cutlass_kernels_static.a
-7a12fc880d2a13ee5c7cf2b1e169cb19 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+4fc3e1fb0db6a121f88a9141605d9285 libtensorrt_llm_internal_cutlass_kernels_static.a
+253731af750407020dbe6f2fbe50fa2b libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib
index d5007d0d2..e88023db2 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:5ba6f8610db3b967f3de4beeff6394cdd3e56d15916f39110ed932c3c3a65417
-size 88141376
+oid sha256:62af58f5e09d1cf5e347b02ef3bd3a186469162fc9645d038fb2cba23b597722
+size 88140804
diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt
index 15334333c..5bb9d18b8 100644
--- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt
+++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt
@@ -1,2 +1,2 @@
-5c26d1347bb8b47288d598b6d7444900 tensorrt_llm_internal_cutlass_kernels_static.lib
-7adf157833793b3215570b0a95b9c4b2998a620c commit
\ No newline at end of file
+eb7fc4a105eb6e6f52ba865f2b055233 tensorrt_llm_internal_cutlass_kernels_static.lib
+7f370deb0090d885d7518c2b146399ba3933c004 commit
\ No newline at end of file
diff --git a/cpp/tensorrt_llm/layers/decodingParams.h b/cpp/tensorrt_llm/layers/decodingParams.h
index 0179add1d..07d200704 100644
--- a/cpp/tensorrt_llm/layers/decodingParams.h
+++ b/cpp/tensorrt_llm/layers/decodingParams.h
@@ -210,8 +210,6 @@ struct LookaheadSetupParams : public DecodingSetupParams
TensorPtr positionOffsets;
//! see LookaheadDecodingOutputs::attentionPackedMasks
TensorPtr attentionPackedMasks;
- //! see LookaheadDecodingOutputs::actualGenerationLengths
- TensorPtr actualGenerationLengths;
};
class BaseDecodingInputs
@@ -551,8 +549,6 @@ class LookaheadDecodingOutputs : public SpeculativeDecodingOutputs
TensorPtr positionOffsets;
//! [maxBatchSize, maxDecodingTokens]
TensorPtr positionIds;
- //! The actual decoding tokens length, for debug and for future.
- TensorPtr actualGenerationLengths;
};
class ExplicitDraftTokensOutputs : public SpeculativeDecodingOutputs
diff --git a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp
index 5b3062be0..db78160b9 100644
--- a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp
+++ b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp
@@ -18,8 +18,12 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
+#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
+#include "tensorrt_llm/runtime/common.h"
+#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/lookaheadModule.h"
+#include
#include
namespace tensorrt_llm::layers
@@ -27,6 +31,36 @@ namespace tensorrt_llm::layers
using namespace tensorrt_llm::runtime;
+LookaheadAlgorithm::LookaheadAlgorithm(
+ runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id)
+ : mMaxW(maxW)
+ , mMaxN(maxN)
+ , mMaxG(maxG)
+ , mFilling(0)
+ , mPoolManager(maxG)
+ , mId(id)
+ , mGoldenTokensMax(
+ runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32))
+ , mPrefillsMax(runtime::BufferManager::cpu(
+ runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32))
+ , mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32))
+ , mPastTokensMax(
+ runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32))
+ , mGuessTokensMax(
+ runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxG * (maxN - 1)}), nvinfer1::DataType::kINT32))
+{
+ runtime::SizeType32 maxGeneratedLen, maxDraftLen;
+ std::tie(maxGeneratedLen, std::ignore, maxDraftLen, std::ignore)
+ = executor::LookaheadDecodingConfig(maxW, maxN, maxG).calculateSpeculativeResource();
+ mAttentionMask = runtime::BufferManager::cpu(
+ runtime::ITensor::makeShape({maxDraftLen, maxDraftLen}), nvinfer1::DataType::kBOOL);
+ mDraftTokensMax
+ = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32);
+ mSampledTokensMax
+ = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxGeneratedLen}), nvinfer1::DataType::kINT32);
+ mEncodeMapMax = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32);
+}
+
void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeType32 n, SizeType32 g)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@@ -36,7 +70,7 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT
mW = w;
mN = n;
mG = g;
- std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, std::ignore)
+ std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, mRuntimeMaxDraftPathLen)
= executor::LookaheadDecodingConfig(mW, mN, mG).calculateSpeculativeResource();
mPoolManager.setup(mG);
@@ -81,8 +115,8 @@ void LookaheadAlgorithm::accept(TensorConstPtr const& generatedTokens)
}
//! lookahead has two phase, prefill the past tokens matrix and maintain past tokens matrix.
-runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, TensorPtr const& positionIds,
- TensorPtr const& samplingMask, runtime::SizeType32 offset)
+runtime::SizeType32 LookaheadAlgorithm::lookahead(
+ TensorPtr const& draftTokens, TensorPtr const& positionIds, runtime::SizeType32 startPosId)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@@ -90,7 +124,6 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
SizeType32 len = prefill + mFilling * mW;
TLLM_CHECK(len <= ITensor::volume(draftTokens->getShape()));
TLLM_CHECK(len <= ITensor::volume(positionIds->getShape()));
- TLLM_CHECK(len <= ITensor::volume(samplingMask->getShape()));
BufferRange prefillRange(*mPrefills);
BufferRange pastRange(*mPastTokens);
BufferRange draftRange(*draftTokens);
@@ -112,11 +145,6 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
}
BufferRange positionIdsRange(*positionIds);
- BufferRange samplingMaskRange(*samplingMask);
- for (auto& v : samplingMaskRange)
- {
- v = 0;
- }
SizeType32 idx = 0, wj = 0;
auto fillPosition = [&positionIdsRange, &idx](SizeType32 start, SizeType32 len)
{
@@ -127,20 +155,18 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
};
if (prefill >= 0)
{
- fillPosition(offset, prefill);
+ fillPosition(startPosId, prefill);
for (wj = 0; wj < mW; wj++)
{
- fillPosition(offset + prefill + wj, mFilling);
- samplingMaskRange[prefill + wj * mFilling + mFilling - 1] = true;
+ fillPosition(startPosId + prefill + wj, mFilling);
}
}
else
{
- fillPosition(offset, mFilling - 1);
+ fillPosition(startPosId, mFilling - 1);
for (wj = 1; wj < mW; wj++)
{
- fillPosition(offset - 1 + wj, mFilling);
- samplingMaskRange[wj * mFilling + mFilling - 1 - 1] = true;
+ fillPosition(startPosId - 1 + wj, mFilling);
}
}
PRINT_VALUES(positionIds);
@@ -150,7 +176,7 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
}
runtime::SizeType32 LookaheadAlgorithm::guess(TensorPtr const& guessTokens, TensorPtr const& guessIds,
- TensorPtr const& samplingMask, runtime::SizeType32 offset, runtime::TokenIdType lastToken)
+ runtime::SizeType32 startPosId, runtime::TokenIdType lastToken)
{
auto guesses = mPoolManager.guess(lastToken, mW);
@@ -158,67 +184,227 @@ runtime::SizeType32 LookaheadAlgorithm::guess(TensorPtr const& guessTokens, Tens
std::for_each(guesses.begin(), guesses.end(), [&len](auto& a) { len += ITensor::volume(a->getShape()); });
TLLM_CHECK(len <= ITensor::volume(guessTokens->getShape()));
TLLM_CHECK(len <= ITensor::volume(guessIds->getShape()));
- TLLM_CHECK(len <= ITensor::volume(samplingMask->getShape()));
BufferRange guessTokensRange(*guessTokens);
BufferRange guessIdsRange(*guessIds);
- BufferRange samplingMaskRange(*samplingMask);
SizeType32 cur = 0;
for (auto guess : guesses)
{
BufferRange guessRange(*guess);
std::copy(guessRange.begin(), guessRange.end(), guessTokensRange.begin() + cur);
- SizeType32 tmp = offset;
+ SizeType32 tmp = startPosId;
std::for_each(
guessIdsRange.begin() + cur, guessIdsRange.begin() + cur + mN - 1, [&tmp](auto& v) { v = tmp++; });
cur += ITensor::volume(guess->getShape());
}
- std::for_each(samplingMaskRange.begin(), samplingMaskRange.begin() + len, [](auto& a) { a = true; });
-
return len;
}
+void LookaheadAlgorithm::posIdsToMask(TensorPtr const& mask, TensorConstPtr const& posIds)
+{
+ auto len = ITensor::volume(posIds->getShape());
+ TLLM_CHECK(mask->getDimension<0>() >= len);
+ TLLM_CHECK(mask->getDimension<1>() >= len);
+ auto posIdsRange = BufferRange(*posIds);
+ auto maskLocation = BufferLocation(*mask);
+
+ for (auto& item : maskLocation)
+ {
+ item = false;
+ }
+
+ if (len > 0)
+ {
+ std::vector> stack;
+ for (auto i = 0; i < len; i++)
+ {
+ auto cur = posIdsRange[i];
+ while (stack.size() > 0 && cur <= stack.back().second)
+ {
+ stack.pop_back();
+ }
+ TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true);
+ stack.push_back(std::make_pair(i, cur));
+ for (auto prev : stack)
+ {
+ maskLocation.at(i, prev.first) = true;
+ }
+ }
+ }
+}
+
+struct TreeValue;
+using TreeMap = std::unordered_map;
+
+struct TreeValue
+{
+ TreeValue()
+ : nexts(std::make_shared())
+ {
+ }
+
+ using Nexts = std::shared_ptr;
+ Nexts nexts{nullptr};
+ std::list sources;
+};
+
+using TreeNode = TreeMap::value_type;
+
+template
+void treeDFS(TreeNode& node, BF const& visitBefore, AF const& visitAfter)
+{
+ visitBefore(node);
+ for (auto& next : *(node.second.nexts))
+ {
+ treeDFS(next, visitBefore, visitAfter);
+ }
+ visitAfter(node);
+}
+
+SizeType32 LookaheadAlgorithm::treeEncode(
+ TensorPtr const& tokens, TensorPtr const& posIds, TensorPtr const& mask, TensorPtr const& encodeMap)
+{
+ TLLM_CHECK(ITensor::volume(tokens->getShape()) == ITensor::volume(posIds->getShape()));
+ auto len = ITensor::volume(tokens->getShape());
+
+ BufferRange tokensRange(*tokens);
+ BufferRange posIdsRange(*posIds);
+ BufferLocation maskLocation(*mask);
+ BufferRange mapRange(*encodeMap);
+
+ auto branches = std::make_shared();
+
+ for (auto i = 0; i < len; i++)
+ {
+ auto nexts = branches;
+ for (auto j = 0; j <= i; j++)
+ {
+ if (maskLocation.at(i, j))
+ {
+ auto pos = posIdsRange[j];
+ auto tok = tokensRange[j];
+ auto found = nexts->find(tok);
+ if (found != nexts->end())
+ {
+ found->second.sources.push_back(j);
+ nexts = found->second.nexts;
+ }
+ else
+ {
+ auto [inserted, ok] = nexts->insert({tok, TreeValue()});
+ inserted->second.sources.push_back(j);
+ nexts = inserted->second.nexts;
+ }
+ }
+ }
+ }
+
+ for (auto& item : maskLocation)
+ {
+ item = 0;
+ }
+ std::vector> stack;
+ SizeType32 offset = 0;
+ SizeType32 posId = posIdsRange.size() ? posIdsRange[0] : 0;
+
+ auto visitBefore
+ = [&stack, &maskLocation, &tokensRange, &posIdsRange, &posId, &offset, &mapRange](TreeNode const& node)
+ {
+ stack.push_back(std::make_pair(offset, node.first));
+ for (auto const& source : node.second.sources)
+ {
+ mapRange[source] = offset;
+ }
+ for (auto const& prev : stack)
+ {
+ maskLocation.at(offset, prev.first) = true;
+ }
+ tokensRange[offset] = node.first;
+ posIdsRange[offset] = posId;
+ offset++;
+ posId++;
+ };
+ auto visitAfter = [&stack, &posId](TreeNode const& node)
+ {
+ stack.pop_back();
+ posId--;
+ };
+
+ for (auto& next : *branches)
+ {
+ treeDFS(next, visitBefore, visitAfter);
+ }
+
+ for (SizeType32 i = offset; i < len; i++)
+ {
+ tokensRange[i] = 0;
+ posIdsRange[i] = 0;
+ }
+ for (SizeType32 i = 0; i < len; i++)
+ {
+ for (SizeType32 j = i < offset ? offset : 0; j < len; j++)
+ {
+ maskLocation.at(i, j) = false;
+ }
+ }
+
+ return offset;
+}
+
void LookaheadAlgorithm::prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds,
- TensorPtr const& samplingMask, TensorPtr const& length, TensorConstPtr const& offsetPtr,
- TensorConstPtr const& lastTokenPtr)
+ TensorPtr const& draftLengthPtr, TensorPtr const& attentionMask, SizeType32 attentionMaskOffset,
+ TensorConstPtr const& lastPositionIdPtr, TensorConstPtr const& lastTokenPtr)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (mRuntimeMaxDraftLen == 0)
{
- (BufferRange(*length))[0] = 0;
+ mDraftTokens = ITensor::slice(mDraftTokensMax, 0, 0);
+ mEncodeMap = ITensor::slice(mEncodeMapMax, 0, 0);
+ (BufferRange(*draftLengthPtr))[0] = 0;
return;
}
auto lastToken = BufferRange(*lastTokenPtr)[0];
- auto offset = BufferRange(*offsetPtr)[0];
+ auto offset = BufferRange(*lastPositionIdPtr)[0];
SizeType32 inputLen = ITensor::volume(draftTokens->getShape());
TLLM_CHECK(inputLen >= mRuntimeMaxDraftLen);
BufferRange draftRange(*draftTokens);
BufferRange positionRange(*positionIds);
- BufferRange samplingRange(*samplingMask);
SizeType32 filledLen = 0;
filledLen += lookahead(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen),
- ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen),
- ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset);
+ ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), offset);
auto guessStart = filledLen;
filledLen += guess(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen),
- ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen),
- ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken);
+ ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken);
auto guessEnd = filledLen;
+ std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd,
+ BufferRange(*mGuessTokensMax).begin());
mGuessTokens = ITensor::slice(mGuessTokensMax, 0, guessEnd - guessStart);
- std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd,
- BufferRange(*mGuessTokens).begin());
+ posIdsToMask(mAttentionMask, ITensor::slice(positionIds, 0, filledLen));
- (BufferRange(*length))[0] = filledLen;
+ auto draftLen = treeEncode(ITensor::slice(draftTokens, 0, filledLen), ITensor::slice(positionIds, 0, filledLen),
+ mAttentionMask, mEncodeMapMax);
+
+ for (SizeType32 i = 0; i < draftLen; i++)
+ {
+ BufferRange srcRange(*ITensor::at(mAttentionMask, {i}));
+ BufferRange dstRange(*ITensor::slice(attentionMask, {i + attentionMaskOffset, attentionMaskOffset}));
+ std::copy(srcRange.begin(), srcRange.end(), dstRange.begin());
+ }
+
+ std::copy(draftRange.begin(), draftRange.begin() + draftLen, BufferRange(*mDraftTokensMax).begin());
+ mDraftTokens = ITensor::slice(mDraftTokensMax, 0, draftLen);
+ (BufferRange(*draftLengthPtr))[0] = draftLen;
+ mEncodeMap = ITensor::slice(mEncodeMapMax, 0, filledLen);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@@ -229,29 +415,31 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
- TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mGuessTokens->getShape()));
+ TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mDraftTokens->getShape()));
BufferRange goldRange(*goldenTokens);
- BufferRange guessTokensRange(*mGuessTokens);
- auto guessSize = ITensor::volume(mGuessTokens->getShape());
+ BufferRange draftRange(*mDraftTokens);
+ BufferLocation maskLocation(*mAttentionMask);
+ auto draftSize = ITensor::volume(mDraftTokens->getShape());
+ auto end = *BufferRange(*endToken).begin();
- SizeType32 guesses = (mN - 1 > 0) ? (guessSize / (mN - 1)) : 0;
- SizeType32 hit = 0, maxHit = 0, hitIdx = 0;
- for (SizeType32 i = 0; i < guesses; i++)
+ SizeType32 maxHit = 0, hitIdx = 0;
+ for (SizeType32 i = 0; i < draftSize; i++)
{
SizeType32 hit = 0;
- for (SizeType32 j = 0; j < mN - 1; j++)
+ TokenIdType cur = newLastToken;
+ for (SizeType32 j = 0; j < draftSize; j++)
{
- auto idx = i * (mN - 1) + j;
- bool ok
- = (j == 0) ? (newLastToken == guessTokensRange[idx]) : (goldRange[idx - 1] == guessTokensRange[idx]);
- bool finish = guessTokensRange[idx] == *BufferRange(*endToken).begin();
- if (ok && !finish)
- {
- hit++;
- }
- else
+ if (maskLocation.at(i, j))
{
- break;
+ if (draftRange[j] == cur && draftRange[j] != end)
+ {
+ hit++;
+ cur = goldRange[j];
+ }
+ else
+ {
+ break;
+ }
}
}
if (hit > maxHit)
@@ -261,17 +449,19 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce
}
}
- BufferRange acceptedRange(*accepted);
- acceptedRange[0] = newLastToken;
- std::copy(goldRange.begin() + hitIdx * (mN - 1), goldRange.begin() + hitIdx * (mN - 1) + maxHit,
- acceptedRange.begin() + 1);
+ maxHit = maxHit > mRuntimeMaxDraftPathLen ? mRuntimeMaxDraftPathLen : maxHit;
+ SizeType32 acceptedIdx = 0;
+ BufferRange acceptedRange(*accepted);
BufferRange acceptedOffsetsRange(*acceptedOffsets);
- auto lookSize = 1 + mN - 2 - mFilling + mFilling * mW;
- // acceptedOffsetsRange[0] = 0;
- for (SizeType32 i = 0; i < maxHit; i++)
+ acceptedRange[acceptedIdx] = newLastToken;
+ for (SizeType32 j = 0; j < draftSize; j++)
{
- acceptedOffsetsRange[i] = lookSize + hitIdx * (mN - 1) + i - 1;
+ if (maskLocation.at(hitIdx, j) && acceptedIdx < maxHit)
+ {
+ acceptedOffsetsRange[acceptedIdx++] = j;
+ acceptedRange[acceptedIdx] = goldRange[j];
+ }
}
*BufferRange(*acceptedLength).begin() = maxHit + 1;
@@ -325,7 +515,19 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(ITensor::volume(acceptedTokens->getShape()) >= mN);
- BufferRange sampledRange(*sampledTokens);
+ BufferRange zippedTokensRange(*sampledTokens);
+ BufferRange sampledRange(*mSampledTokensMax);
+
+ BufferRange mapRange(*mEncodeMap);
+ BufferRange unzipRange(*mSampledTokensMax);
+ mSampledTokens = ITensor::slice(mSampledTokensMax, 0, mEncodeMap->getShape().d[0] + 1);
+
+ unzipRange[0] = zippedTokensRange[0];
+ for (SizeType32 i = 0; i < mapRange.size(); i++)
+ {
+ unzipRange[i + 1] = zippedTokensRange[mapRange[i] + 1];
+ }
+
BufferRange keyRange(*mKeyTokens);
BufferRange pastRange(*mPastTokens);
@@ -359,13 +561,15 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
}
auto guessSize = ITensor::volume(mGuessTokens->getShape());
- auto outputSize = ITensor::volume(sampledTokens->getShape());
+ auto outputSize = ITensor::volume(mSampledTokens->getShape());
auto lookSize = 1 + (mN > 1 ? mN - 2 : 0) - mFilling + mFilling * mW;
TLLM_CHECK(guessSize + lookSize == outputSize);
- TensorConstPtr goldenTokens = ITensor::slice(sampledTokens, lookSize, guessSize);
+ TensorConstPtr goldenTokens = ITensor::slice(mSampledTokens, lookSize, guessSize);
+
+ auto& acptLen = *BufferRange(*acceptedLength).begin();
- verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, goldenTokens, endToken);
+ verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, ITensor::slice(sampledTokens, 1), endToken);
accept(ITensor::slice(acceptedTokens, 0, *BufferRange(*acceptedLength).begin()));
diff --git a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.h b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.h
index 99df44128..485734c5a 100644
--- a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.h
+++ b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.h
@@ -21,6 +21,7 @@
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include
+#include
namespace tensorrt_llm::layers
{
@@ -35,24 +36,7 @@ class LookaheadAlgorithm
//! @brief Currently the resource management is to be aligned with batch manager.
//! @param w, n, g is the Jacobi window, n-gram level and guess set size respectively.
LookaheadAlgorithm(
- runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id = 0)
- : mMaxW(maxW)
- , mMaxN(maxN)
- , mMaxG(maxG)
- , mFilling(0)
- , mPoolManager(maxG)
- , mId(id)
- , mGoldenTokensMax(
- runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32))
- , mPrefillsMax(runtime::BufferManager::cpu(
- runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32))
- , mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32))
- , mPastTokensMax(
- runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32))
- , mGuessTokensMax(
- runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxG * (maxN - 1)}), nvinfer1::DataType::kINT32))
- {
- }
+ runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id = 0);
//! @brief setup per request, fill internal states from @param prompt.
void setup(TensorConstPtr const& prompt, runtime::SizeType32 w, runtime::SizeType32 n, runtime::SizeType32 g);
@@ -62,43 +46,55 @@ class LookaheadAlgorithm
void accept(TensorConstPtr const& generatedTokens);
//! @brief combine lookahead and guess to prepare the tensors.
- //! input @param offsetPtr is position id of the last golden token, in a TensorPtr.
+ //! input @param lastPositionIdPtr is position id of the last golden token, in a TensorPtr.
//! input @param lastTokenPtr the last golden token for searching in the pool, in a TensorPtr.
- //! output @param draftTokens, positionIds, samplingMask; including the golden token, the lookahead
- //! and the verification branch information. @param length holds the draft tokens length.
- void prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds, TensorPtr const& samplingMask,
- TensorPtr const& length, TensorConstPtr const& offsetPtr, TensorConstPtr const& lastTokenPtr);
+ //! output @param draftTokens, positionIds includes the lookahead and the verification branch information.
+ //! output @param draftLengthPtr holds the draft tokens length.
+ //! output @param attentionMask holds the draft tokens dependency mask, and attentionMaskOffset is the index offset
+ //! in attentionMask.
+ void prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds, TensorPtr const& draftLengthPtr,
+ TensorPtr const& attentionMask, runtime::SizeType32 attentionMaskOffset,
+ TensorConstPtr const& lastPositionIdPtr, TensorConstPtr const& lastTokenPtr);
//! @brief update the internal states and generate accepted tokens from @param outputTokens.
- //! input @param sampledTokens is the all the tokens from the language model. The position at samplingMask=1 is
- //! valid. input @param endToken is the end token for `verify` early quit.
- //! output @param acceptedTokens, acceptedOffsets ind @param acceptedLength.
+ //! input @param sampledTokens is the all the tokens from the language model.
+ //! input @param endToken is the end token for `verify` early quit.
+ //! output @param acceptedTokens, acceptedOffsets in @param acceptedLength.
void update(TensorPtr const& acceptedTokens, TensorPtr const& acceptedOffsets, TensorPtr const& acceptedLength,
TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken);
+ //! generate attention @param mask from @param posIds.
+ static void posIdsToMask(TensorPtr const& mask, TensorConstPtr const& posIds);
+
+ //! inplace encode the @param tokens and @param posIds according to attention @param masks, and record the offsets
+ //! in @param encodeMap.
+ static runtime::SizeType32 treeEncode(
+ TensorPtr const& tokens, TensorPtr const& posIds, TensorPtr const& masks, TensorPtr const& encodeMap);
+
private:
//! @brief generate lookahead branch information.
- //! input @param offset the position id of the last golden token.
- //! output @param draftTokens, positionIds, samplingMask of the lookahead branch.
+ //! input @param startPosId is the first position id of the draftTokens.
+ //! output @param draftTokens, positionIds of the lookahead branch.
//! @return the actual filled lookahead length.
- runtime::SizeType32 lookahead(TensorPtr const& draftTokens, TensorPtr const& positionIds,
- TensorPtr const& samplingMask, runtime::SizeType32 offset);
+ runtime::SizeType32 lookahead(
+ TensorPtr const& draftTokens, TensorPtr const& positionIds, runtime::SizeType32 startPosId);
//! @brief generate verification branch information. Also save the guessed tokens for future verification.
- //! input @param offset the position id of the last golden token.
+ //! input @param startPosId the first position id.
//! input @param lastToken the last golden token for searching in the pool.
- //! output @param guessTokens, guessIds, samplingMask of the verification branch.
+ //! output @param guessTokens, guessIds of the verification branch.
//! @return the actual filled guess length.
- runtime::SizeType32 guess(TensorPtr const& guessTokens, TensorPtr const& guessIds, TensorPtr const& samplingMask,
- runtime::SizeType32 offset, runtime::TokenIdType lastToken);
+ runtime::SizeType32 guess(TensorPtr const& guessTokens, TensorPtr const& guessIds, runtime::SizeType32 startPosId,
+ runtime::TokenIdType lastToken);
//! @brief verify the guessed tokens results and generate the longest accepted tokens.
//! input @param newLastToken is the new-generated last golden token.
- //! input @param goldenTokens is the guessed token results from the language model.
+ //! input @param sampledTokens is the generated token results from the language model.
//! input @param endToken is the end token for early quit detection.
- //! output @param accepted, acceptedOffsets in @param acceptedLength, .
+ //! output @param accepted in @param acceptedLength, including the first golden one.
+ //! output @param acceptedOffsets is the offsets of draft tokens, excluding the first golden one.
void verify(TensorPtr const& accepted, TensorPtr const& acceptedOffsets, TensorPtr const& acceptedLength,
- runtime::TokenIdType newLastToken, TensorConstPtr const& goldenTokens, TensorConstPtr const& endToken);
+ runtime::TokenIdType newLastToken, TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken);
private:
LookaheadPoolManager mPoolManager;
@@ -117,6 +113,13 @@ class LookaheadAlgorithm
//! the same guess tokens from `guess` and used in `verify`
TensorPtr mGuessTokensMax; // shape [mMaxG*(mMaxN-1)]
TensorPtr mGuessTokens; // shape [mG*(mN-1)]
+ TensorPtr mDraftTokensMax;
+ TensorPtr mDraftTokens;
+ TensorPtr mAttentionMask;
+ TensorPtr mEncodeMapMax;
+ TensorPtr mEncodeMap;
+ TensorPtr mSampledTokensMax;
+ TensorPtr mSampledTokens;
//! look ahead algorithm parameters, Window size, Level and Guess set size.
//! max for reserving resources and current for current request.
@@ -127,6 +130,7 @@ class LookaheadAlgorithm
runtime::SizeType32 mN{0};
runtime::SizeType32 mG{0};
runtime::SizeType32 mRuntimeMaxDraftLen{0};
+ runtime::SizeType32 mRuntimeMaxDraftPathLen{0};
//! in prefilling mode when mFilling < mN-1.
runtime::SizeType32 mFilling;
diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp
index 8214abfb4..32b812967 100644
--- a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp
+++ b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp
@@ -24,6 +24,7 @@
#include "tensorrt_llm/layers/lookaheadAlgorithm.h"
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
+#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/lookaheadModule.h"
@@ -80,14 +81,14 @@ LookaheadDecodingLayer::CpuAlgorithmResources::CpuAlgorithmResources(DecoderD
mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
mNextDraftPosIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
mGenerationLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
- mGenerationLengthsMax = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mPositionOffsets
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32);
mPositionIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32);
+ mAttentionMask
+ = BufferManager::cpu(ITensor::makeShape({maxTokensPerStep, maxTokensPerStep}), nvinfer1::DataType::kBOOL);
mPackedMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep,
static_cast(divUp(maxTokensPerStep, 32))}),
nvinfer1::DataType::kINT32);
- mSamplingMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kBOOL);
mNextDraftLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mSequenceLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
}
@@ -113,7 +114,6 @@ LookaheadDecodingLayer::LookaheadDecodingLayer(
mWorkspaceSize = getTopKWorkspaceSize(maxBatchSize, maxTokensPerStep, maxTopK, vocabSizePadded);
mTargetTokensDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kINT32);
- mSamplingMaskDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kBOOL);
mCurandStatesDevice
= mBufferManager->gpu(ITensor::makeShape({maxBatchSize, sizeof(curandState_t)}), nvinfer1::DataType::kINT8);
@@ -168,6 +168,7 @@ void LookaheadDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth
{
SizeType32 gbi = batchSlotsRange[bi];
(BufferRange(*mCpuAlgo->mGenerationLengths))[gbi] = 1;
+ (BufferRange(*mCpuAlgo->mNextDraftLengths))[gbi] = 0;
BufferLocation(*mCpuAlgo->mPositionOffsets).at(gbi, 0) = 0;
BufferRange packedMaskRange(*ITensor::at(mCpuAlgo->mPackedMask, {gbi}));
for (auto& mask : packedMaskRange)
@@ -184,11 +185,6 @@ void LookaheadDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth
PRINT_SHAPE(setupParams->attentionPackedMasks);
mBufferManager->copy(
*ITensor::at(mCpuAlgo->mGenerationLengths, {gbi}), *ITensor::at(setupParams->generationLengths, {gbi}));
- if (setupParams->actualGenerationLengths)
- {
- mBufferManager->copy(*ITensor::at(mCpuAlgo->mGenerationLengths, {gbi}),
- *ITensor::at(setupParams->actualGenerationLengths, {gbi}));
- }
mBufferManager->copy(
*ITensor::at(mCpuAlgo->mPositionOffsets, {gbi}), *ITensor::at(setupParams->positionOffsets, {gbi}));
mBufferManager->copy(
@@ -261,39 +257,32 @@ size_t LookaheadDecodingLayer::getWorkspaceSize() const noexcept
return std::max(mWorkspaceSize, mSetupWorkspaceSize);
}
-template
-void LookaheadDecodingLayer::posIdsToMask(TensorPtr mask, TensorConstPtr posIds)
+inline void initAttentionMask(TensorPtr const& mask, std::shared_ptr& bufferManager)
{
- auto len = ITensor::volume(posIds->getShape());
- TLLM_CHECK(mask->getDimension<0>() > len);
- TLLM_CHECK(mask->getDimension<1>() * 32 > len);
- auto posIdsRange = BufferRange(*posIds);
- auto maskLocation = BufferLocation(*mask);
-
- for (auto i = 0; i < maskLocation.size(); i++)
+ bufferManager->setZero(*mask);
+ BufferLocation maskLocation(*mask);
+ auto maskShape = mask->getShape();
+ for (SizeType32 i = 0; i < maskShape.d[0]; i++)
{
- maskLocation[i] = 0;
+ maskLocation.at(i, 0) = true;
}
- maskLocation.at(0, 0) = 1;
+}
- auto setBit = [](SizeType32& x, SizeType32 idx) { x |= (1 << idx); };
- if (len > 0)
+inline void convertBoolToInt32(TensorPtr const& dst, TensorConstPtr const& src)
+{
+ auto dstShape = dst->getShape();
+ auto srcShape = src->getShape();
+ TLLM_CHECK(dstShape.d[0] == srcShape.d[0]);
+ TLLM_CHECK(dstShape.d[1] * 32 >= srcShape.d[1]);
+ BufferLocation dstLocation(*dst);
+ BufferLocation srcLocation(*src);
+
+ auto setBit = [](SizeType32& x, SizeType32 idx, bool value) { x |= (value << idx); };
+ for (auto i = 0; i < srcShape.d[0]; i++)
{
- std::vector> stack;
- stack.emplace_back(0, posIdsRange[0] - 1);
- for (auto i = 1; i < len + 1; i++)
+ for (auto j = 0; j < srcShape.d[1]; j++)
{
- auto cur = posIdsRange[i - 1];
- while (stack.size() > 0 && cur <= stack.back().second)
- {
- stack.pop_back();
- }
- TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true);
- stack.emplace_back(i, cur);
- for (auto prev : stack)
- {
- setBit(maskLocation.at(i, prev.first / 32), prev.first % 32);
- }
+ setBit(dstLocation.at(i, j / 32), j % 32, srcLocation.at(i, j));
}
}
}
@@ -307,12 +296,16 @@ void LookaheadDecodingLayer::forwardSyncCPU(
mCpuAlgo->mBatchSlots->reshape(inputs->batchSlots->getShape());
mBufferManager->copy(*inputs->batchSlots, *mCpuAlgo->mBatchSlots);
mBufferManager->copy(*inputs->curTokensPerStep.value(), *mCpuAlgo->mTokensPerStep);
- mBufferManager->copy(*inputs->curTokensPerStep.value(), *mCpuAlgo->mTokensPerStep);
mBufferManager->copy(*inputs->endIds, *mCpuAlgo->mEndIds);
mBufferManager->copy(*outputs->sequenceLength.value(), *mCpuAlgo->mSequenceLengths);
mBufferManager->copy(*mTargetTokensDevice, *mCpuAlgo->mTargetTokens);
+ if (outputs->prevDraftLengths)
+ {
+ mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, *outputs->prevDraftLengths);
+ }
+
mBufferManager->getStream().synchronize();
auto const batchSize = inputs->localBatchSize;
@@ -325,7 +318,6 @@ void LookaheadDecodingLayer::forwardSyncCPU(
BufferRange numNewTokensCumSumRange(*mCpuAlgo->mNumNewTokensCumSum);
BufferRange batchSlotsRange(*mCpuAlgo->mBatchSlots);
BufferRange generationLengthsRange(*mCpuAlgo->mGenerationLengths);
- BufferRange generationLengthsMaxRange(*mCpuAlgo->mGenerationLengthsMax);
BufferRange nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths);
BufferRange sequenceLengthsRange(*mCpuAlgo->mSequenceLengths);
BufferLocation pathsOffsetLocation(*mCpuAlgo->mPathsOffsets);
@@ -334,6 +326,7 @@ void LookaheadDecodingLayer::forwardSyncCPU(
mBufferManager->setZero(*mCpuAlgo->mPathsOffsets);
mBufferManager->setZero(*mCpuAlgo->mNumNewTokens);
mBufferManager->setZero(*mCpuAlgo->mNumNewTokensCumSum);
+ mBufferManager->setZero(*mCpuAlgo->mPackedMask);
for (SizeType32 bi = 0; bi < batchSize; bi++)
{
@@ -342,7 +335,6 @@ void LookaheadDecodingLayer::forwardSyncCPU(
SizeType32 const tokensPerStep = generationLengthsRange[gbi];
TensorPtr sampledTokens = ITensor::slice(mCpuAlgo->mTargetTokens, {gbi, 0}, tokensPerStep);
- PRINT_VALUES(sampledTokens);
if (tokensPerStep == 1)
{
@@ -369,14 +361,18 @@ void LookaheadDecodingLayer::forwardSyncCPU(
sequenceLengthsRange[gbi] += numNewTokensRange[gbi];
+ initAttentionMask(mCpuAlgo->mAttentionMask, mBufferManager);
+
theAlgo.prepare( //
ITensor::at(mCpuAlgo->mNextDraftTokens, {gbi}), //
ITensor::at(mCpuAlgo->mNextDraftPosIds, {gbi}), //
- ITensor::at(mCpuAlgo->mSamplingMask, {gbi}), //
ITensor::at(mCpuAlgo->mNextDraftLengths, {gbi}), //
+ mCpuAlgo->mAttentionMask, 1, //
ITensor::at(mCpuAlgo->mSequenceLengths, {gbi}), //
ITensor::at(mCpuAlgo->mOutputIds, {gbi, numNewTokensRange[gbi] - 1}));
+ convertBoolToInt32(ITensor::at(mCpuAlgo->mPackedMask, {gbi}), mCpuAlgo->mAttentionMask);
+
BufferLocation posIdsLocation(*ITensor::at(mCpuAlgo->mPositionIds, {gbi}));
for (auto& posid : posIdsLocation)
{
@@ -385,20 +381,14 @@ void LookaheadDecodingLayer::forwardSyncCPU(
mBufferManager->copy(*ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi]),
*ITensor::slice(mCpuAlgo->mPositionIds, {gbi, 1}, nextDraftLengthsRange[gbi]));
- posIdsToMask( //
- ITensor::at(mCpuAlgo->mPackedMask, {gbi}), //
- ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi]));
-
BufferRange offsetRange(*ITensor::at(mCpuAlgo->mPositionOffsets, {gbi}));
- TLLM_CHECK_WITH_INFO(
- posIdsLocation.size() == offsetRange.size(), "%ld, %ld", posIdsLocation.size(), offsetRange.size());
for (auto i = 0; i < posIdsLocation.size(); i++)
{
offsetRange[i] = posIdsLocation[i] - posIdsLocation[0];
}
+
TensorPtr accepted = ITensor::slice(mCpuAlgo->mOutputIds, {gbi, 0}, numNewTokensRange[gbi]);
TensorPtr draft = ITensor::slice(mCpuAlgo->mNextDraftTokens, {gbi, 0}, nextDraftLengthsRange[gbi]);
-
TLLM_LOG_DEBUG("CPU ALGO [ %d ] forward, %s", gbi, D(sampledTokens).values().c_str());
TLLM_LOG_DEBUG("[%d][%d] CPU ALGO [ %d ] forward, %s, %s", mGlobalSteps, batchSize, gbi,
D(accepted).values().c_str(), D(draft).values().c_str());
@@ -430,29 +420,23 @@ void LookaheadDecodingLayer::forwardSyncCPU(
mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, *outputs->numNewTokensCumSum); //
mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, *outputs->nextDraftTokens);
- mBufferManager->copy(*mCpuAlgo->mPackedMask, *outputs->packedMasks);
+ for (SizeType32 bi = 0; bi < batchSize; bi++)
+ {
+ SizeType32 gbi = batchSlotsRange[bi];
+ // nextDraftLengthsRange[gbi] = mDecoderDomain.getMaxDecodingTokens() - 1;
+ generationLengthsRange[gbi] = nextDraftLengthsRange[gbi] + 1;
+ }
if (outputs->nextDraftLengths)
{
mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, *outputs->nextDraftLengths);
}
- for (SizeType32 bi = 0; bi < batchSize; bi++)
- {
- SizeType32 gbi = batchSlotsRange[bi];
- generationLengthsRange[gbi] = nextDraftLengthsRange[gbi] + 1;
- generationLengthsMaxRange[gbi] = mDecoderDomain.getMaxDecodingTokens();
- }
mBufferManager->copy(*mCpuAlgo->mPackedMask, *outputs->packedMasks);
- mBufferManager->copy(*mCpuAlgo->mGenerationLengthsMax, *outputs->generationLengths);
+ mBufferManager->copy(*mCpuAlgo->mGenerationLengths, *outputs->generationLengths);
mBufferManager->copy(*mCpuAlgo->mPositionOffsets, *outputs->positionOffsets);
mBufferManager->copy(*mCpuAlgo->mPositionIds, *outputs->positionIds);
- if (outputs->actualGenerationLengths)
- {
- mBufferManager->copy(*mCpuAlgo->mGenerationLengths, *outputs->actualGenerationLengths);
- }
-
mBufferManager->getStream().synchronize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h
index 536d21727..f2470a411 100644
--- a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h
+++ b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h
@@ -48,7 +48,6 @@ class LookaheadDecodingLayer : public BaseLayer
private:
void forwardSyncCPU(std::shared_ptr const& outputs,
std::shared_ptr const& inputs);
- void posIdsToMask(TensorPtr mask, TensorConstPtr posIds);
private:
using Base::mDecoderDomain;
@@ -57,7 +56,6 @@ class LookaheadDecodingLayer : public BaseLayer
size_t mSetupWorkspaceSize{};
TensorPtr mCurandStatesDevice;
TensorPtr mTargetTokensDevice;
- TensorPtr mSamplingMaskDevice;
struct CpuAlgorithmResources
{
@@ -78,11 +76,10 @@ class LookaheadDecodingLayer : public BaseLayer
TensorPtr mNextDraftTokens;
TensorPtr mNextDraftPosIds;
- TensorPtr mSamplingMask;
TensorPtr mNextDraftLengths;
TensorPtr mSequenceLengths;
TensorPtr mGenerationLengths;
- TensorPtr mGenerationLengthsMax;
+ TensorPtr mAttentionMask;
TensorPtr mPackedMask;
TensorPtr mPositionOffsets;
TensorPtr mPositionIds;
diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp
index cd5529b07..946d50bd4 100644
--- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp
@@ -423,6 +423,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
request_seq_len, mNumHeads, mHeadSize, padding_offset, (float*) nullptr, 0, stream);
}
}
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp b/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp
index ba73171c4..3d0ff1db3 100644
--- a/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp
@@ -143,6 +143,7 @@ int CumsumLastDimPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
invokeCumsumLastDim(
batchSize, inputLength, inputs[getInputTensorIdx()], outputs[0], wp, mTempStorageBytes, stream);
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp
index 291900103..71ad6591f 100644
--- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp
+++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp
@@ -164,7 +164,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
memset(&xqaParams, 0, sizeof(XQAParams));
xqaParams.data_type = ConvertMMHAToXQAParamsHelper::data_type;
- xqaParams.layer_idx = mLayerIdx;
+ xqaParams.layer_idx = mLayerIdxInCachePool;
xqaParams.num_q_heads = mNumHeads;
xqaParams.num_kv_heads = mNumKVHeads;
xqaParams.head_size = mHeadSize;
@@ -376,13 +376,13 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params&, \
- const FusedQKVMaskedAttentionDispatchParams&, cudaStream_t stream); \
+ FusedQKVMaskedAttentionDispatchParams const&, cudaStream_t stream); \
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params&, \
- const FusedQKVMaskedAttentionDispatchParams&, cudaStream_t stream); \
+ FusedQKVMaskedAttentionDispatchParams const&, cudaStream_t stream); \
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params&, \
- const FusedQKVMaskedAttentionDispatchParams&, cudaStream_t stream); \
+ FusedQKVMaskedAttentionDispatchParams const&, cudaStream_t stream); \
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params&, \
- const FusedQKVMaskedAttentionDispatchParams&, cudaStream_t stream);
+ FusedQKVMaskedAttentionDispatchParams const&, cudaStream_t stream);
INSTANTIATE_MMHA_DISPATCH(float, float)
INSTANTIATE_MMHA_DISPATCH(uint16_t, half)
#ifdef ENABLE_BF16
@@ -391,8 +391,8 @@ INSTANTIATE_MMHA_DISPATCH(__nv_bfloat16, __nv_bfloat16)
#undef INSTANTIATE_MMHA_DISPATCH
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length,
- int num_kv_heads, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
- tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
+ int num_kv_heads, int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling,
+ float qk_tanh_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale,
@@ -411,6 +411,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
, mVisionStart(vision_start)
, mVisionLength(vision_length)
, mNumKVHeads(num_kv_heads)
+ , mLayerIdxInCachePool(layer_idx_in_cache_pool)
, mHeadSize(head_size)
, mUnidirectional(unidirectional)
, mQScaling(q_scaling)
@@ -525,6 +526,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
read(d, mVisionStart);
read(d, mVisionLength);
read(d, mNumKVHeads);
+ read(d, mLayerIdxInCachePool);
read(d, mHeadSize);
read(d, mUnidirectional);
read(d, mQScaling);
@@ -721,7 +723,7 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams
#include
#include
#include
@@ -41,8 +43,8 @@ static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"};
static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length,
- int num_kv_heads, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
- tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
+ int num_kv_heads, int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling,
+ float qk_tanh_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, float rotary_embedding_short_m_scale,
@@ -57,9 +59,9 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha,
bool use_cache, bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable,
int spec_decoding_max_generation_length)
- : GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size,
- unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base,
- rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
+ : GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, layer_idx_in_cache_pool,
+ head_size, unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim,
+ rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size,
tp_rank, unfuse_qkv_gemm, context_fmha_type, enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type,
block_sparse_params, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled,
@@ -94,6 +96,7 @@ bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const
case IdxEntry::KV_CACHE_BLOCK_OFFSETS: return useKVCache() && mPagedKVCache;
case IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS: return useKVCache() && mPagedKVCache;
case IdxEntry::HOST_KV_CACHE_POOL_POINTERS: return useKVCache() && mPagedKVCache;
+ case IdxEntry::HOST_KV_CACHE_POOL_MAPPING: return useKVCache() && mPagedKVCache;
case IdxEntry::PAST_KEY_VALUE: return useKVCache() && !mPagedKVCache;
case IdxEntry::KV_CACHE_QUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
case IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
@@ -244,6 +247,11 @@ bool GPTAttentionPlugin::supportsFormatCombination(
// kv cache pool pointers
return inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR;
}
+ else if (useKVCache() && mPagedKVCache && (pos == getIdx(IdxEntry::HOST_KV_CACHE_POOL_MAPPING)))
+ {
+ // kv cache pool mapping
+ return inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
+ }
else if (useKVCache() && mKVCacheQuantMode.hasInt8KvCache()
&& (!mPagedKVCache && (pos == getIdx(IdxEntry::PAST_KEY_VALUE) || pos == nbInputs + 1)))
{
@@ -478,6 +486,7 @@ int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
outputDesc, inputs, outputs, workspace, stream);
}
+ sync_check_cuda_error();
TLLM_LOG_TRACE("Attention plugin stop at layer %d", mLayerIdx);
return 0;
@@ -624,27 +633,36 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
auto const& kvCacheBlockOffsets = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)];
auto const& kvCacheBlockOffsetsShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims;
max_blocks_per_sequence = kvCacheBlockOffsetsShape.d[kvCacheBlockOffsetsShape.nbDims - 1];
- auto const seqStride = getStride(kvCacheBlockOffsetsShape, 0);
+
+ std::int32_t const* host_pool_mapping
+ = static_cast(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_MAPPING)]);
+
+ const int32_t layerToPool = host_pool_mapping[mLayerIdx];
+ auto const seqStride = getStride(kvCacheBlockOffsetsShape, 1);
+ auto const poolStride = getStride(kvCacheBlockOffsetsShape, 0);
auto const seqOffset = seqIdxBeg * seqStride;
+ auto const poolOffset = layerToPool * poolStride;
block_offsets
= reinterpret_cast(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)])
- + seqOffset;
+ + poolOffset + seqOffset;
host_block_offsets
= reinterpret_cast(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)])
- + seqOffset;
+ + poolOffset + seqOffset;
auto const* const typed_host_pool_pointers
= static_cast(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)]);
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
+
auto const blockSize = mTokensPerBlock * mNumKVHeads * mHeadSize;
auto const bytesPerBlock = blockSize * cacheElemSize;
- auto const layerOffset = mLayerIdx * 2 * bytesPerBlock;
+ auto const layerOffset = mLayerIdxInCachePool * 2 * bytesPerBlock;
- host_primary_pool_pointer = reinterpret_cast(typed_host_pool_pointers[0] + layerOffset);
- host_secondary_pool_pointer = reinterpret_cast(typed_host_pool_pointers[1] + layerOffset);
+ host_primary_pool_pointer = reinterpret_cast(typed_host_pool_pointers[layerToPool * 2] + layerOffset);
+ host_secondary_pool_pointer
+ = reinterpret_cast(typed_host_pool_pointers[layerToPool * 2 + 1] + layerOffset);
}
AttentionOutT* context_buf_ = static_cast(outputs[0])
@@ -962,8 +980,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
auto* obj = new GPTAttentionPlugin(p.getScalar("layer_idx").value(),
p.getScalar("num_heads").value(), p.getScalar("vision_start").value(),
p.getScalar("vision_length").value(), p.getScalar("num_kv_heads").value(),
- p.getScalar("head_size").value(), p.getScalar("unidirectional").value(),
- p.getScalar("q_scaling").value(), p.getScalar("qk_tanh_scale").value(),
+ p.getScalar("layer_idx_in_cache_pool").value(), p.getScalar("head_size").value(),
+ p.getScalar("unidirectional").value(), p.getScalar("q_scaling").value(),
+ p.getScalar("qk_tanh_scale").value(),
static_cast(p.getScalar("position_embedding_type").value()),
p.getScalar("rotary_embedding_dim").value(), p.getScalar("rotary_embedding_base").value(),
static_cast(p.getScalar("rotary_embedding_scale_type").value()),
diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h
index aeeae99ce..7982d3c07 100644
--- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h
+++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h
@@ -85,7 +85,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon
{
public:
GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
- int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
+ int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
@@ -182,6 +182,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon
KV_CACHE_BLOCK_OFFSETS,
HOST_KV_CACHE_BLOCK_OFFSETS,
HOST_KV_CACHE_POOL_POINTERS,
+ HOST_KV_CACHE_POOL_MAPPING,
PAST_KEY_VALUE,
KV_CACHE_QUANTIZATION_SCALE,
KV_CACHE_DEQUANTIZATION_SCALE,
diff --git a/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp b/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp
index b4f3e46dc..97fcd3358 100644
--- a/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp
@@ -90,6 +90,7 @@ int IdentityPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer
cudaMemcpyAsync(outputs[0], inputs[0], count, cudaMemcpyDeviceToDevice, stream);
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp b/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp
index 5b8dae3c9..1e33ac63f 100644
--- a/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp
@@ -177,7 +177,7 @@ int LayernormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* input
scale, dynamic_scale, output);
}
#endif
-
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/lruPlugin/lruPlugin.cpp b/cpp/tensorrt_llm/plugins/lruPlugin/lruPlugin.cpp
index 780beef6c..7cc5dae9c 100644
--- a/cpp/tensorrt_llm/plugins/lruPlugin/lruPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/lruPlugin/lruPlugin.cpp
@@ -214,6 +214,7 @@ int lruPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1
{
invokeRGLRUUpdate(lru_params, stream);
}
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.cpp b/cpp/tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.cpp
index 1f39c40c3..c45f512b2 100644
--- a/cpp/tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.cpp
@@ -200,6 +200,7 @@ int MambaConv1dPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
{
invokeMambaConv1dGeneration(mambaConv1dParams, stream);
}
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp b/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp
index fd3affcb4..c81550106 100644
--- a/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp
@@ -203,7 +203,7 @@ int QuantizePerTokenPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
}
#endif // ENABLE_FP8
#endif // ENABLE_BF16
-
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp b/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp
index 1d7e7026d..8bf93d146 100644
--- a/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp
@@ -134,7 +134,7 @@ int QuantizeTensorPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
stream, mProp.maxGridSize[0]);
}
#endif
-
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp b/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp
index f46afadef..372bcf2c8 100644
--- a/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp
@@ -236,7 +236,7 @@ int RmsnormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDe
}
#endif // ENABLE_FP8
#endif // ENABLE_BF16
-
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp
index d82d855f6..8cbdd5c7e 100644
--- a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp
+++ b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp
@@ -347,6 +347,7 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
{
invokeSelectiveScanUpdate(ssm_params, stream);
}
+ sync_check_cuda_error();
return 0;
}
diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt
old mode 100644
new mode 100755
index 65f54c0c3..1f7bf73fb
--- a/cpp/tensorrt_llm/pybind/CMakeLists.txt
+++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt
@@ -41,9 +41,11 @@ set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
"${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(
- ${TRTLLM_PYBIND_MODULE}
- PUBLIC ${SHARED_TARGET} ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python
- ${UNDEFINED_FLAG})
+ ${TRTLLM_PYBIND_MODULE} PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG}
+ ${NO_AS_NEEDED_FLAG})
+target_link_libraries(
+ ${TRTLLM_PYBIND_MODULE} PUBLIC ${Python3_LIBRARIES} ${TORCH_LIBRARIES}
+ torch_python ${UNDEFINED_FLAG})
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE})
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp
index c617cc8ba..2c74104e9 100644
--- a/cpp/tensorrt_llm/pybind/bindings.cpp
+++ b/cpp/tensorrt_llm/pybind/bindings.cpp
@@ -178,19 +178,25 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def(py::self != py::self);
py::class_(m, "ModelConfig")
- .def(py::init(),
- py::arg("vocab_size"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"), py::arg("num_heads"),
- py::arg("hidden_size"), py::arg("data_type"))
+ .def(py::init(),
+ py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"),
+ py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type"))
.def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize)
.def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size"))
- .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1)
- .def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1)
+ .def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1)
+ .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1,
+ py::arg("pipeline_parallelism_rank") = 0)
+ .def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1,
+ py::arg("pipeline_parallelism_rank") = 0)
+ .def("num_kv_heads", &tr::ModelConfig::getNbKvHeads, py::arg("layer_idx"))
+ .def("set_num_kv_heads", &tr::ModelConfig::setNbKvHeads, py::arg("num_kv_heads"))
.def_property_readonly("num_heads", &tr::ModelConfig::getNbHeads)
.def_property_readonly("hidden_size", &tr::ModelConfig::getHiddenSize)
.def_property_readonly("size_per_head", &tr::ModelConfig::getSizePerHead)
.def_property_readonly("data_type", &tr::ModelConfig::getDataType)
- .def_property("num_kv_heads", &tr::ModelConfig::getNbKvHeads, &tr::ModelConfig::setNbKvHeads)
.def_property("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead)
+ .def_property(
+ "num_kv_heads_per_layer", &tr::ModelConfig::getNumKvHeadsPerLayer, &tr::ModelConfig::setNumKvHeadsPerLayer)
.def_property("use_gpt_attention_plugin",
py::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, py::const_),
py::overload_cast(&tr::ModelConfig::useGptAttentionPlugin))
diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp
index ca0746980..4a79a64ee 100644
--- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp
+++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp
@@ -93,7 +93,8 @@ void InitBindings(pybind11::module_& m)
py::enum_(m, "CapacitySchedulerPolicy")
.value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION)
- .value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT);
+ .value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT)
+ .value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH);
py::enum_(m, "ContextChunkingPolicy")
.value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS)
@@ -299,7 +300,8 @@ void InitBindings(pybind11::module_& m)
.def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize);
py::class_(m, "ContextPhaseParams")
- .def(py::init(), py::arg("first_gen_tokens"));
+ .def(py::init(), py::arg("first_gen_tokens"),
+ py::arg("req_id"));
py::class_ request(m, "Request");
request
@@ -631,14 +633,18 @@ void InitBindings(pybind11::module_& m)
auto extendedRuntimePerfKnobConfigSetstate = [](py::tuple state)
{
- if (state.size() != 2)
+ if (state.size() != 4)
{
throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!");
}
- return tle::ExtendedRuntimePerfKnobConfig(state[0].cast(), state[1].cast());
+ return tle::ExtendedRuntimePerfKnobConfig(
+ state[0].cast(), state[1].cast(), state[2].cast(), state[2].cast());
};
auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self)
- { return py::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc()); };
+ {
+ return py::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc(), self.getCudaGraphMode(),
+ self.getCudaGraphCacheSize());
+ };
py::class_(m, "ExtendedRuntimePerfKnobConfig")
.def(
py::init(), py::arg("multi_block_mode") = true, py::arg("enable_context_fmha_fp32_acc") = false)
@@ -646,6 +652,10 @@ void InitBindings(pybind11::module_& m)
&tle::ExtendedRuntimePerfKnobConfig::setMultiBlockMode)
.def_property("enable_context_fmha_fp32_acc", &tle::ExtendedRuntimePerfKnobConfig::getEnableContextFMHAFP32Acc,
&tle::ExtendedRuntimePerfKnobConfig::setEnableContextFMHAFP32Acc)
+ .def_property("cuda_graph_mode", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphMode,
+ &tle::ExtendedRuntimePerfKnobConfig::setCudaGraphMode)
+ .def_property("cuda_graph_cache_size", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphCacheSize,
+ &tle::ExtendedRuntimePerfKnobConfig::setCudaGraphCacheSize)
.def(py::pickle(extendedRuntimePerfKnobConfigGetstate, extendedRuntimePerfKnobConfigSetstate));
auto executorConfigGetState = [](tle::ExecutorConfig const& self)
diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp
index 5403b8cd2..8e1f57e9f 100644
--- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp
+++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp
@@ -803,7 +803,7 @@ void GptDecoderBatched::forwardDispatch(
}
}
-GptDecoderBatched::TokenPtr GptDecoderBatched::forwardAsync(
+GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync(
decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@@ -813,7 +813,7 @@ GptDecoderBatched::TokenPtr GptDecoderBatched::forwardAsync(
CudaEvent eventStop{};
mRuntimeStream->record(eventStop);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
- return std::make_unique(std::move(eventStop), input.active);
+ return std::make_unique(std::move(eventStop), input.active);
}
void GptDecoderBatched::forwardDecoder(
@@ -1019,12 +1019,12 @@ void GptDecoderBatched::forwardDecoder(
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
-void GptDecoderBatched::updateFinished(decoder_batch::Token const& token)
+void GptDecoderBatched::updateFinished(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
{
- if (token.active[i] && !mFinished[i])
+ if (decoderFinishEvent.active[i] && !mFinished[i])
{
auto finishedSum = ITensor::slice(mJointDecodingOutput->finishedSum, i, 1);
mFinished[i] = mFinished[i]
@@ -1035,25 +1035,25 @@ void GptDecoderBatched::updateFinished(decoder_batch::Token const& token)
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
-void GptDecoderBatched::forwardSync(decoder_batch::Token const& token)
+void GptDecoderBatched::forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
- token.event.synchronize();
+ decoderFinishEvent.event.synchronize();
- updateFinished(token);
+ updateFinished(decoderFinishEvent);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
-void GptDecoderBatched::forwardSync(
- decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input)
+void GptDecoderBatched::forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent,
+ decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
- token.event.synchronize();
+ decoderFinishEvent.event.synchronize();
forwardDispatch(output, input, ForwardType::kSYNC);
- updateFinished(token);
+ updateFinished(decoderFinishEvent);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@@ -1231,7 +1231,7 @@ void GptDecoderBatched::forwardAsync(decoder::Output& output, decoder::Input con
batchOutput.cacheIndirection = output.cacheIndirection;
batchOutput.sequenceLengths = output.sequenceLengths;
- mForwardToken = forwardAsync(batchOutput, batchInput);
+ mDecoderFinishEvent = forwardAsync(batchOutput, batchInput);
mBufferManager.setZero(*mFinishedSum);
kernels::reduce(
*mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mRuntimeStream);
@@ -1243,7 +1243,7 @@ void GptDecoderBatched::forwardAsync(decoder::Output& output, decoder::Input con
void GptDecoderBatched::forwardSync()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
- forwardSync(*mForwardToken);
+ forwardSync(*mDecoderFinishEvent);
// wait for mFinishedSum to be updated
mForwardEvent.synchronize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
diff --git a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp
index 620be923b..da58300fa 100644
--- a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp
+++ b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp
@@ -85,6 +85,8 @@ std::vector buildLayerTypes(
auto constexpr layerNameAttention = "attention";
auto constexpr layerNameRecurrent = "recurrent";
+ auto constexpr layerNameLinear = "linear";
+ auto constexpr layerNameNoop = "no_op";
// The json field specifies a "group" of layers, which gets repeated multiple times
// Note that the total number of layers does not need to be a multiple of a layer
@@ -102,9 +104,17 @@ std::vector buildLayerTypes(
{
result[i] = ModelConfig::LayerType::kRECURRENT;
}
+ else if (layerStringTypes[i % groupSize] == layerNameLinear)
+ {
+ result[i] = ModelConfig::LayerType::kLINEAR;
+ }
+ else if (layerStringTypes[i % groupSize] == layerNameNoop)
+ {
+ result[i] = ModelConfig::LayerType::kNOOP;
+ }
else
{
- TLLM_LOG_ERROR("Unknown layer type: %s", layerStringTypes[i % groupSize].c_str());
+ TLLM_LOG_WARNING("Unknown layer type: %s, assuming attention", layerStringTypes[i % groupSize].c_str());
}
}
@@ -147,9 +157,25 @@ ModelConfig createModelConfig(
auto const mlpHiddenSize = parseJsonFieldOptional(config, mlpHiddenSizeField);
- auto modelConfig = ModelConfig{vocabSize, numAttentionLayers, numRnnLayers, numHeads, hiddenSize, dataType};
+ auto numKvHeadsPerAttentionLayer
+ = parseJsonFieldOr>(config, "num_kv_heads_per_layer", std::vector());
+
+ auto modelConfig
+ = ModelConfig{vocabSize, numLayers, numAttentionLayers, numRnnLayers, numHeads, hiddenSize, dataType};
+
+ if (!numKvHeadsPerAttentionLayer.empty())
+ {
+ std::transform(numKvHeadsPerAttentionLayer.cbegin(), numKvHeadsPerAttentionLayer.cend(),
+ numKvHeadsPerAttentionLayer.begin(),
+ [tensorParallelism](SizeType32 const numKvHeads) { return std::max(numKvHeads / tensorParallelism, 1); });
+ modelConfig.setNumKvHeadsPerLayer(numKvHeadsPerAttentionLayer);
+ }
+ else
+ {
+ modelConfig.setNbKvHeads(numKvHeads);
+ }
+
modelConfig.setSizePerHead(sizePerHead);
- modelConfig.setNbKvHeads(numKvHeads);
modelConfig.setLayerTypes(layerTypes);
// Set logits datatype
@@ -269,9 +295,24 @@ void parseLora(ModelConfig& modelConfig, Json const& json, Json const& pluginCon
if (loraTargetModules.has_value())
{
+ auto const& loraModuleNames = loraTargetModules.value();
+ auto const& numKvHeadsPerLayer = modelConfig.getNumKvHeadsPerLayer();
+ if (!loraModuleNames.empty())
+ {
+ TLLM_CHECK_WITH_INFO(std::all_of(numKvHeadsPerLayer.cbegin(), numKvHeadsPerLayer.cend(),
+ [firstNumKvHeads = numKvHeadsPerLayer[0]](SizeType32 numKvHeads)
+ { return numKvHeads == firstNumKvHeads; }),
+ "LORA with a VGQA model is not supported");
+ }
+ // TODO(oargov): don't assume all layers have the same num_kv_heads to support VGQA
+ auto const numKvHeads = numKvHeadsPerLayer.empty() ? modelConfig.getNbHeads() : numKvHeadsPerLayer[0];
+ bool hasMoE = !engineVersionNone && json.at("pretrained_config").contains("moe");
+ auto const numExperts = hasMoE
+ ? json.at("pretrained_config").at("moe").at("num_experts").template get()
+ : SizeType32{0};
modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(), modelConfig.getHiddenSize(),
- modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), modelConfig.getNbKvHeads(),
- modelConfig.getSizePerHead(), tensorParallelism));
+ modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), numKvHeads, modelConfig.getSizePerHead(),
+ tensorParallelism, numExperts));
}
modelConfig.setMaxLoraRank(loraMaxRank);
diff --git a/cpp/tensorrt_llm/runtime/gptSession.cpp b/cpp/tensorrt_llm/runtime/gptSession.cpp
index c5d4dda55..c5bc84cf1 100644
--- a/cpp/tensorrt_llm/runtime/gptSession.cpp
+++ b/cpp/tensorrt_llm/runtime/gptSession.cpp
@@ -219,8 +219,13 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea
// tokens, when enabling cyclic kv cache.
auto const useOneMoreBlock = maxBeamWidth > 1 && maxSequenceLength > maxAttentionWindow;
- auto const localNbLayers = mModelConfig.getNbAttentionLayers(mWorldConfig.getPipelineParallelism());
- auto const nbKvHeads = mModelConfig.getNbKvHeads();
+ auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = mModelConfig.getNumKvHeadsPerLayerLocalRange(
+ mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
+ TLLM_CHECK_WITH_INFO(std::all_of(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd,
+ [firstNumKvHeads = *numKvHeadsPerLayerBegin](SizeType32 numKvHeads)
+ { return numKvHeads == firstNumKvHeads; }),
+ "Deprecated session API does not support multiple cache pools, use the newer executor API instead");
+
auto const sizePerHead = mModelConfig.getSizePerHead();
bool constexpr enableBlockReuse{false};
bool enableDiffMaxAttenWin = false;
@@ -235,7 +240,8 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea
TLLM_CHECK_WITH_INFO(maxBeamWidth == 1 || !enableDiffMaxAttenWin,
"Can't support layer-wise max_attention_window with beam search. Please use a unified max_attention_window for "
"all layers.");
- mKvCacheManager = std::make_shared(localNbLayers, nbKvHeads, sizePerHead, tokensPerBlock,
+ mKvCacheManager = std::make_shared(
+ std::vector(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd), sizePerHead, tokensPerBlock,
blocksInPrimaryPool, blocksInSecondaryPool, maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength,
useOneMoreBlock, mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.onboardBlocks);
@@ -253,6 +259,7 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea
for (auto& buffers : mBuffers)
{
buffers->transformerBuffers->setKvPoolPointers(mKvCacheManager.get());
+ buffers->transformerBuffers->setKvPoolMapping(mKvCacheManager.get());
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
diff --git a/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp b/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp
index 8ecb2061c..8f543f9ed 100644
--- a/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp
+++ b/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp
@@ -11,7 +11,9 @@
*/
#include "tensorrt_llm/runtime/lookaheadBuffers.h"
+#include "iTensor.h"
#include "tensorrt_llm/common/logger.h"
+#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
#include "tensorrt_llm/runtime/common.h"
namespace tensorrt_llm::runtime
@@ -28,8 +30,6 @@ LookaheadDecodingBuffers::LookaheadDecodingBuffers(
, positionIds(
bufferManager.gpu(ITensor::makeShape({maxNumSequences, maxTokensPerStep}), nvinfer1::DataType::kINT32))
{
- TLLM_LOG_DEBUG(
- "LookaheadDecodingBuffers, maxNumSequences = %d, maxTokensPerStep = %d", maxNumSequences, maxTokensPerStep);
}
LookaheadRuntimeBuffers::LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
@@ -40,11 +40,11 @@ LookaheadRuntimeBuffers::LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeTy
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(maxBeamWidth == 1, "Lookahead decoding does not support beam search");
- // auto const tokensPerStep = modelConfig.getMaxTokensPerStep();
auto const tokensPerStep = modelConfig.getMaxDecodingTokens();
auto const numPackedMasks = static_cast(tensorrt_llm::common::divUp(tokensPerStep, 32));
- // Copy buffers to device
+ cumSumLength = manager.pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
+
packedMasksDevice
= manager.gpu(ITensor::makeShape({maxBatchSize * tokensPerStep, numPackedMasks}), nvinfer1::DataType::kINT32);
positionOffsetsDevice = manager.gpu(ITensor::makeShape({maxBatchSize, tokensPerStep}), nvinfer1::DataType::kINT32);
@@ -76,24 +76,59 @@ void LookaheadRuntimeBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType
auto const tokensPerStep = modelConfig.getMaxDecodingTokens();
+ manager.copy(seqSlots, *batchSlotsHostCopy);
+ manager.copy(*decoderLookaheadBuffers.generationLengths, *generationLengthsHostCopy);
manager.copy(*decoderLookaheadBuffers.positionOffsets, *positionOffsetsHostCopy);
manager.copy(*decoderLookaheadBuffers.packedMasks, *packedMaskHostCopy);
manager.copy(*decoderLookaheadBuffers.positionIds, *positionIdsHostCopy);
- manager.copy(seqSlots, *batchSlotsHostCopy);
- manager.copy(*decoderLookaheadBuffers.generationLengths, *generationLengthsHostCopy);
manager.getStream().synchronize();
BufferRange batchSlotsRange(*batchSlotsHostCopy);
+ BufferRange cumSumLengthRange(*cumSumLength);
+
+ SizeType32 maxGenerationLength = 0;
+ for (SizeType32 bi = 0; bi < numGenSequences; bi++)
+ {
+ SizeType32 gbi = batchSlotsRange[bi + numCtxSequences];
+ SizeType32 theLength = BufferRange(*generationLengthsHostCopy)[gbi];
+ maxGenerationLength = std::max(maxGenerationLength, theLength);
+ }
+
+ auto positionOffsetShape = positionOffsetsHost->getShape();
+ positionOffsetShape.d[1] = maxGenerationLength;
+ positionOffsetsHost->reshape(positionOffsetShape);
+ positionOffsetsDevice->reshape(positionOffsetShape);
+
+ auto positionIdsShape = positionIdsHostCopy->getShape();
+ auto positionIdsShape1D = ITensor::makeShape({ITensor::volume(positionIdsShape)});
+ positionIdsHostCopy->reshape(positionIdsShape1D);
+ positionIdsHost->reshape(positionIdsShape1D);
+
+ cumSumLengthRange[0] = 0;
for (SizeType32 bi = 0; bi < numGenSequences; bi++)
{
SizeType32 gbi = batchSlotsRange[bi + numCtxSequences];
+ SizeType32 theLength = BufferRange(*generationLengthsHostCopy)[gbi];
+
manager.copy(*ITensor::at(generationLengthsHostCopy, {gbi}), *ITensor::at(generationLengthsHost, {bi}));
- manager.copy(*ITensor::at(positionOffsetsHostCopy, {gbi}), *ITensor::at(positionOffsetsHost, {bi}));
- manager.copy(*ITensor::slice(packedMaskHostCopy, gbi * tokensPerStep, tokensPerStep),
- *ITensor::slice(packedMaskHost, bi * tokensPerStep, tokensPerStep));
- manager.copy(*ITensor::at(positionIdsHostCopy, {gbi}), *ITensor::at(positionIdsHost, {bi}));
+
+ manager.copy(*ITensor::slice(positionOffsetsHostCopy, {gbi, 0}, theLength),
+ *ITensor::slice(positionOffsetsHost, {bi, 0}, theLength));
+
+ manager.copy(*ITensor::slice(packedMaskHostCopy, gbi * tokensPerStep, theLength),
+ *ITensor::slice(packedMaskHost, cumSumLengthRange[0], theLength));
+
+ manager.copy(*ITensor::slice(positionIdsHostCopy, gbi * tokensPerStep, theLength),
+ *ITensor::slice(positionIdsHost, cumSumLengthRange[0], theLength));
+
+ cumSumLengthRange[0] += theLength;
}
+
+ positionIdsHostCopy->reshape(positionIdsShape);
+ positionIdsHost->reshape(positionIdsShape);
+ positionIdsDevice->reshape(positionIdsShape);
+
manager.copy(*ITensor::slice(generationLengthsHost, 0, numGenSequences),
*ITensor::slice(generationLengthsDevice, 0, numGenSequences));
manager.copy(*ITensor::slice(positionOffsetsHost, 0, numGenSequences),
@@ -102,6 +137,7 @@ void LookaheadRuntimeBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType
*ITensor::slice(packedMasksDevice, 0, numGenSequences * tokensPerStep));
manager.copy(
*ITensor::slice(positionIdsHost, 0, numGenSequences), *ITensor::slice(positionIdsDevice, 0, numGenSequences));
+ positionIdsDevice->reshape(ITensor::makeShape({cumSumLengthRange[0]}));
manager.getStream().synchronize();
diff --git a/cpp/tensorrt_llm/runtime/loraModule.cpp b/cpp/tensorrt_llm/runtime/loraModule.cpp
index 2716e78cb..8a8e2e559 100644
--- a/cpp/tensorrt_llm/runtime/loraModule.cpp
+++ b/cpp/tensorrt_llm/runtime/loraModule.cpp
@@ -21,7 +21,7 @@ namespace tensorrt_llm::runtime
std::vector LoraModule::createLoraModules(std::vector const& loraModuleNames,
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
- SizeType32 attentionHeadSize, SizeType32 tpSize)
+ SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts)
{
auto const hidden = hiddenSize * tpSize;
auto const mlpHidden = mlpHiddenSize * tpSize;
@@ -55,10 +55,10 @@ std::vector LoraModule::createLoraModules(std::vector c
case ModuleType::kMLP_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break;
// TODO(TRTLLM-379): Support MOE LoRA weights
case ModuleType::kMOE_H_TO_4H:
- case ModuleType::kMOE_GATE:
- case ModuleType::kMOE_4H_TO_H:
- case ModuleType::kMOE_ROUTER:
- case ModuleType::kMLP_ROUTER:
+ case ModuleType::kMOE_GATE: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
+ case ModuleType::kMOE_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break;
+ case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break;
+ case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break;
case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName);
}
}
diff --git a/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp b/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp
index c4f9d888b..6b9c1175f 100644
--- a/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp
+++ b/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp
@@ -15,11 +15,11 @@
*/
#include "tensorrt_llm/runtime/rnnStateBuffers.h"
+#include "iBuffer.h"
#include "tensorrt_llm/runtime/runtimeBuffers.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
using namespace tensorrt_llm::runtime;
-namespace tc = tensorrt_llm::common;
RnnStateBuffers::RnnStateBuffers()
{
@@ -92,8 +92,8 @@ RnnStateBuffers::RnnStateBuffers(
auto statePtrsShape = ITensor::makeShape({localNbLayers});
slotMappingDevice = bufferManager.gpu(slotMappingShape, nvinfer1::DataType::kINT32);
slotMappingHost = BufferManager::cpu(slotMappingShape, nvinfer1::DataType::kINT32);
- rnnStatePtrs = BufferManager::cpu(statePtrsShape, nvinfer1::DataType::kINT64);
- convStatePtrs = BufferManager::cpu(statePtrsShape, nvinfer1::DataType::kINT64);
+ rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType::value);
+ convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType