diff --git a/README.md b/README.md index 0b88d361f..0da5128a5 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ TensorRT-LLM [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/) [![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/) [![cuda](https://img.shields.io/badge/cuda-12.5.1-green)](https://developer.nvidia.com/cuda-downloads) -[![trt](https://img.shields.io/badge/TRT-10.3.0-green)](https://developer.nvidia.com/tensorrt) -[![version](https://img.shields.io/badge/release-0.13.0.dev-green)](./tensorrt_llm/version.py) +[![trt](https://img.shields.io/badge/TRT-10.4.0-green)](https://developer.nvidia.com/tensorrt) +[![version](https://img.shields.io/badge/release-0.14.0.dev-green)](./tensorrt_llm/version.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) [Architecture](./docs/source/architecture/overview.md)   |   [Results](./docs/source/performance/perf-overview.md)   |   [Examples](./examples/)   |   [Documentation](./docs/source/) @@ -17,6 +17,12 @@ TensorRT-LLM
## Latest News +* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12 +[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0) +
+ +
+ * [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup [➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link) diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index 45632350c..b901a17bc 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -159,6 +159,8 @@ struct BenchmarkParams std::optional sinkTokenLength{std::nullopt}; bool multiBlockMode{true}; bool enableContextFMHAFP32Acc{false}; + bool cudaGraphMode{false}; + SizeType32 cudaGraphCacheSize{0}; // lora / peft params std::optional loraDir{std::nullopt}; @@ -470,7 +472,38 @@ class Recorder mRequestBenchInfos[requestId].firstTokenSeen = true; } - mRequestBenchInfos[requestId].outputLength += 1; + mRequestBenchInfos[requestId].decodingIter += 1; + } + + void recordToken(uint64_t requestId, std::list const& responseTensors) + { + int32_t outputLength = 1; + for (auto& tensor : responseTensors) + { + if (tensor.name == inference_request::kSequenceLengthTensorName) + { + // Tensor of shape nBeams, and we only need the first one + outputLength = *(bufferCast(*(tensor.tensor))); + break; + } + } + + mRequestBenchInfos[requestId].outputLength += outputLength; + this->recordToken(requestId); + } + + void recordToken(uint64_t requestId, texec::Response const& response) + { + auto outputTokenIds = response.getResult().outputTokenIds; + + int32_t outputLength = 1; + for (auto const& beam : outputTokenIds) + { + outputLength = std::max(static_cast(beam.size()), outputLength); + } + + mRequestBenchInfos[requestId].outputLength += outputLength; + this->recordToken(requestId); } void recordEnd(uint64_t requestId, std::list const& responseTensors, bool hasError) @@ -500,7 +533,7 @@ class Recorder } else { - this->recordToken(requestId); + this->recordToken(requestId, responseTensors); } } @@ -532,7 +565,7 @@ class Recorder } else { - this->recordToken(requestId); + this->recordToken(requestId, response); } } } @@ -821,8 +854,9 @@ class ExecutorServer benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks); texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8, std::nullopt, benchmarkParams.loraHostCacheSize); - texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig( - benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc); + texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode, + benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, + benchmarkParams.cudaGraphCacheSize); texec::ExecutorConfig executorConfig( maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true); executorConfig.setGpuWeightsPercent(benchmarkParams.gpuWeightsPercent); @@ -940,7 +974,7 @@ class ExecutorServer { if (!warmup && !response.hasError()) { - mRecorder->recordToken(reqId); + mRecorder->recordToken(reqId, response); } } } @@ -1228,7 +1262,7 @@ class GptServer { if (errMsg.empty()) { - mRecorder->recordToken(requestId); + mRecorder->recordToken(requestId, response_tensors); } } } @@ -1458,8 +1492,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType : benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead() : texec::DecodingMode::Auto(), benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices); - optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig( - benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc); + optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode, + benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, benchmarkParams.cudaGraphCacheSize); auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json"); auto const worldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(), @@ -1895,7 +1929,8 @@ int main(int argc, char* argv[]) options.add_options()("return_generation_logits", "Whether to return generation logits.", cxxopts::value()->default_value("false")); - options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.", + options.add_options()("scheduler_policy", + "Choose scheduler policy between max_utilization/guaranteed_no_evict/static_batch.", cxxopts::value()->default_value("guaranteed_no_evict")); options.add_options()("first_batch_delay", @@ -1946,6 +1981,12 @@ int main(int argc, char* argv[]) cxxopts::value()->default_value("true")); options.add_options()( "encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value()); + options.add_options()("cuda_graph_mode", "When enabled, inference is executed with cuda graph.", + cxxopts::value()->default_value("false")); + options.add_options()("cuda_graph_cache_size", + "Specify how many cuda graphs are cached in the runtime. Larger cache gives better perf, but consumes more GPU " + "memory.", + cxxopts::value()->default_value("0")); options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation", cxxopts::value()->default_value("false")); @@ -2131,6 +2172,12 @@ int main(int argc, char* argv[]) // Argument: enable_context_fmha_fp32_acc benchmarkParams.enableContextFMHAFP32Acc = result["enable_context_fmha_fp32_acc"].as(); + // Argument: cuda_graph_mode + benchmarkParams.cudaGraphMode = result["cuda_graph_mode"].as(); + + // Argument: cuda_graph_mode + benchmarkParams.cudaGraphCacheSize = result["cuda_graph_cache_size"].as(); + std::optional padId; // Argument: Padding token id if (result.count("pad_id")) @@ -2168,6 +2215,10 @@ int main(int argc, char* argv[]) { capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT; } + else if (capacitySchedulerPolicyArg == "static_batch") + { + capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kSTATIC_BATCH; + } else { TLLM_LOG_ERROR("Unexpected scheduler policy: " + capacitySchedulerPolicyArg); diff --git a/benchmarks/python/gpt_benchmark.py b/benchmarks/python/gpt_benchmark.py index 04ba2ab0f..ce06c9f9f 100644 --- a/benchmarks/python/gpt_benchmark.py +++ b/benchmarks/python/gpt_benchmark.py @@ -80,7 +80,7 @@ def __init__(self, args, batch_sizes, in_out_lens, gpu_weights_percents, kv_cache_type = KVCacheType.CONTINUOUS if hasattr(self, 'kv_cache_type'): - kv_cache_type = self.kv_cache_type + kv_cache_type = KVCacheType(self.kv_cache_type) else: if hasattr(self, 'paged_kv_cache'): kv_cache_type = KVCacheType.PAGED if self.paged_kv_cache == True else KVCacheType.CONTINUOUS diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 5f26170e9..959e2c39c 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -282,14 +282,37 @@ class GenerationRequest std::vector> mCacheBlockIds; }; -// BlockManager manages overall metadata of KVCacheBlocks in a layer of the -// network. Layers are expected to be symmetric, so the metadata can be -// reused for all layers of the network. -// The array of cache blocks for a layer is called a pool. -// Each pool has shape [max_blocks, 2, num_heads, tokens_per_block, head_size]. -// Size per block and number of blocks per pool are pre-determined and set in -// constructor. These should not be changed after. -// Block shape is [2, num_heads, tokens_per_block, head_size]. +// attach metadata to a pool pointer +class KVCacheBlockPool +{ +public: + SizeType32 numKvHeads; + SizeType32 numLayers; + SizeType32 blockSize; + + // Memory pools. Primary is fast memory, secondary is slower memory used for offloading. + runtime::ITensor::SharedPtr primaryPtr; + runtime::ITensor::SharedPtr secondaryPtr; + + KVCacheBlockPool(SizeType32 numKvHeads, SizeType32 numLayers, SizeType32 blockSize, + runtime::ITensor::SharedPtr primaryPtr = nullptr, runtime::ITensor::SharedPtr secondaryPtr = nullptr) + : numKvHeads(numKvHeads) + , numLayers(numLayers) + , blockSize(blockSize) + , primaryPtr(std::move(primaryPtr)) + , secondaryPtr(std::move(secondaryPtr)) + { + } +}; + +// The BlockManager manages the metadata of KVCacheBlocks. +// It manages multiple arrays of cache blocks called pools. +// Layers with the same number of kv heads are grouped under the same pool. +// Each pool has shape [max_blocks, num_layers, 2, num_kv_heads, tokens_pre_block, head_size], where num_layers refers +// to the number of layers with the same num_kv_heads that share that pool. +// The metadata of KVCacheBlocks is shared between layers, so each block spans all of the managed pool - an allocated +// block matches some chunk of memory in each pool. The shape of the chunk in every pool is [2, num_kv_heads, +// tokens_per_block, head_size]. The size per block and number of blocks are pre-determined and set in the constructor. // BlockManager maintains a list of free blocks at any time. // Alloc pops off the block at the front, and Free pushes it back to the vector. // BlockManager maintains a vector of lists of seqSlotIdx to allocated blocks @@ -300,7 +323,7 @@ class BlockManager using SizeType32 = tensorrt_llm::runtime::SizeType32; using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType; - explicit BlockManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, + explicit BlockManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType = CacheType::kSELF); @@ -338,7 +361,7 @@ class BlockManager [[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept { - return mFreePrimaryBlocks.size(); + return mFreePrimaryBlocksSize; } [[nodiscard]] SizeType32 getNumAllocTotalBlocks() const @@ -381,21 +404,26 @@ class BlockManager return mTokensPerBlock; } - //! \brief Get size of one K/V cache block in one layer. - //! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead] - [[nodiscard]] SizeType32 getBlockSize() const + //! \brief Get size of one K/V cache block in one layer for the specified pool. + //! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead] in the specified pool. + [[nodiscard]] SizeType32 getBlockSize(SizeType32 poolIdx) const + { + return mPools.at(poolIdx).blockSize; + } + + [[nodiscard]] SizeType32 getNumPools() const noexcept { - return mBlockSize; + return mPools.size(); } - [[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool() const noexcept + [[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const { - return mPrimaryPool; + return mPools.at(poolIdx).primaryPtr; } - [[nodiscard]] runtime::ITensor::SharedPtr getSecondaryPool() const noexcept + [[nodiscard]] runtime::ITensor::SharedPtr getSecondaryPool(SizeType32 poolIdx) const { - return mSecondaryPool; + return mPools.at(poolIdx).secondaryPtr; } [[nodiscard]] SizeType32 getNumLayers() const @@ -403,10 +431,32 @@ class BlockManager return mNumLayers; } + [[nodiscard]] SizeType32 getNumPrimaryBlocks() const + { + return mNumPrimaryBlocks; + } + + [[nodiscard]] SizeType32 getNumSecondaryBlocks() const + { + return mNumSecondaryBlocks; + } + + [[nodiscard]] CacheType getCacheType() const + { + return mCacheType; + } + + [[nodiscard]] SizeType32 getLayerPoolIdx(SizeType32 layerIdx) const + { + return mLayerToPool.at(layerIdx); + } + //! \brief Get index in pool to K or V block. //! \param blockId the blockId as returned by getBlockId() //! \param fieldIdx either 0 (K) or 1 (V), - [[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(KVCacheBlock::IdType blockId, SizeType32 fieldIdx) const; + //! \param poolIdx the index of the pool for which the index is calculated (each pool has different strides) + [[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex( + KVCacheBlock::IdType blockId, SizeType32 fieldIdx, SizeType32 poolIdx) const; //! \brief Bring offloaded block from secondary to primary memory. //! \details Does nothing of block is already in primary memory. @@ -451,7 +501,8 @@ class BlockManager void claimLeafBlock(KVCacheBlock& block); //! \brief Compute pointer to raw KV block (K & V, all layers). - [[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(std::shared_ptr block) const; + [[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer( + std::shared_ptr block, SizeType32 poolIdx) const; //! \brief Copy content of src block to dst. void copyBlock(BlockPtr src, BlockPtr dst); @@ -460,23 +511,30 @@ class BlockManager // Number of blocks in pools SizeType32 mNumPrimaryBlocks; SizeType32 mNumSecondaryBlocks; - // List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory, - // we maintain separate queues for these. + // List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory. + // We maintain separate queues for these. + // We cache size of each queue instead of calling std::list::size, because size is O(N) function. + SizeType32 mFreePrimaryBlocksSize; + SizeType32 mFreeSecondaryBlocksSize; FreeBlocksQueue mFreePrimaryBlocks; FreeBlocksQueue mFreeSecondaryBlocks; // List of allocated blocks for each sequences std::vector> mAllocatedBlocksPerSeq; - // Memory pools. Primary is fast memory, secondary is slower memory used for offloading. - runtime::ITensor::SharedPtr mPrimaryPool; - runtime::ITensor::SharedPtr mSecondaryPool; + + // Pool per unique numKvHeads in the model + std::vector mPools; + // Matching of model layers to their pools + std::vector mLayerToPool; + // Whether offloaded blocks should be onboarded before reuse. bool mOnboardBlocks; // Buffer manager runtime::BufferManager mBufferManager; + + // Size of a single KV heads + SizeType32 mSizePerHead; // Number of layers SizeType32 mNumLayers; - // Volume of [numKvHeads, tokensPerBlock, sizePerHead] - SizeType32 mBlockSize; // Used to keep track of number of free blocks during scheduling SizeType32 mSchedulingNumFreeBlocks; // Number of tokens per one block @@ -502,12 +560,18 @@ class KVCacheManager using CudaStreamPtr = std::shared_ptr; using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType; - KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, + KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock, CudaStreamPtr stream, bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF); + KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, + SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, + SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock, + CudaStreamPtr stream, bool enableBlockReuse = true, bool onboardBlocks = true, + CacheType cacheType = CacheType::kSELF); + void allocatePools(nvinfer1::DataType dtype, bool useUvm = false); void startScheduling(); @@ -577,11 +641,11 @@ class KVCacheManager /// @return The number of blocks [[nodiscard]] SizeType32 getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const; - /// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for - /// maxNewTokens) + /// @brief Function that computes the number of KV cache blocks remaining to advance a request to completion (i.e. + /// for maxNewTokens); the allocated blocks are excluded /// @param req The request for which we need to calculate the number of needed KV cache blocks /// @return The number of blocks - [[nodiscard]] SizeType32 getNeededBlocksToCompletion(LlmRequest const& req) const; + [[nodiscard]] SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req) const; void addContextTokens(SizeType32 seqSlotIdx, SizeType32 numTokens); @@ -603,6 +667,8 @@ class KVCacheManager [[nodiscard]] runtime::ITensor::UniquePtr getBlockPoolPointers() const; + [[nodiscard]] runtime::ITensor::UniquePtr getLayerToPoolMapping() const; + void getBlockOffsetsOfBatch( runtime::ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const; @@ -610,18 +676,16 @@ class KVCacheManager SizeType32 copyBlockOffsets( runtime::ITensor& output, SizeType32 outputSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const; - // Volume of [2, numKvHeads, tokensPerBlock, sizePerHead] - [[nodiscard]] static SizeType32 constexpr calculatePageSize(tensorrt_llm::runtime::ModelConfig const& modelConfig) - { - return 2 * modelConfig.getNbKvHeads() * modelConfig.getTokensPerBlock() * modelConfig.getSizePerHead(); - } - - // numLayers * 2 * numKvHeads * sizePerHead - [[nodiscard]] static SizeType32 constexpr calculateCacheSizePerToken( + // Sum of numLayers * 2 * numKvHeads * sizePerHead for each pool + [[nodiscard]] static SizeType32 calculateCacheSizePerToken( tensorrt_llm::runtime::ModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig) { - return modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads() - * modelConfig.getSizePerHead(); + // NOTE: We expect the initialization of modelConfig to have already taken the tp size into account and do not + // address it here + // consider only local layers for the calculation + return modelConfig.getSumLocalKvHeads( + worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()) + * 2 * modelConfig.getSizePerHead(); } [[nodiscard]] static std::tuple const calculateMaxNumBlocks(KvCacheConfig const& config, @@ -640,7 +704,7 @@ class KVCacheManager [[nodiscard]] bool isCrossKv() const { - return mCacheType == CacheType::kCROSS; + return mBlockManager.getCacheType() == CacheType::kCROSS; } //! \brief Find first new block that must be allocated for context phase and return it's concatenated token vector. @@ -691,8 +755,6 @@ class KVCacheManager runtime::ITensor::SharedPtr mSequenceBlockIndices; // Whether to cache KV pages for reuse bool mEnableBlockReuse; - // KV cache type (self or cross) - CacheType mCacheType; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 81b91e24a..1738cc428 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -91,9 +91,9 @@ class BlockIterator }; [[nodiscard]] BlockIterator getBlockBeginIt( - KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam); + KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam, SizeType32 poolIdx); [[nodiscard]] BlockIterator getBlockEndIt( - KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam); + KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam, SizeType32 poolIdx); } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 46c808de0..fed9dd21e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -400,7 +400,8 @@ class GenericLlmRequest TLLM_CHECK_WITH_INFO(mInputTokenExtraIds.has_value() && mInputTokenExtraIds.value(), "Input token extra ids must be provided when enabling kv cache reuse with prompt table"); TLLM_CHECK_WITH_INFO(mInputTokenExtraIds.value()->size() == static_cast(mOrigPromptLen), - "inputTokenExtraIds vector size must be the same as input token vector size."); + "inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).", + mInputTokenExtraIds.value()->size(), static_cast(mOrigPromptLen)); } } @@ -411,7 +412,7 @@ class GenericLlmRequest /// @brief Get the params of the context /// @return The params of the context - std::optional const& getContextPhaseParams() const noexcept + [[nodiscard]] std::optional const& getContextPhaseParams() const noexcept { return mContextPhaseParams; } @@ -423,10 +424,10 @@ class GenericLlmRequest /// @brief Get the state params of the context /// @return The state params of the context - executor::ContextPhaseState const& getContextPhaseState() const + [[nodiscard]] executor::DataTransceiverState const& getDataTransceiverState() const { TLLM_CHECK(mContextPhaseParams.has_value()); - return *static_cast(mContextPhaseParams.value().getState()); + return *static_cast(mContextPhaseParams.value().getState()); } /// @brief Get total number of tokens for this req (prompt + generated) @@ -659,6 +660,11 @@ class GenericLlmRequest return mSequenceIndex > 0; } + [[nodiscard]] RequestIdType getParentRequestId() const + { + return mParentRequestId; + } + /// @brief Return a vector of the last-generated tokens of shape [num_beams] [[nodiscard]] VecTokens const& getLastTokens() { @@ -858,14 +864,46 @@ class GenericLlmRequest return mOrigPromptLen; } - void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen) + [[nodiscard]] SizeType32 getPromptLen() const { - mPrepopulatedPromptLen = prepopulatedPromptLen; + return mPromptLen; } - [[nodiscard]] SizeType32 getPrepopulatedPromptLen() const + void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock) { - return mPrepopulatedPromptLen; + auto const promptLen = getPromptLen(); + TLLM_CHECK(prepopulatedPromptLen < promptLen); + + if (prepopulatedPromptLen > 0) + { + // Currently, the runtime process is to apply for cache first and then determine prepopulation. + // Use the prepopulated length to advance the context position and decrease chunk size if necessary. + if (isFullContextRequest()) + { + setContextCurrentPosition(prepopulatedPromptLen); + setContextChunkSize(promptLen); + } + else + { + auto chunkSize = getContextChunkSize(); + if (prepopulatedPromptLen + chunkSize < promptLen) + { + // make sure to end at block boundary after current chunk + auto const flooredEndPosition + = (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock; + chunkSize = flooredEndPosition - prepopulatedPromptLen; + TLLM_CHECK(chunkSize <= getContextChunkSize()); + } + setContextCurrentPosition(prepopulatedPromptLen); + setContextChunkSize(chunkSize); + } + if (!isLastContextChunk()) + { + TLLM_CHECK_WITH_INFO((getContextCurrentPosition() + getContextChunkSize()) % kvTokensPerBlock == 0, + "To prevent cache fragmentation, the context position after current chunk should be divisible " + "by the number of tokens per block, except for the last chunk."); + } + } } void setDraftTokens(std::shared_ptr const& draftTokens) @@ -1276,7 +1314,7 @@ class GenericLlmRequest } // TODO: fill the rank ids result.contextPhaseParams = executor::ContextPhaseParams{ - std::move(firstGenTokens), mContextPhaseParams.value().releaseState()}; + std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState()}; } auto const calculateNbTokensOut = [this](SizeType32 maxNbTokens) @@ -1513,8 +1551,8 @@ class GenericLlmRequest { if (mInputTokenExtraIds.value()->size() != inputTokens.size()) { - std::string errStr = "inputTokenExtraIds vector size must be the same as input token vector size."; - TLLM_THROW(errStr); + TLLM_THROW("inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).", + mInputTokenExtraIds.value()->size(), inputTokens.size()); } VecTokenExtraIds tokenExtraIds = *mInputTokenExtraIds.value(); for (std::size_t i = 0; i < inputTokens.size(); ++i) diff --git a/cpp/include/tensorrt_llm/common/cudaUtils.h b/cpp/include/tensorrt_llm/common/cudaUtils.h index 3bb203f8e..023f97d87 100644 --- a/cpp/include/tensorrt_llm/common/cudaUtils.h +++ b/cpp/include/tensorrt_llm/common/cudaUtils.h @@ -161,7 +161,7 @@ inline std::optional isCudaLaunchBlocking() return result; } -inline void syncAndCheck(char const* const file, int const line) +inline bool doCheckError() { auto const cudaLaunchBlocking = isCudaLaunchBlocking(); #ifndef NDEBUG @@ -170,10 +170,15 @@ inline void syncAndCheck(char const* const file, int const line) bool const checkError = cudaLaunchBlocking.value_or(false); #endif - if (checkError) + return checkError; +} + +inline void syncAndCheck(char const* const file, int const line) +{ + if (doCheckError()) { - cudaError_t result = cudaDeviceSynchronize(); - check(result, "cudaDeviceSynchronize", file, line); + check(cudaGetLastError(), "cudaGetLastError", file, line); + check(cudaDeviceSynchronize(), "cudaDeviceSynchronize", file, line); } } diff --git a/cpp/include/tensorrt_llm/common/mpiUtils.h b/cpp/include/tensorrt_llm/common/mpiUtils.h index edf3da004..4a7bb53ae 100644 --- a/cpp/include/tensorrt_llm/common/mpiUtils.h +++ b/cpp/include/tensorrt_llm/common/mpiUtils.h @@ -380,6 +380,10 @@ class MpiComm void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const; void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const; + + void allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, + std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const; + void barrier() const; void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const; diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index a5c6cf03c..807382c4a 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -43,7 +43,7 @@ char const* version() noexcept; class Model; class Serialization; -class ContextPhaseState; +class DataTransceiverState; /// @brief Sampling configuration class SamplingConfig @@ -283,8 +283,10 @@ struct LookaheadDecodingConfig class ContextPhaseParams { public: - explicit ContextPhaseParams(VecTokens firstGenTokens); - ContextPhaseParams(VecTokens firstGenTokens, void* state); + using RequestIdType = std::uint64_t; + + explicit ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId); + ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state); ContextPhaseParams(ContextPhaseParams const&); ContextPhaseParams(ContextPhaseParams&&); @@ -295,6 +297,8 @@ class ContextPhaseParams [[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept; [[nodiscard]] VecTokens popFirstGenTokens() && noexcept; + [[nodiscard]] RequestIdType getReqId() const noexcept; + [[nodiscard]] void const* getState() const noexcept; [[nodiscard]] void* getState() noexcept; [[nodiscard]] void* releaseState() noexcept; @@ -304,6 +308,9 @@ class ContextPhaseParams static void deleter(void const* data); using StatePtr = std::unique_ptr; + /// @brief This request corresponds to the request ID in the context phase. + RequestIdType mReqId{0}; + /// @brief The first tokens generated by context executor VecTokens mFirstGenTokens; @@ -593,18 +600,24 @@ class KvCacheConfig class ExtendedRuntimePerfKnobConfig { public: - explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = true, bool enableContextFMHAFP32Acc = false); + explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = true, bool enableContextFMHAFP32Acc = false, + bool cudaGraphMode = false, SizeType32 cudaGraphCacheSize = 0); bool operator==(ExtendedRuntimePerfKnobConfig const& other) const { - return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc; + return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc + && mCudaGraphMode == other.mCudaGraphMode && mCudaGraphCacheSize == other.mCudaGraphCacheSize; } [[nodiscard]] bool getMultiBlockMode() const; [[nodiscard]] bool getEnableContextFMHAFP32Acc() const; + [[nodiscard]] bool getCudaGraphMode() const; + [[nodiscard]] SizeType32 getCudaGraphCacheSize() const; void setMultiBlockMode(bool multiBlockMode); void setEnableContextFMHAFP32Acc(bool enableContextFMHAFP32Acc); + void setCudaGraphMode(bool cudaGraphMode); + void setCudaGraphCacheSize(SizeType32 cacheSize); private: friend class Serialization; @@ -614,6 +627,13 @@ class ExtendedRuntimePerfKnobConfig /// @brief If enable FMHA runner FP32 accumulation. bool mEnableContextFMHAFP32Acc; + + /// @brief Control if enable cuda graph. + bool mCudaGraphMode; + + /// @brief Number of cuda graphs to be cached in the runtime. + /// The larger the cache, the better the perf, but more GPU memory is consumed. + SizeType32 mCudaGraphCacheSize; }; /// @brief Configuration class for debugging output diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index 11d22c3f0..9fe197dc9 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -75,10 +75,10 @@ class Serialization static void serialize(kv_cache::CacheState const& state, std::ostream& os); [[nodiscard]] static size_t serializedSize(kv_cache::CacheState const& state); - // ContextPhaseState - [[nodiscard]] static ContextPhaseState deserializeContextPhaseState(std::istream& is); - static void serialize(ContextPhaseState const& contextPhaseState, std::ostream& os); - [[nodiscard]] static size_t serializedSize(ContextPhaseState const& contextPhaseState); + // DataTransceiverState + [[nodiscard]] static DataTransceiverState deserializeDataTransceiverState(std::istream& is); + static void serialize(DataTransceiverState const& dataTransceiverState, std::ostream& os); + [[nodiscard]] static size_t serializedSize(DataTransceiverState const& dataTransceiverState); // ContextPhaseParams [[nodiscard]] static ContextPhaseParams deserializeContextPhaseParams(std::istream& is); diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 39e6693d7..a2476caff 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -198,6 +198,10 @@ enum class CapacitySchedulerPolicy /// @brief GUARANTEED_NO_EVICT uses KV cache more conservatively guaranteeing that a request, once started, will run /// to completion without eviction. kGUARANTEED_NO_EVICT = 1, + + /// @brief kSTATIC_BATCH does not schedule new requests until all requests in current batch are completed. + /// Similar to kGUARANTEED_NO_EVICT, requests will run to completion without eviction. + kSTATIC_BATCH = 2 }; std::ostream& operator<<(std::ostream& os, CapacitySchedulerPolicy policy); diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index 358826f50..2db8fcc18 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -62,12 +62,12 @@ class GptDecoderBatched : public IGptDecoderBatched void newRequests(std::vector const& seqSlots, std::vector const& requests, std::vector const& samplingConfigs) override; - TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override; + DecoderFinishedEventPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override; - void forwardSync(decoder_batch::Token const& token) override; + void forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent) override; - void forwardSync( - decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input) override; + void forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent, decoder_batch::Output& output, + decoder_batch::Input const& input) override; void forwardAsync(decoder::Output& output, decoder::Input const& input) override; @@ -271,7 +271,7 @@ class GptDecoderBatched : public IGptDecoderBatched void newRequestExplicitDraftTokens(SizeType32 batchIdx, decoder_batch::Request const& request); //! @brief Updates finished state on host for all active requests - void updateFinished(decoder_batch::Token const& token); + void updateFinished(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent); //! @brief Sets inputs for explicit draft tokens. void setExplicitDraftTokensInputs(decoder_batch::Input const& input); @@ -289,7 +289,7 @@ class GptDecoderBatched : public IGptDecoderBatched CudaStreamPtr mRuntimeStream; CudaStreamPtr mDecoderStream; BufferManager mBufferManager; - TokenPtr mForwardToken; + DecoderFinishedEventPtr mDecoderFinishEvent; CudaEvent mForwardEvent; using GptDecoderPtr = std::unique_ptr; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index 11464f80e..048fa05a7 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -75,11 +75,11 @@ class Input using Output = decoder::Output; -// TODO: is this a bad name to mix up with token concept in LLM? Would 'Event' be better? And should move to common.h -class Token +// used just as a container for easy returning / passing to function +class DecoderFinishedEvent { public: - explicit Token(CudaEvent&& event, std::vector const& active) + explicit DecoderFinishedEvent(CudaEvent&& event, std::vector const& active) : event(std::move(event)) , active(active) { @@ -96,7 +96,7 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder public: using CudaStreamPtr = std::shared_ptr; using TensorPtr = std::shared_ptr; - using TokenPtr = std::unique_ptr; + using DecoderFinishedEventPtr = std::unique_ptr; //! @brief Setup buffers for ExplicitDraftTokens decoding. virtual void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) = 0; @@ -105,15 +105,15 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder virtual void setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) = 0; //! @brief Run one step for all requests without blocking the host process and return the token for synchronization. - virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0; + virtual DecoderFinishedEventPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0; //! @brief Call decoder forwardSync and wait for the call to `forwardAsync` associated with a token to complete. - virtual void forwardSync( - decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input) + virtual void forwardSync(decoder_batch::DecoderFinishedEvent const& token, decoder_batch::Output& output, + decoder_batch::Input const& input) = 0; //! @brief Wait for the call to `forwardAsync` associated with a token to complete. - virtual void forwardSync(decoder_batch::Token const& token) = 0; + virtual void forwardSync(decoder_batch::DecoderFinishedEvent const& token) = 0; //! @brief Run one step for all requests and wait for completion on the host. virtual void forward(decoder_batch::Output& output, decoder_batch::Input const& input) diff --git a/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h b/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h index 56504bd94..3c6fe731a 100644 --- a/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h +++ b/cpp/include/tensorrt_llm/runtime/lookaheadBuffers.h @@ -62,6 +62,7 @@ class LookaheadRuntimeBuffers TensorMap& inputBuffers, TensorMap& outputBuffers, runtime::WorldConfig const& worldConfig) const; public: + TensorPtr cumSumLength; // [1] the cumulative sum of generation length, on pinned TensorPtr packedMasksDevice; // [forwardBatchSize, tokensPerStep, numPackedMasks], on gpu TensorPtr generationLengthsDevice; // [forwardBatchSize], on gpu TensorPtr positionOffsetsDevice; // [forwardBatchSize, tokensPerStep], on gpu diff --git a/cpp/include/tensorrt_llm/runtime/loraModule.h b/cpp/include/tensorrt_llm/runtime/loraModule.h index 15ac50fa8..d76178eaf 100644 --- a/cpp/include/tensorrt_llm/runtime/loraModule.h +++ b/cpp/include/tensorrt_llm/runtime/loraModule.h @@ -179,7 +179,7 @@ class LoraModule static std::vector createLoraModules(std::vector const& loraModuleNames, SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads, - SizeType32 attentionHeadSize, SizeType32 tpSize); + SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts); static ModuleType constexpr toModuleType(std::string_view const& name) { diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h index fc3ac2928..b1b495e75 100644 --- a/cpp/include/tensorrt_llm/runtime/modelConfig.h +++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h @@ -60,6 +60,9 @@ class ModelConfig { kATTENTION, kRECURRENT, + // NOTE: Linear and noop are attention alternatives introduced in Nemotron-NAS. They do not use the KV cache. + kLINEAR, + kNOOP, }; enum class KVCacheType : std::int32_t @@ -97,13 +100,13 @@ class ModelConfig kEnabled, }; - explicit ModelConfig(SizeType32 vocabSize, SizeType32 nbAttentionLayers, SizeType32 nbRnnLayers, SizeType32 nbHeads, - SizeType32 hiddenSize, nvinfer1::DataType dtype) + explicit ModelConfig(SizeType32 vocabSize, SizeType32 nbLayers, SizeType32 nbAttentionLayers, + SizeType32 nbRnnLayers, SizeType32 nbHeads, SizeType32 hiddenSize, nvinfer1::DataType dtype) : mVocabSize(vocabSize) + , mNbLayers(nbLayers) , mNbAttentionLayers(nbAttentionLayers) , mNbRnnLayers(nbRnnLayers) , mNbHeads(nbHeads) - , mNbKvHeads(nbHeads) , mHiddenSize(hiddenSize) , mSizePerHead(mHiddenSize / mNbHeads) , mDataType(dtype) @@ -134,6 +137,10 @@ class ModelConfig , mUseShapeInference(true) , mManageWeightsType(ManageWeightsType::kDisabled) { + TLLM_CHECK_WITH_INFO(mNbLayers >= mNbAttentionLayers + mNbRnnLayers, + "Number of layers (%d) expected to be >= number of attention (%d) + number of rnn layers (%d)", mNbLayers, + mNbAttentionLayers, mNbRnnLayers); + setNbKvHeads(mNbHeads); } [[nodiscard]] static std::vector getOptProfilesSplitPoints() noexcept @@ -151,14 +158,55 @@ class ModelConfig return (mVocabSize + worldSize - 1) / worldSize * worldSize; } - [[nodiscard]] SizeType32 constexpr getNbAttentionLayers(SizeType32 pipelineParallelism = 1) const + [[nodiscard]] SizeType32 countLocalLayers( + LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const { - return mNbAttentionLayers / pipelineParallelism; + TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism); + auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder + auto const firstLocalLayerIt = mLayerTypes.cbegin() + (numLocalLayers * pipelineParallelismRank); + return std::count(firstLocalLayerIt, firstLocalLayerIt + numLocalLayers, layerType); } - [[nodiscard]] SizeType32 constexpr getNbRnnLayers(SizeType32 pipelineParallelism = 1) const + [[nodiscard]] SizeType32 countLowerRankLayers( + LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const { - return mNbRnnLayers / pipelineParallelism; + auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder + auto const firstLocalLayer = numLocalLayers * pipelineParallelismRank; + // count number of previous non-local attention layers + return std::count(mLayerTypes.cbegin(), mLayerTypes.cbegin() + firstLocalLayer, layerType); + } + + [[nodiscard]] SizeType32 getNbLayers(SizeType32 pipelineParallelism = 1) const + { + return mNbLayers / pipelineParallelism; // WARNING: assume no remainder + } + + [[nodiscard]] SizeType32 getNbAttentionLayers( + SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const + { + // TODO(oargov): get rid of this invalid state + if (mLayerTypes.empty()) + { + // this assumption might be wrong in a few cases, for example: + // layer types: [attention, recurrent, recurrent], pp=2 ==> first rank has 1 attention layer, not 0 + TLLM_LOG_DEBUG("Assuming uniform distribution of attention layers between ranks"); + return mNbAttentionLayers / pipelineParallelism; + } + return countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank); + } + + [[nodiscard]] SizeType32 getNbRnnLayers( + SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const + { + // TODO(oargov): get rid of this invalid state + if (mLayerTypes.empty()) + { + // this assumption might be wrong in a few cases, for example: + // layer types: [attention, attention, recurrent], pp=2 ==> second rank has 1 rnn layer, not 0 + TLLM_LOG_DEBUG("Assuming uniform distribution of recurrent layers between ranks"); + return mNbRnnLayers / pipelineParallelism; + } + return countLocalLayers(LayerType::kRECURRENT, pipelineParallelism, pipelineParallelismRank); } [[nodiscard]] SizeType32 constexpr getNbHeads() const noexcept @@ -166,14 +214,16 @@ class ModelConfig return mNbHeads; } - [[nodiscard]] SizeType32 constexpr getNbKvHeads() const noexcept + [[nodiscard]] SizeType32 getNbKvHeads(SizeType32 layerIdx) const { - return mNbKvHeads; + TLLM_CHECK_WITH_INFO(layerIdx < mNbAttentionLayers, "Layer index %d is out of bounds", layerIdx); + return mNumKvHeadsPerAttentionLayer[layerIdx]; } - void constexpr setNbKvHeads(SizeType32 nbKvHeads) noexcept + // set the number of kv heads for all layers + void setNbKvHeads(SizeType32 nbKvHeads) { - mNbKvHeads = nbKvHeads; + mNumKvHeadsPerAttentionLayer = std::vector(mNbAttentionLayers, nbKvHeads); } [[nodiscard]] SizeType32 constexpr getHiddenSize() const noexcept @@ -645,12 +695,46 @@ class ModelConfig mModelName = modelName; } + [[nodiscard]] std::vector const& getNumKvHeadsPerLayer() const + { + return mNumKvHeadsPerAttentionLayer; + } + + [[nodiscard]] std::pair::const_iterator, std::vector::const_iterator> + getNumKvHeadsPerLayerLocalRange(SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const + { + TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism); + // count number of previous non-local attention layers + auto const numPrevAttnLayers + = countLowerRankLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank); + auto const firstLocalAttentionLayerIt = mNumKvHeadsPerAttentionLayer.cbegin() + numPrevAttnLayers; + auto const numLocalAttentionLayers + = countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank); + return std::make_pair(firstLocalAttentionLayerIt, firstLocalAttentionLayerIt + numLocalAttentionLayers); + } + + void setNumKvHeadsPerLayer(std::vector const& headsPerLayer) + { + auto const numElems = static_cast(headsPerLayer.size()); + TLLM_CHECK_WITH_INFO(numElems == mNbAttentionLayers, + "Length of head_per_layer (%d) must match number of attention layers (%d)", numElems, mNbAttentionLayers); + mNumKvHeadsPerAttentionLayer = headsPerLayer; + } + + [[nodiscard]] SizeType32 getSumLocalKvHeads( + SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const + { + auto [cbegin, cend] = getNumKvHeadsPerLayerLocalRange(pipelineParallelism, pipelineParallelismRank); + auto const sumLocalHeads = std::reduce(cbegin, cend); + return sumLocalHeads; + } + private: SizeType32 mVocabSize; + SizeType32 mNbLayers; SizeType32 mNbAttentionLayers; SizeType32 mNbRnnLayers; SizeType32 mNbHeads; - SizeType32 mNbKvHeads; SizeType32 mHiddenSize; SizeType32 mSizePerHead; nvinfer1::DataType mDataType; @@ -703,6 +787,7 @@ class ModelConfig bool mUseShapeInference; ManageWeightsType mManageWeightsType; std::string mModelName; + std::vector mNumKvHeadsPerAttentionLayer; }; } // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h index 8226c411c..e739e8188 100644 --- a/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h +++ b/cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h @@ -97,8 +97,7 @@ class SpeculativeDecodingMode [[nodiscard]] bool constexpr variableDraftLength() const { - // Add Lookahead, when lookahead supports it. - return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens); + return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding); } [[nodiscard]] bool constexpr hasDraftLogits() const diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index 10debf560..2ff2de09e 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -348,9 +348,11 @@ endif() if(NOT WIN32) # Unix-like compilers set(UNDEFINED_FLAG "-Wl,--no-undefined") set(AS_NEEDED_FLAG "-Wl,--as-needed") + set(NO_AS_NEEDED_FLAG "-Wl,--no-as-needed") else() # Windows set(UNDEFINED_FLAG "") set(AS_NEEDED_FLAG "") + set(NO_AS_NEEDED_FLAG "") endif() set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a index b18a4d582..c54da94de 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e08b60b89bb4934490ee61383c55c22d831fa1cfcccedea5735400e3574aadbc -size 4671466 +oid sha256:10b940475c5acd80a61674d8ce4e42cc4ef3d806bafb245bbed26751378274e3 +size 4904726 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index d5279318b..ac692bb61 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2b6b3bf449c4b4d67f0bb9879af6b8eda6f46f272eaa5b7305582a2cc8c73e17 -size 4775694 +oid sha256:b2754f7887a1b5c37ba3d589320e16144039cfe5dc6a6c78ee71925861d7d511 +size 5015842 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt index 975995460..c8b35c9c0 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -f229593e4699180b52e38f99c8ac31dc libtensorrt_llm_batch_manager_static.a -440b3ae47982d88fc8517c5f01f67b3c libtensorrt_llm_batch_manager_static.pre_cxx11.a -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +ff71eabd0ac6ede5398b5b6ce4e26dcf libtensorrt_llm_batch_manager_static.a +846eb112a182973e7c3b0b193300b4b8 libtensorrt_llm_batch_manager_static.pre_cxx11.a +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 2b9c3f003..2b867222c 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cb9d3d05ef4b08df0fc02f39c053a4435b58f9431d1ce269439b2c1f0a055b21 -size 4523116 +oid sha256:13b8701dd767b414a5376a91905985979ad9d2b975465ac00835c04656ee6508 +size 4766226 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 54354fcf8..64680e7ae 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6597b35faffe93244d89595dc369ece59729c984871ad5aab531d714d39c8e49 -size 4487214 +oid sha256:cd0b73a017fc5c663235dcd724eb104ecc49d12ff29b6e3744be6ea952d027db +size 4722522 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt index 1b565a5fa..833efc826 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -c015186a31c789891a27e44f5a9ab9ec libtensorrt_llm_batch_manager_static.a -cac21708838abf82b18e1846c40b5c79 libtensorrt_llm_batch_manager_static.pre_cxx11.a -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +1eb5c88f894f3361445d7254cbc29b03 libtensorrt_llm_batch_manager_static.a +4e73341b23e8fb20b732ba08e03a54a8 libtensorrt_llm_batch_manager_static.pre_cxx11.a +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib index c234b864e..9fd773218 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib +++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:85f18d38a66cd8b15c7e447be16171b9db854f2b2fe9dc49daa4f93fae9bc125 -size 30145896 +oid sha256:b4ac61c0b0816477c11bd6c66ec4c2f23f7b6e1400eacd8c07c333f79dec0bea +size 30794956 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt index 063e584b1..db6e80406 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt +++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/version.txt @@ -1,2 +1,2 @@ -8ebb7d383e97bcd738cc24b00d58a2d0 tensorrt_llm_batch_manager_static.lib -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +eefe7310a60098897724f46cf4aa54f8 tensorrt_llm_batch_manager_static.lib +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/common/mpiUtils.cpp b/cpp/tensorrt_llm/common/mpiUtils.cpp index 6022dfd6a..b637e57f1 100644 --- a/cpp/tensorrt_llm/common/mpiUtils.cpp +++ b/cpp/tensorrt_llm/common/mpiUtils.cpp @@ -314,6 +314,18 @@ void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType d #endif // ENABLE_MULTI_DEVICE } +void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, + std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(), + getMpiDtype(recvtype), mComm)); + +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const { #if ENABLE_MULTI_DEVICE diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a index 8726ff8ae..d7f58205a 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:41eae80923f8634a635f2fce84fdbe33101ee6cf86c0a98ed4ce30a7f4cea350 -size 1782460 +oid sha256:ebab2cc2c62a826ddec02597178b8e0c9bc316726f37f8eef37c06795aebcf03 +size 1784658 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index 482344588..b8e5962bf 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dd4fb660b1c3e664a012e26cea949dcded855e133aa6aadd01157a15df3e0d44 -size 1808956 +oid sha256:4b630f89708614e63c67871e21b6e32bfde71acc51549b650c57048c0fa343e7 +size 1812686 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt index af0c5cd81..a4434f2dd 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -2744c4784cfd34c7311148c7f7614757 libtensorrt_llm_executor_static.a -d56af9e74a9d49e32860d89dcca024d0 libtensorrt_llm_executor_static.pre_cxx11.a -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +136f1b9d2168cbb9011a341b267af9a2 libtensorrt_llm_executor_static.a +183bd079377d6cd698d46370168a5726 libtensorrt_llm_executor_static.pre_cxx11.a +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a index a9670e009..d1c437693 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e9a5c64c347400297bc8b6907b4feaa890305aaf5c1b45ce57aca8fcae3e881f -size 1846898 +oid sha256:e04c76f6441a49db4d3996c62b4055395ae018384d8ee2f02ea5f0c4c0843902 +size 1853180 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index 0edd3b394..61c25133c 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e4b5912ce9a1c13554f4b16d29deb6f2ad51477c56810b758ba488212f8e5dc9 -size 1757522 +oid sha256:95ba1a4b6bdcecbb592bbb42b4998bcb0eb1f45a318163635183bcde6950c4bf +size 1764982 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt index 6717291c8..ad7ba2bf9 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -69dc85fe48625b6f8f684487f2048458 libtensorrt_llm_executor_static.a -d46f0be3543e24c4df51ae287086ca52 libtensorrt_llm_executor_static.pre_cxx11.a -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +dfbd0d424c150253ff758aa5bd37a971 libtensorrt_llm_executor_static.a +e82866739fef1d6df8293541967924bf libtensorrt_llm_executor_static.pre_cxx11.a +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib index cc641f706..2799dc524 100644 --- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib +++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fc94de246db88ca06e5a7c20d04f78057cc3c2928b3fa3e79f49e8b9d90b76da -size 19683718 +oid sha256:aa8ba34fb98c5407e3d6944245086158c61b2c784b15c7b923fdd156b942224d +size 19670642 diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt index e88bce29b..d2e341ae7 100644 --- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt +++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/version.txt @@ -1,2 +1,2 @@ -0e361ba639fa897f489f6d0f48cfe13f tensorrt_llm_executor_static.lib -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +784ad1fabd3d02466f95fbc463b64f5b tensorrt_llm_executor_static.lib +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/cumsumLastDim.cu b/cpp/tensorrt_llm/kernels/cumsumLastDim.cu index 1b7977708..8989e95fc 100644 --- a/cpp/tensorrt_llm/kernels/cumsumLastDim.cu +++ b/cpp/tensorrt_llm/kernels/cumsumLastDim.cu @@ -115,6 +115,11 @@ template void invokeCumsumLastDim(SizeType32 batchSize, SizeType32 inputLength, void const* __restrict__ input, void* __restrict__ output, void* deviceTempStorage, size_t tempStorageBytes, cudaStream_t stream) { + // For empty tensor support + if (batchSize == 0) + { + return; + } if (deviceTempStorage != nullptr) // we need to use DeviceScan { diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt index 49e66bdce..6f3b1ed98 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ 88c30973b9b3452baa3f063d34d08169 libtensorrt_llm_nvrtc_wrapper.so -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt index f5da2f14a..d3923a7d2 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ 95e9f87610383348e444d2d0b8396f2d libtensorrt_llm_nvrtc_wrapper.so -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll index fadecd1af..643b3b831 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6fc6f35c712d83404e40a7840a0c9d1f5157df61df91a7207c4e4131783f4676 +oid sha256:1471e322bb44cd65b98ee30e0befa32ae4c86e828f0b4fd4f02d4af4e710d08f size 1128448 diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib index 0592bf8c5..cfe4399d6 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e74ab8e65851dfc44e015714fe166f521649b781c85bd0215d42b488218e9ca5 +oid sha256:e207a8f57b944529163c7ed2ab30639a5f2779c5118602c6ebd50a623d16f845 size 3488 diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt index b9555dc07..6dded519b 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/version.txt @@ -1,3 +1,3 @@ -700fc148d9a0f939e0088bf69e899360 tensorrt_llm_nvrtc_wrapper.lib -6ea6ac6dff8793afbd79dd5768daae85 tensorrt_llm_nvrtc_wrapper.dll -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +b7e624ba775e9f5090ef4b67bcdbd7a2 tensorrt_llm_nvrtc_wrapper.lib +f9b1cc37a27dd0574bb41a2763a97be7 tensorrt_llm_nvrtc_wrapper.dll +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a index a6ff47b69..36daef37b 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f4cc1e6f0b6d1e7bc875284275b591d34c707471e636019b4c2904f30798dbc9 +oid sha256:9117f7cf5eef0ed452c0d0bc79242b84def103e7038c9d3df6e366690801ca92 size 25364090 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a index cc7772d70..7eaca6cd9 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:658a3f0bc5b9877e5ad447437287908dc9b7df87ae0e86f5338aaf81e26f723e +oid sha256:2b04913f9e9029a5ce5a222d5cc7492ff53323a548079d2fb32d5b2aeb0c2268 size 25768990 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt index b8d933d49..ecfff5209 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -979f3165fbc7a68528df6e343cc54e3f libtensorrt_llm_internal_cutlass_kernels_static.a -68e84c294a658734a8b26d7270540e1d libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +d54fb93f256601f4c4ad7f1c8e6e9919 libtensorrt_llm_internal_cutlass_kernels_static.a +71028d801074f11138e890391e48591d libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a index 61b711b17..715fba593 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:24358d8eb15a5e802cbee6d2d27735033eedb33091e9355d199229c3ba7b6447 +oid sha256:d8c685f8ea2f84838dfdbf448eab41c76fe88fe29db0d4a511d6d6d241ad1832 size 44173632 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a index 0f981d9d4..4f403b38e 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4b54efb78e98bf6580f73a3b6b689823d7c2eb851c0bab5f36906f4ebbfc44fc -size 43561142 +oid sha256:b9d75392ba3b59853c43072b4f9949b32cb6724813a39048e4585e9a8fb3e136 +size 43561206 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt index b6f327833..dcd8a686a 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -0b49d88b8b5e83c8c6997c725a37f373 libtensorrt_llm_internal_cutlass_kernels_static.a -7a12fc880d2a13ee5c7cf2b1e169cb19 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +4fc3e1fb0db6a121f88a9141605d9285 libtensorrt_llm_internal_cutlass_kernels_static.a +253731af750407020dbe6f2fbe50fa2b libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib index d5007d0d2..e88023db2 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/tensorrt_llm_internal_cutlass_kernels_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5ba6f8610db3b967f3de4beeff6394cdd3e56d15916f39110ed932c3c3a65417 -size 88141376 +oid sha256:62af58f5e09d1cf5e347b02ef3bd3a186469162fc9645d038fb2cba23b597722 +size 88140804 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt index 15334333c..5bb9d18b8 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-windows-msvc/version.txt @@ -1,2 +1,2 @@ -5c26d1347bb8b47288d598b6d7444900 tensorrt_llm_internal_cutlass_kernels_static.lib -7adf157833793b3215570b0a95b9c4b2998a620c commit \ No newline at end of file +eb7fc4a105eb6e6f52ba865f2b055233 tensorrt_llm_internal_cutlass_kernels_static.lib +7f370deb0090d885d7518c2b146399ba3933c004 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/layers/decodingParams.h b/cpp/tensorrt_llm/layers/decodingParams.h index 0179add1d..07d200704 100644 --- a/cpp/tensorrt_llm/layers/decodingParams.h +++ b/cpp/tensorrt_llm/layers/decodingParams.h @@ -210,8 +210,6 @@ struct LookaheadSetupParams : public DecodingSetupParams TensorPtr positionOffsets; //! see LookaheadDecodingOutputs::attentionPackedMasks TensorPtr attentionPackedMasks; - //! see LookaheadDecodingOutputs::actualGenerationLengths - TensorPtr actualGenerationLengths; }; class BaseDecodingInputs @@ -551,8 +549,6 @@ class LookaheadDecodingOutputs : public SpeculativeDecodingOutputs TensorPtr positionOffsets; //! [maxBatchSize, maxDecodingTokens] TensorPtr positionIds; - //! The actual decoding tokens length, for debug and for future. - TensorPtr actualGenerationLengths; }; class ExplicitDraftTokensOutputs : public SpeculativeDecodingOutputs diff --git a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp index 5b3062be0..db78160b9 100644 --- a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp +++ b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.cpp @@ -18,8 +18,12 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/layers/decodingParams.h" #include "tensorrt_llm/layers/lookaheadDecodingUtils.h" +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/lookaheadModule.h" +#include #include namespace tensorrt_llm::layers @@ -27,6 +31,36 @@ namespace tensorrt_llm::layers using namespace tensorrt_llm::runtime; +LookaheadAlgorithm::LookaheadAlgorithm( + runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id) + : mMaxW(maxW) + , mMaxN(maxN) + , mMaxG(maxG) + , mFilling(0) + , mPoolManager(maxG) + , mId(id) + , mGoldenTokensMax( + runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32)) + , mPrefillsMax(runtime::BufferManager::cpu( + runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32)) + , mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32)) + , mPastTokensMax( + runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32)) + , mGuessTokensMax( + runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxG * (maxN - 1)}), nvinfer1::DataType::kINT32)) +{ + runtime::SizeType32 maxGeneratedLen, maxDraftLen; + std::tie(maxGeneratedLen, std::ignore, maxDraftLen, std::ignore) + = executor::LookaheadDecodingConfig(maxW, maxN, maxG).calculateSpeculativeResource(); + mAttentionMask = runtime::BufferManager::cpu( + runtime::ITensor::makeShape({maxDraftLen, maxDraftLen}), nvinfer1::DataType::kBOOL); + mDraftTokensMax + = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32); + mSampledTokensMax + = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxGeneratedLen}), nvinfer1::DataType::kINT32); + mEncodeMapMax = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32); +} + void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeType32 n, SizeType32 g) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -36,7 +70,7 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT mW = w; mN = n; mG = g; - std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, std::ignore) + std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, mRuntimeMaxDraftPathLen) = executor::LookaheadDecodingConfig(mW, mN, mG).calculateSpeculativeResource(); mPoolManager.setup(mG); @@ -81,8 +115,8 @@ void LookaheadAlgorithm::accept(TensorConstPtr const& generatedTokens) } //! lookahead has two phase, prefill the past tokens matrix and maintain past tokens matrix. -runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, TensorPtr const& positionIds, - TensorPtr const& samplingMask, runtime::SizeType32 offset) +runtime::SizeType32 LookaheadAlgorithm::lookahead( + TensorPtr const& draftTokens, TensorPtr const& positionIds, runtime::SizeType32 startPosId) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -90,7 +124,6 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, SizeType32 len = prefill + mFilling * mW; TLLM_CHECK(len <= ITensor::volume(draftTokens->getShape())); TLLM_CHECK(len <= ITensor::volume(positionIds->getShape())); - TLLM_CHECK(len <= ITensor::volume(samplingMask->getShape())); BufferRange prefillRange(*mPrefills); BufferRange pastRange(*mPastTokens); BufferRange draftRange(*draftTokens); @@ -112,11 +145,6 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, } BufferRange positionIdsRange(*positionIds); - BufferRange samplingMaskRange(*samplingMask); - for (auto& v : samplingMaskRange) - { - v = 0; - } SizeType32 idx = 0, wj = 0; auto fillPosition = [&positionIdsRange, &idx](SizeType32 start, SizeType32 len) { @@ -127,20 +155,18 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, }; if (prefill >= 0) { - fillPosition(offset, prefill); + fillPosition(startPosId, prefill); for (wj = 0; wj < mW; wj++) { - fillPosition(offset + prefill + wj, mFilling); - samplingMaskRange[prefill + wj * mFilling + mFilling - 1] = true; + fillPosition(startPosId + prefill + wj, mFilling); } } else { - fillPosition(offset, mFilling - 1); + fillPosition(startPosId, mFilling - 1); for (wj = 1; wj < mW; wj++) { - fillPosition(offset - 1 + wj, mFilling); - samplingMaskRange[wj * mFilling + mFilling - 1 - 1] = true; + fillPosition(startPosId - 1 + wj, mFilling); } } PRINT_VALUES(positionIds); @@ -150,7 +176,7 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, } runtime::SizeType32 LookaheadAlgorithm::guess(TensorPtr const& guessTokens, TensorPtr const& guessIds, - TensorPtr const& samplingMask, runtime::SizeType32 offset, runtime::TokenIdType lastToken) + runtime::SizeType32 startPosId, runtime::TokenIdType lastToken) { auto guesses = mPoolManager.guess(lastToken, mW); @@ -158,67 +184,227 @@ runtime::SizeType32 LookaheadAlgorithm::guess(TensorPtr const& guessTokens, Tens std::for_each(guesses.begin(), guesses.end(), [&len](auto& a) { len += ITensor::volume(a->getShape()); }); TLLM_CHECK(len <= ITensor::volume(guessTokens->getShape())); TLLM_CHECK(len <= ITensor::volume(guessIds->getShape())); - TLLM_CHECK(len <= ITensor::volume(samplingMask->getShape())); BufferRange guessTokensRange(*guessTokens); BufferRange guessIdsRange(*guessIds); - BufferRange samplingMaskRange(*samplingMask); SizeType32 cur = 0; for (auto guess : guesses) { BufferRange guessRange(*guess); std::copy(guessRange.begin(), guessRange.end(), guessTokensRange.begin() + cur); - SizeType32 tmp = offset; + SizeType32 tmp = startPosId; std::for_each( guessIdsRange.begin() + cur, guessIdsRange.begin() + cur + mN - 1, [&tmp](auto& v) { v = tmp++; }); cur += ITensor::volume(guess->getShape()); } - std::for_each(samplingMaskRange.begin(), samplingMaskRange.begin() + len, [](auto& a) { a = true; }); - return len; } +void LookaheadAlgorithm::posIdsToMask(TensorPtr const& mask, TensorConstPtr const& posIds) +{ + auto len = ITensor::volume(posIds->getShape()); + TLLM_CHECK(mask->getDimension<0>() >= len); + TLLM_CHECK(mask->getDimension<1>() >= len); + auto posIdsRange = BufferRange(*posIds); + auto maskLocation = BufferLocation(*mask); + + for (auto& item : maskLocation) + { + item = false; + } + + if (len > 0) + { + std::vector> stack; + for (auto i = 0; i < len; i++) + { + auto cur = posIdsRange[i]; + while (stack.size() > 0 && cur <= stack.back().second) + { + stack.pop_back(); + } + TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true); + stack.push_back(std::make_pair(i, cur)); + for (auto prev : stack) + { + maskLocation.at(i, prev.first) = true; + } + } + } +} + +struct TreeValue; +using TreeMap = std::unordered_map; + +struct TreeValue +{ + TreeValue() + : nexts(std::make_shared()) + { + } + + using Nexts = std::shared_ptr; + Nexts nexts{nullptr}; + std::list sources; +}; + +using TreeNode = TreeMap::value_type; + +template +void treeDFS(TreeNode& node, BF const& visitBefore, AF const& visitAfter) +{ + visitBefore(node); + for (auto& next : *(node.second.nexts)) + { + treeDFS(next, visitBefore, visitAfter); + } + visitAfter(node); +} + +SizeType32 LookaheadAlgorithm::treeEncode( + TensorPtr const& tokens, TensorPtr const& posIds, TensorPtr const& mask, TensorPtr const& encodeMap) +{ + TLLM_CHECK(ITensor::volume(tokens->getShape()) == ITensor::volume(posIds->getShape())); + auto len = ITensor::volume(tokens->getShape()); + + BufferRange tokensRange(*tokens); + BufferRange posIdsRange(*posIds); + BufferLocation maskLocation(*mask); + BufferRange mapRange(*encodeMap); + + auto branches = std::make_shared(); + + for (auto i = 0; i < len; i++) + { + auto nexts = branches; + for (auto j = 0; j <= i; j++) + { + if (maskLocation.at(i, j)) + { + auto pos = posIdsRange[j]; + auto tok = tokensRange[j]; + auto found = nexts->find(tok); + if (found != nexts->end()) + { + found->second.sources.push_back(j); + nexts = found->second.nexts; + } + else + { + auto [inserted, ok] = nexts->insert({tok, TreeValue()}); + inserted->second.sources.push_back(j); + nexts = inserted->second.nexts; + } + } + } + } + + for (auto& item : maskLocation) + { + item = 0; + } + std::vector> stack; + SizeType32 offset = 0; + SizeType32 posId = posIdsRange.size() ? posIdsRange[0] : 0; + + auto visitBefore + = [&stack, &maskLocation, &tokensRange, &posIdsRange, &posId, &offset, &mapRange](TreeNode const& node) + { + stack.push_back(std::make_pair(offset, node.first)); + for (auto const& source : node.second.sources) + { + mapRange[source] = offset; + } + for (auto const& prev : stack) + { + maskLocation.at(offset, prev.first) = true; + } + tokensRange[offset] = node.first; + posIdsRange[offset] = posId; + offset++; + posId++; + }; + auto visitAfter = [&stack, &posId](TreeNode const& node) + { + stack.pop_back(); + posId--; + }; + + for (auto& next : *branches) + { + treeDFS(next, visitBefore, visitAfter); + } + + for (SizeType32 i = offset; i < len; i++) + { + tokensRange[i] = 0; + posIdsRange[i] = 0; + } + for (SizeType32 i = 0; i < len; i++) + { + for (SizeType32 j = i < offset ? offset : 0; j < len; j++) + { + maskLocation.at(i, j) = false; + } + } + + return offset; +} + void LookaheadAlgorithm::prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds, - TensorPtr const& samplingMask, TensorPtr const& length, TensorConstPtr const& offsetPtr, - TensorConstPtr const& lastTokenPtr) + TensorPtr const& draftLengthPtr, TensorPtr const& attentionMask, SizeType32 attentionMaskOffset, + TensorConstPtr const& lastPositionIdPtr, TensorConstPtr const& lastTokenPtr) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); if (mRuntimeMaxDraftLen == 0) { - (BufferRange(*length))[0] = 0; + mDraftTokens = ITensor::slice(mDraftTokensMax, 0, 0); + mEncodeMap = ITensor::slice(mEncodeMapMax, 0, 0); + (BufferRange(*draftLengthPtr))[0] = 0; return; } auto lastToken = BufferRange(*lastTokenPtr)[0]; - auto offset = BufferRange(*offsetPtr)[0]; + auto offset = BufferRange(*lastPositionIdPtr)[0]; SizeType32 inputLen = ITensor::volume(draftTokens->getShape()); TLLM_CHECK(inputLen >= mRuntimeMaxDraftLen); BufferRange draftRange(*draftTokens); BufferRange positionRange(*positionIds); - BufferRange samplingRange(*samplingMask); SizeType32 filledLen = 0; filledLen += lookahead(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen), - ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), - ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset); + ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), offset); auto guessStart = filledLen; filledLen += guess(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen), - ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), - ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken); + ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken); auto guessEnd = filledLen; + std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd, + BufferRange(*mGuessTokensMax).begin()); mGuessTokens = ITensor::slice(mGuessTokensMax, 0, guessEnd - guessStart); - std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd, - BufferRange(*mGuessTokens).begin()); + posIdsToMask(mAttentionMask, ITensor::slice(positionIds, 0, filledLen)); - (BufferRange(*length))[0] = filledLen; + auto draftLen = treeEncode(ITensor::slice(draftTokens, 0, filledLen), ITensor::slice(positionIds, 0, filledLen), + mAttentionMask, mEncodeMapMax); + + for (SizeType32 i = 0; i < draftLen; i++) + { + BufferRange srcRange(*ITensor::at(mAttentionMask, {i})); + BufferRange dstRange(*ITensor::slice(attentionMask, {i + attentionMaskOffset, attentionMaskOffset})); + std::copy(srcRange.begin(), srcRange.end(), dstRange.begin()); + } + + std::copy(draftRange.begin(), draftRange.begin() + draftLen, BufferRange(*mDraftTokensMax).begin()); + mDraftTokens = ITensor::slice(mDraftTokensMax, 0, draftLen); + (BufferRange(*draftLengthPtr))[0] = draftLen; + mEncodeMap = ITensor::slice(mEncodeMapMax, 0, filledLen); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -229,29 +415,31 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mGuessTokens->getShape())); + TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mDraftTokens->getShape())); BufferRange goldRange(*goldenTokens); - BufferRange guessTokensRange(*mGuessTokens); - auto guessSize = ITensor::volume(mGuessTokens->getShape()); + BufferRange draftRange(*mDraftTokens); + BufferLocation maskLocation(*mAttentionMask); + auto draftSize = ITensor::volume(mDraftTokens->getShape()); + auto end = *BufferRange(*endToken).begin(); - SizeType32 guesses = (mN - 1 > 0) ? (guessSize / (mN - 1)) : 0; - SizeType32 hit = 0, maxHit = 0, hitIdx = 0; - for (SizeType32 i = 0; i < guesses; i++) + SizeType32 maxHit = 0, hitIdx = 0; + for (SizeType32 i = 0; i < draftSize; i++) { SizeType32 hit = 0; - for (SizeType32 j = 0; j < mN - 1; j++) + TokenIdType cur = newLastToken; + for (SizeType32 j = 0; j < draftSize; j++) { - auto idx = i * (mN - 1) + j; - bool ok - = (j == 0) ? (newLastToken == guessTokensRange[idx]) : (goldRange[idx - 1] == guessTokensRange[idx]); - bool finish = guessTokensRange[idx] == *BufferRange(*endToken).begin(); - if (ok && !finish) - { - hit++; - } - else + if (maskLocation.at(i, j)) { - break; + if (draftRange[j] == cur && draftRange[j] != end) + { + hit++; + cur = goldRange[j]; + } + else + { + break; + } } } if (hit > maxHit) @@ -261,17 +449,19 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce } } - BufferRange acceptedRange(*accepted); - acceptedRange[0] = newLastToken; - std::copy(goldRange.begin() + hitIdx * (mN - 1), goldRange.begin() + hitIdx * (mN - 1) + maxHit, - acceptedRange.begin() + 1); + maxHit = maxHit > mRuntimeMaxDraftPathLen ? mRuntimeMaxDraftPathLen : maxHit; + SizeType32 acceptedIdx = 0; + BufferRange acceptedRange(*accepted); BufferRange acceptedOffsetsRange(*acceptedOffsets); - auto lookSize = 1 + mN - 2 - mFilling + mFilling * mW; - // acceptedOffsetsRange[0] = 0; - for (SizeType32 i = 0; i < maxHit; i++) + acceptedRange[acceptedIdx] = newLastToken; + for (SizeType32 j = 0; j < draftSize; j++) { - acceptedOffsetsRange[i] = lookSize + hitIdx * (mN - 1) + i - 1; + if (maskLocation.at(hitIdx, j) && acceptedIdx < maxHit) + { + acceptedOffsetsRange[acceptedIdx++] = j; + acceptedRange[acceptedIdx] = goldRange[j]; + } } *BufferRange(*acceptedLength).begin() = maxHit + 1; @@ -325,7 +515,19 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(ITensor::volume(acceptedTokens->getShape()) >= mN); - BufferRange sampledRange(*sampledTokens); + BufferRange zippedTokensRange(*sampledTokens); + BufferRange sampledRange(*mSampledTokensMax); + + BufferRange mapRange(*mEncodeMap); + BufferRange unzipRange(*mSampledTokensMax); + mSampledTokens = ITensor::slice(mSampledTokensMax, 0, mEncodeMap->getShape().d[0] + 1); + + unzipRange[0] = zippedTokensRange[0]; + for (SizeType32 i = 0; i < mapRange.size(); i++) + { + unzipRange[i + 1] = zippedTokensRange[mapRange[i] + 1]; + } + BufferRange keyRange(*mKeyTokens); BufferRange pastRange(*mPastTokens); @@ -359,13 +561,15 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const } auto guessSize = ITensor::volume(mGuessTokens->getShape()); - auto outputSize = ITensor::volume(sampledTokens->getShape()); + auto outputSize = ITensor::volume(mSampledTokens->getShape()); auto lookSize = 1 + (mN > 1 ? mN - 2 : 0) - mFilling + mFilling * mW; TLLM_CHECK(guessSize + lookSize == outputSize); - TensorConstPtr goldenTokens = ITensor::slice(sampledTokens, lookSize, guessSize); + TensorConstPtr goldenTokens = ITensor::slice(mSampledTokens, lookSize, guessSize); + + auto& acptLen = *BufferRange(*acceptedLength).begin(); - verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, goldenTokens, endToken); + verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, ITensor::slice(sampledTokens, 1), endToken); accept(ITensor::slice(acceptedTokens, 0, *BufferRange(*acceptedLength).begin())); diff --git a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.h b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.h index 99df44128..485734c5a 100644 --- a/cpp/tensorrt_llm/layers/lookaheadAlgorithm.h +++ b/cpp/tensorrt_llm/layers/lookaheadAlgorithm.h @@ -21,6 +21,7 @@ #include "tensorrt_llm/layers/decodingParams.h" #include "tensorrt_llm/runtime/common.h" #include +#include namespace tensorrt_llm::layers { @@ -35,24 +36,7 @@ class LookaheadAlgorithm //! @brief Currently the resource management is to be aligned with batch manager. //! @param w, n, g is the Jacobi window, n-gram level and guess set size respectively. LookaheadAlgorithm( - runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id = 0) - : mMaxW(maxW) - , mMaxN(maxN) - , mMaxG(maxG) - , mFilling(0) - , mPoolManager(maxG) - , mId(id) - , mGoldenTokensMax( - runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32)) - , mPrefillsMax(runtime::BufferManager::cpu( - runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32)) - , mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32)) - , mPastTokensMax( - runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32)) - , mGuessTokensMax( - runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxG * (maxN - 1)}), nvinfer1::DataType::kINT32)) - { - } + runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id = 0); //! @brief setup per request, fill internal states from @param prompt. void setup(TensorConstPtr const& prompt, runtime::SizeType32 w, runtime::SizeType32 n, runtime::SizeType32 g); @@ -62,43 +46,55 @@ class LookaheadAlgorithm void accept(TensorConstPtr const& generatedTokens); //! @brief combine lookahead and guess to prepare the tensors. - //! input @param offsetPtr is position id of the last golden token, in a TensorPtr. + //! input @param lastPositionIdPtr is position id of the last golden token, in a TensorPtr. //! input @param lastTokenPtr the last golden token for searching in the pool, in a TensorPtr. - //! output @param draftTokens, positionIds, samplingMask; including the golden token, the lookahead - //! and the verification branch information. @param length holds the draft tokens length. - void prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds, TensorPtr const& samplingMask, - TensorPtr const& length, TensorConstPtr const& offsetPtr, TensorConstPtr const& lastTokenPtr); + //! output @param draftTokens, positionIds includes the lookahead and the verification branch information. + //! output @param draftLengthPtr holds the draft tokens length. + //! output @param attentionMask holds the draft tokens dependency mask, and attentionMaskOffset is the index offset + //! in attentionMask. + void prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds, TensorPtr const& draftLengthPtr, + TensorPtr const& attentionMask, runtime::SizeType32 attentionMaskOffset, + TensorConstPtr const& lastPositionIdPtr, TensorConstPtr const& lastTokenPtr); //! @brief update the internal states and generate accepted tokens from @param outputTokens. - //! input @param sampledTokens is the all the tokens from the language model. The position at samplingMask=1 is - //! valid. input @param endToken is the end token for `verify` early quit. - //! output @param acceptedTokens, acceptedOffsets ind @param acceptedLength. + //! input @param sampledTokens is the all the tokens from the language model. + //! input @param endToken is the end token for `verify` early quit. + //! output @param acceptedTokens, acceptedOffsets in @param acceptedLength. void update(TensorPtr const& acceptedTokens, TensorPtr const& acceptedOffsets, TensorPtr const& acceptedLength, TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken); + //! generate attention @param mask from @param posIds. + static void posIdsToMask(TensorPtr const& mask, TensorConstPtr const& posIds); + + //! inplace encode the @param tokens and @param posIds according to attention @param masks, and record the offsets + //! in @param encodeMap. + static runtime::SizeType32 treeEncode( + TensorPtr const& tokens, TensorPtr const& posIds, TensorPtr const& masks, TensorPtr const& encodeMap); + private: //! @brief generate lookahead branch information. - //! input @param offset the position id of the last golden token. - //! output @param draftTokens, positionIds, samplingMask of the lookahead branch. + //! input @param startPosId is the first position id of the draftTokens. + //! output @param draftTokens, positionIds of the lookahead branch. //! @return the actual filled lookahead length. - runtime::SizeType32 lookahead(TensorPtr const& draftTokens, TensorPtr const& positionIds, - TensorPtr const& samplingMask, runtime::SizeType32 offset); + runtime::SizeType32 lookahead( + TensorPtr const& draftTokens, TensorPtr const& positionIds, runtime::SizeType32 startPosId); //! @brief generate verification branch information. Also save the guessed tokens for future verification. - //! input @param offset the position id of the last golden token. + //! input @param startPosId the first position id. //! input @param lastToken the last golden token for searching in the pool. - //! output @param guessTokens, guessIds, samplingMask of the verification branch. + //! output @param guessTokens, guessIds of the verification branch. //! @return the actual filled guess length. - runtime::SizeType32 guess(TensorPtr const& guessTokens, TensorPtr const& guessIds, TensorPtr const& samplingMask, - runtime::SizeType32 offset, runtime::TokenIdType lastToken); + runtime::SizeType32 guess(TensorPtr const& guessTokens, TensorPtr const& guessIds, runtime::SizeType32 startPosId, + runtime::TokenIdType lastToken); //! @brief verify the guessed tokens results and generate the longest accepted tokens. //! input @param newLastToken is the new-generated last golden token. - //! input @param goldenTokens is the guessed token results from the language model. + //! input @param sampledTokens is the generated token results from the language model. //! input @param endToken is the end token for early quit detection. - //! output @param accepted, acceptedOffsets in @param acceptedLength, . + //! output @param accepted in @param acceptedLength, including the first golden one. + //! output @param acceptedOffsets is the offsets of draft tokens, excluding the first golden one. void verify(TensorPtr const& accepted, TensorPtr const& acceptedOffsets, TensorPtr const& acceptedLength, - runtime::TokenIdType newLastToken, TensorConstPtr const& goldenTokens, TensorConstPtr const& endToken); + runtime::TokenIdType newLastToken, TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken); private: LookaheadPoolManager mPoolManager; @@ -117,6 +113,13 @@ class LookaheadAlgorithm //! the same guess tokens from `guess` and used in `verify` TensorPtr mGuessTokensMax; // shape [mMaxG*(mMaxN-1)] TensorPtr mGuessTokens; // shape [mG*(mN-1)] + TensorPtr mDraftTokensMax; + TensorPtr mDraftTokens; + TensorPtr mAttentionMask; + TensorPtr mEncodeMapMax; + TensorPtr mEncodeMap; + TensorPtr mSampledTokensMax; + TensorPtr mSampledTokens; //! look ahead algorithm parameters, Window size, Level and Guess set size. //! max for reserving resources and current for current request. @@ -127,6 +130,7 @@ class LookaheadAlgorithm runtime::SizeType32 mN{0}; runtime::SizeType32 mG{0}; runtime::SizeType32 mRuntimeMaxDraftLen{0}; + runtime::SizeType32 mRuntimeMaxDraftPathLen{0}; //! in prefilling mode when mFilling < mN-1. runtime::SizeType32 mFilling; diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp index 8214abfb4..32b812967 100644 --- a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp +++ b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp @@ -24,6 +24,7 @@ #include "tensorrt_llm/layers/lookaheadAlgorithm.h" #include "tensorrt_llm/layers/lookaheadDecodingUtils.h" #include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/lookaheadModule.h" @@ -80,14 +81,14 @@ LookaheadDecodingLayer::CpuAlgorithmResources::CpuAlgorithmResources(DecoderD mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32); mNextDraftPosIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32); mGenerationLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); - mGenerationLengthsMax = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); mPositionOffsets = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32); mPositionIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32); + mAttentionMask + = BufferManager::cpu(ITensor::makeShape({maxTokensPerStep, maxTokensPerStep}), nvinfer1::DataType::kBOOL); mPackedMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep, static_cast(divUp(maxTokensPerStep, 32))}), nvinfer1::DataType::kINT32); - mSamplingMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kBOOL); mNextDraftLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); mSequenceLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32); } @@ -113,7 +114,6 @@ LookaheadDecodingLayer::LookaheadDecodingLayer( mWorkspaceSize = getTopKWorkspaceSize(maxBatchSize, maxTokensPerStep, maxTopK, vocabSizePadded); mTargetTokensDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kINT32); - mSamplingMaskDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kBOOL); mCurandStatesDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize, sizeof(curandState_t)}), nvinfer1::DataType::kINT8); @@ -168,6 +168,7 @@ void LookaheadDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth { SizeType32 gbi = batchSlotsRange[bi]; (BufferRange(*mCpuAlgo->mGenerationLengths))[gbi] = 1; + (BufferRange(*mCpuAlgo->mNextDraftLengths))[gbi] = 0; BufferLocation(*mCpuAlgo->mPositionOffsets).at(gbi, 0) = 0; BufferRange packedMaskRange(*ITensor::at(mCpuAlgo->mPackedMask, {gbi})); for (auto& mask : packedMaskRange) @@ -184,11 +185,6 @@ void LookaheadDecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth PRINT_SHAPE(setupParams->attentionPackedMasks); mBufferManager->copy( *ITensor::at(mCpuAlgo->mGenerationLengths, {gbi}), *ITensor::at(setupParams->generationLengths, {gbi})); - if (setupParams->actualGenerationLengths) - { - mBufferManager->copy(*ITensor::at(mCpuAlgo->mGenerationLengths, {gbi}), - *ITensor::at(setupParams->actualGenerationLengths, {gbi})); - } mBufferManager->copy( *ITensor::at(mCpuAlgo->mPositionOffsets, {gbi}), *ITensor::at(setupParams->positionOffsets, {gbi})); mBufferManager->copy( @@ -261,39 +257,32 @@ size_t LookaheadDecodingLayer::getWorkspaceSize() const noexcept return std::max(mWorkspaceSize, mSetupWorkspaceSize); } -template -void LookaheadDecodingLayer::posIdsToMask(TensorPtr mask, TensorConstPtr posIds) +inline void initAttentionMask(TensorPtr const& mask, std::shared_ptr& bufferManager) { - auto len = ITensor::volume(posIds->getShape()); - TLLM_CHECK(mask->getDimension<0>() > len); - TLLM_CHECK(mask->getDimension<1>() * 32 > len); - auto posIdsRange = BufferRange(*posIds); - auto maskLocation = BufferLocation(*mask); - - for (auto i = 0; i < maskLocation.size(); i++) + bufferManager->setZero(*mask); + BufferLocation maskLocation(*mask); + auto maskShape = mask->getShape(); + for (SizeType32 i = 0; i < maskShape.d[0]; i++) { - maskLocation[i] = 0; + maskLocation.at(i, 0) = true; } - maskLocation.at(0, 0) = 1; +} - auto setBit = [](SizeType32& x, SizeType32 idx) { x |= (1 << idx); }; - if (len > 0) +inline void convertBoolToInt32(TensorPtr const& dst, TensorConstPtr const& src) +{ + auto dstShape = dst->getShape(); + auto srcShape = src->getShape(); + TLLM_CHECK(dstShape.d[0] == srcShape.d[0]); + TLLM_CHECK(dstShape.d[1] * 32 >= srcShape.d[1]); + BufferLocation dstLocation(*dst); + BufferLocation srcLocation(*src); + + auto setBit = [](SizeType32& x, SizeType32 idx, bool value) { x |= (value << idx); }; + for (auto i = 0; i < srcShape.d[0]; i++) { - std::vector> stack; - stack.emplace_back(0, posIdsRange[0] - 1); - for (auto i = 1; i < len + 1; i++) + for (auto j = 0; j < srcShape.d[1]; j++) { - auto cur = posIdsRange[i - 1]; - while (stack.size() > 0 && cur <= stack.back().second) - { - stack.pop_back(); - } - TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true); - stack.emplace_back(i, cur); - for (auto prev : stack) - { - setBit(maskLocation.at(i, prev.first / 32), prev.first % 32); - } + setBit(dstLocation.at(i, j / 32), j % 32, srcLocation.at(i, j)); } } } @@ -307,12 +296,16 @@ void LookaheadDecodingLayer::forwardSyncCPU( mCpuAlgo->mBatchSlots->reshape(inputs->batchSlots->getShape()); mBufferManager->copy(*inputs->batchSlots, *mCpuAlgo->mBatchSlots); mBufferManager->copy(*inputs->curTokensPerStep.value(), *mCpuAlgo->mTokensPerStep); - mBufferManager->copy(*inputs->curTokensPerStep.value(), *mCpuAlgo->mTokensPerStep); mBufferManager->copy(*inputs->endIds, *mCpuAlgo->mEndIds); mBufferManager->copy(*outputs->sequenceLength.value(), *mCpuAlgo->mSequenceLengths); mBufferManager->copy(*mTargetTokensDevice, *mCpuAlgo->mTargetTokens); + if (outputs->prevDraftLengths) + { + mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, *outputs->prevDraftLengths); + } + mBufferManager->getStream().synchronize(); auto const batchSize = inputs->localBatchSize; @@ -325,7 +318,6 @@ void LookaheadDecodingLayer::forwardSyncCPU( BufferRange numNewTokensCumSumRange(*mCpuAlgo->mNumNewTokensCumSum); BufferRange batchSlotsRange(*mCpuAlgo->mBatchSlots); BufferRange generationLengthsRange(*mCpuAlgo->mGenerationLengths); - BufferRange generationLengthsMaxRange(*mCpuAlgo->mGenerationLengthsMax); BufferRange nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths); BufferRange sequenceLengthsRange(*mCpuAlgo->mSequenceLengths); BufferLocation pathsOffsetLocation(*mCpuAlgo->mPathsOffsets); @@ -334,6 +326,7 @@ void LookaheadDecodingLayer::forwardSyncCPU( mBufferManager->setZero(*mCpuAlgo->mPathsOffsets); mBufferManager->setZero(*mCpuAlgo->mNumNewTokens); mBufferManager->setZero(*mCpuAlgo->mNumNewTokensCumSum); + mBufferManager->setZero(*mCpuAlgo->mPackedMask); for (SizeType32 bi = 0; bi < batchSize; bi++) { @@ -342,7 +335,6 @@ void LookaheadDecodingLayer::forwardSyncCPU( SizeType32 const tokensPerStep = generationLengthsRange[gbi]; TensorPtr sampledTokens = ITensor::slice(mCpuAlgo->mTargetTokens, {gbi, 0}, tokensPerStep); - PRINT_VALUES(sampledTokens); if (tokensPerStep == 1) { @@ -369,14 +361,18 @@ void LookaheadDecodingLayer::forwardSyncCPU( sequenceLengthsRange[gbi] += numNewTokensRange[gbi]; + initAttentionMask(mCpuAlgo->mAttentionMask, mBufferManager); + theAlgo.prepare( // ITensor::at(mCpuAlgo->mNextDraftTokens, {gbi}), // ITensor::at(mCpuAlgo->mNextDraftPosIds, {gbi}), // - ITensor::at(mCpuAlgo->mSamplingMask, {gbi}), // ITensor::at(mCpuAlgo->mNextDraftLengths, {gbi}), // + mCpuAlgo->mAttentionMask, 1, // ITensor::at(mCpuAlgo->mSequenceLengths, {gbi}), // ITensor::at(mCpuAlgo->mOutputIds, {gbi, numNewTokensRange[gbi] - 1})); + convertBoolToInt32(ITensor::at(mCpuAlgo->mPackedMask, {gbi}), mCpuAlgo->mAttentionMask); + BufferLocation posIdsLocation(*ITensor::at(mCpuAlgo->mPositionIds, {gbi})); for (auto& posid : posIdsLocation) { @@ -385,20 +381,14 @@ void LookaheadDecodingLayer::forwardSyncCPU( mBufferManager->copy(*ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi]), *ITensor::slice(mCpuAlgo->mPositionIds, {gbi, 1}, nextDraftLengthsRange[gbi])); - posIdsToMask( // - ITensor::at(mCpuAlgo->mPackedMask, {gbi}), // - ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi])); - BufferRange offsetRange(*ITensor::at(mCpuAlgo->mPositionOffsets, {gbi})); - TLLM_CHECK_WITH_INFO( - posIdsLocation.size() == offsetRange.size(), "%ld, %ld", posIdsLocation.size(), offsetRange.size()); for (auto i = 0; i < posIdsLocation.size(); i++) { offsetRange[i] = posIdsLocation[i] - posIdsLocation[0]; } + TensorPtr accepted = ITensor::slice(mCpuAlgo->mOutputIds, {gbi, 0}, numNewTokensRange[gbi]); TensorPtr draft = ITensor::slice(mCpuAlgo->mNextDraftTokens, {gbi, 0}, nextDraftLengthsRange[gbi]); - TLLM_LOG_DEBUG("CPU ALGO [ %d ] forward, %s", gbi, D(sampledTokens).values().c_str()); TLLM_LOG_DEBUG("[%d][%d] CPU ALGO [ %d ] forward, %s, %s", mGlobalSteps, batchSize, gbi, D(accepted).values().c_str(), D(draft).values().c_str()); @@ -430,29 +420,23 @@ void LookaheadDecodingLayer::forwardSyncCPU( mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, *outputs->numNewTokensCumSum); // mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, *outputs->nextDraftTokens); - mBufferManager->copy(*mCpuAlgo->mPackedMask, *outputs->packedMasks); + for (SizeType32 bi = 0; bi < batchSize; bi++) + { + SizeType32 gbi = batchSlotsRange[bi]; + // nextDraftLengthsRange[gbi] = mDecoderDomain.getMaxDecodingTokens() - 1; + generationLengthsRange[gbi] = nextDraftLengthsRange[gbi] + 1; + } if (outputs->nextDraftLengths) { mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, *outputs->nextDraftLengths); } - for (SizeType32 bi = 0; bi < batchSize; bi++) - { - SizeType32 gbi = batchSlotsRange[bi]; - generationLengthsRange[gbi] = nextDraftLengthsRange[gbi] + 1; - generationLengthsMaxRange[gbi] = mDecoderDomain.getMaxDecodingTokens(); - } mBufferManager->copy(*mCpuAlgo->mPackedMask, *outputs->packedMasks); - mBufferManager->copy(*mCpuAlgo->mGenerationLengthsMax, *outputs->generationLengths); + mBufferManager->copy(*mCpuAlgo->mGenerationLengths, *outputs->generationLengths); mBufferManager->copy(*mCpuAlgo->mPositionOffsets, *outputs->positionOffsets); mBufferManager->copy(*mCpuAlgo->mPositionIds, *outputs->positionIds); - if (outputs->actualGenerationLengths) - { - mBufferManager->copy(*mCpuAlgo->mGenerationLengths, *outputs->actualGenerationLengths); - } - mBufferManager->getStream().synchronize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h index 536d21727..f2470a411 100644 --- a/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h +++ b/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.h @@ -48,7 +48,6 @@ class LookaheadDecodingLayer : public BaseLayer private: void forwardSyncCPU(std::shared_ptr const& outputs, std::shared_ptr const& inputs); - void posIdsToMask(TensorPtr mask, TensorConstPtr posIds); private: using Base::mDecoderDomain; @@ -57,7 +56,6 @@ class LookaheadDecodingLayer : public BaseLayer size_t mSetupWorkspaceSize{}; TensorPtr mCurandStatesDevice; TensorPtr mTargetTokensDevice; - TensorPtr mSamplingMaskDevice; struct CpuAlgorithmResources { @@ -78,11 +76,10 @@ class LookaheadDecodingLayer : public BaseLayer TensorPtr mNextDraftTokens; TensorPtr mNextDraftPosIds; - TensorPtr mSamplingMask; TensorPtr mNextDraftLengths; TensorPtr mSequenceLengths; TensorPtr mGenerationLengths; - TensorPtr mGenerationLengthsMax; + TensorPtr mAttentionMask; TensorPtr mPackedMask; TensorPtr mPositionOffsets; TensorPtr mPositionIds; diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp index cd5529b07..946d50bd4 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp @@ -423,6 +423,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc request_seq_len, mNumHeads, mHeadSize, padding_offset, (float*) nullptr, 0, stream); } } + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp b/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp index ba73171c4..3d0ff1db3 100644 --- a/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp @@ -143,6 +143,7 @@ int CumsumLastDimPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc invokeCumsumLastDim( batchSize, inputLength, inputs[getInputTensorIdx()], outputs[0], wp, mTempStorageBytes, stream); + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp index 291900103..71ad6591f 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp @@ -164,7 +164,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel memset(&xqaParams, 0, sizeof(XQAParams)); xqaParams.data_type = ConvertMMHAToXQAParamsHelper::data_type; - xqaParams.layer_idx = mLayerIdx; + xqaParams.layer_idx = mLayerIdxInCachePool; xqaParams.num_q_heads = mNumHeads; xqaParams.num_kv_heads = mNumKVHeads; xqaParams.head_size = mHeadSize; @@ -376,13 +376,13 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params&, \ - const FusedQKVMaskedAttentionDispatchParams&, cudaStream_t stream); \ + FusedQKVMaskedAttentionDispatchParams const&, cudaStream_t stream); \ template void fusedQKV_masked_attention_dispatch(Multihead_attention_params&, \ - const FusedQKVMaskedAttentionDispatchParams&, cudaStream_t stream); \ + FusedQKVMaskedAttentionDispatchParams const&, cudaStream_t stream); \ template void fusedQKV_masked_attention_dispatch(Multihead_attention_params&, \ - const FusedQKVMaskedAttentionDispatchParams&, cudaStream_t stream); \ + FusedQKVMaskedAttentionDispatchParams const&, cudaStream_t stream); \ template void fusedQKV_masked_attention_dispatch(Multihead_attention_params&, \ - const FusedQKVMaskedAttentionDispatchParams&, cudaStream_t stream); + FusedQKVMaskedAttentionDispatchParams const&, cudaStream_t stream); INSTANTIATE_MMHA_DISPATCH(float, float) INSTANTIATE_MMHA_DISPATCH(uint16_t, half) #ifdef ENABLE_BF16 @@ -391,8 +391,8 @@ INSTANTIATE_MMHA_DISPATCH(__nv_bfloat16, __nv_bfloat16) #undef INSTANTIATE_MMHA_DISPATCH GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, - int num_kv_heads, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale, - tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, + int num_kv_heads, int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling, + float qk_tanh_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale, @@ -411,6 +411,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, , mVisionStart(vision_start) , mVisionLength(vision_length) , mNumKVHeads(num_kv_heads) + , mLayerIdxInCachePool(layer_idx_in_cache_pool) , mHeadSize(head_size) , mUnidirectional(unidirectional) , mQScaling(q_scaling) @@ -525,6 +526,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng read(d, mVisionStart); read(d, mVisionLength); read(d, mNumKVHeads); + read(d, mLayerIdxInCachePool); read(d, mHeadSize); read(d, mUnidirectional); read(d, mQScaling); @@ -721,7 +723,7 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams #include #include #include @@ -41,8 +43,8 @@ static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"}; static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"}; GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, - int num_kv_heads, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale, - tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, + int num_kv_heads, int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling, + float qk_tanh_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int rotary_embedding_dim, // for RoPE. 0 for non-RoPE float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, float rotary_embedding_scale, float rotary_embedding_short_m_scale, @@ -57,9 +59,9 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_ bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable, int spec_decoding_max_generation_length) - : GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size, - unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base, - rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale, + : GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, layer_idx_in_cache_pool, + head_size, unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim, + rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type, enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled, @@ -94,6 +96,7 @@ bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const case IdxEntry::KV_CACHE_BLOCK_OFFSETS: return useKVCache() && mPagedKVCache; case IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS: return useKVCache() && mPagedKVCache; case IdxEntry::HOST_KV_CACHE_POOL_POINTERS: return useKVCache() && mPagedKVCache; + case IdxEntry::HOST_KV_CACHE_POOL_MAPPING: return useKVCache() && mPagedKVCache; case IdxEntry::PAST_KEY_VALUE: return useKVCache() && !mPagedKVCache; case IdxEntry::KV_CACHE_QUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant(); case IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant(); @@ -244,6 +247,11 @@ bool GPTAttentionPlugin::supportsFormatCombination( // kv cache pool pointers return inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR; } + else if (useKVCache() && mPagedKVCache && (pos == getIdx(IdxEntry::HOST_KV_CACHE_POOL_MAPPING))) + { + // kv cache pool mapping + return inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR; + } else if (useKVCache() && mKVCacheQuantMode.hasInt8KvCache() && (!mPagedKVCache && (pos == getIdx(IdxEntry::PAST_KEY_VALUE) || pos == nbInputs + 1))) { @@ -478,6 +486,7 @@ int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, outputDesc, inputs, outputs, workspace, stream); } + sync_check_cuda_error(); TLLM_LOG_TRACE("Attention plugin stop at layer %d", mLayerIdx); return 0; @@ -624,27 +633,36 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 auto const& kvCacheBlockOffsets = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)]; auto const& kvCacheBlockOffsetsShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims; max_blocks_per_sequence = kvCacheBlockOffsetsShape.d[kvCacheBlockOffsetsShape.nbDims - 1]; - auto const seqStride = getStride(kvCacheBlockOffsetsShape, 0); + + std::int32_t const* host_pool_mapping + = static_cast(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_MAPPING)]); + + const int32_t layerToPool = host_pool_mapping[mLayerIdx]; + auto const seqStride = getStride(kvCacheBlockOffsetsShape, 1); + auto const poolStride = getStride(kvCacheBlockOffsetsShape, 0); auto const seqOffset = seqIdxBeg * seqStride; + auto const poolOffset = layerToPool * poolStride; block_offsets = reinterpret_cast(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)]) - + seqOffset; + + poolOffset + seqOffset; host_block_offsets = reinterpret_cast(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)]) - + seqOffset; + + poolOffset + seqOffset; auto const* const typed_host_pool_pointers = static_cast(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)]); auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T)); + auto const blockSize = mTokensPerBlock * mNumKVHeads * mHeadSize; auto const bytesPerBlock = blockSize * cacheElemSize; - auto const layerOffset = mLayerIdx * 2 * bytesPerBlock; + auto const layerOffset = mLayerIdxInCachePool * 2 * bytesPerBlock; - host_primary_pool_pointer = reinterpret_cast(typed_host_pool_pointers[0] + layerOffset); - host_secondary_pool_pointer = reinterpret_cast(typed_host_pool_pointers[1] + layerOffset); + host_primary_pool_pointer = reinterpret_cast(typed_host_pool_pointers[layerToPool * 2] + layerOffset); + host_secondary_pool_pointer + = reinterpret_cast(typed_host_pool_pointers[layerToPool * 2 + 1] + layerOffset); } AttentionOutT* context_buf_ = static_cast(outputs[0]) @@ -962,8 +980,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField auto* obj = new GPTAttentionPlugin(p.getScalar("layer_idx").value(), p.getScalar("num_heads").value(), p.getScalar("vision_start").value(), p.getScalar("vision_length").value(), p.getScalar("num_kv_heads").value(), - p.getScalar("head_size").value(), p.getScalar("unidirectional").value(), - p.getScalar("q_scaling").value(), p.getScalar("qk_tanh_scale").value(), + p.getScalar("layer_idx_in_cache_pool").value(), p.getScalar("head_size").value(), + p.getScalar("unidirectional").value(), p.getScalar("q_scaling").value(), + p.getScalar("qk_tanh_scale").value(), static_cast(p.getScalar("position_embedding_type").value()), p.getScalar("rotary_embedding_dim").value(), p.getScalar("rotary_embedding_base").value(), static_cast(p.getScalar("rotary_embedding_scale_type").value()), diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h index aeeae99ce..7982d3c07 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h @@ -85,7 +85,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon { public: GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads, - int head_size, int unidirectional, float q_scaling, float qk_tanh_scale, + int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int rotary_embedding_dim, // for RoPE. 0 for non-RoPE float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, @@ -182,6 +182,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon KV_CACHE_BLOCK_OFFSETS, HOST_KV_CACHE_BLOCK_OFFSETS, HOST_KV_CACHE_POOL_POINTERS, + HOST_KV_CACHE_POOL_MAPPING, PAST_KEY_VALUE, KV_CACHE_QUANTIZATION_SCALE, KV_CACHE_DEQUANTIZATION_SCALE, diff --git a/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp b/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp index b4f3e46dc..97fcd3358 100644 --- a/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp @@ -90,6 +90,7 @@ int IdentityPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer cudaMemcpyAsync(outputs[0], inputs[0], count, cudaMemcpyDeviceToDevice, stream); + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp b/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp index 5b8dae3c9..1e33ac63f 100644 --- a/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp @@ -177,7 +177,7 @@ int LayernormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* input scale, dynamic_scale, output); } #endif - + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/lruPlugin/lruPlugin.cpp b/cpp/tensorrt_llm/plugins/lruPlugin/lruPlugin.cpp index 780beef6c..7cc5dae9c 100644 --- a/cpp/tensorrt_llm/plugins/lruPlugin/lruPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/lruPlugin/lruPlugin.cpp @@ -214,6 +214,7 @@ int lruPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1 { invokeRGLRUUpdate(lru_params, stream); } + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.cpp b/cpp/tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.cpp index 1f39c40c3..c45f512b2 100644 --- a/cpp/tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.cpp @@ -200,6 +200,7 @@ int MambaConv1dPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, { invokeMambaConv1dGeneration(mambaConv1dParams, stream); } + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp b/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp index fd3affcb4..c81550106 100644 --- a/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp @@ -203,7 +203,7 @@ int QuantizePerTokenPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, } #endif // ENABLE_FP8 #endif // ENABLE_BF16 - + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp b/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp index 1d7e7026d..8bf93d146 100644 --- a/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp @@ -134,7 +134,7 @@ int QuantizeTensorPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, stream, mProp.maxGridSize[0]); } #endif - + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp b/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp index f46afadef..372bcf2c8 100644 --- a/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp @@ -236,7 +236,7 @@ int RmsnormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDe } #endif // ENABLE_FP8 #endif // ENABLE_BF16 - + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp index d82d855f6..8cbdd5c7e 100644 --- a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp @@ -347,6 +347,7 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc { invokeSelectiveScanUpdate(ssm_params, stream); } + sync_check_cuda_error(); return 0; } diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt old mode 100644 new mode 100755 index 65f54c0c3..1f7bf73fb --- a/cpp/tensorrt_llm/pybind/CMakeLists.txt +++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt @@ -41,9 +41,11 @@ set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC "${TORCH_INSTALL_PREFIX}/lib") target_link_libraries( - ${TRTLLM_PYBIND_MODULE} - PUBLIC ${SHARED_TARGET} ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python - ${UNDEFINED_FLAG}) + ${TRTLLM_PYBIND_MODULE} PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG} + ${NO_AS_NEEDED_FLAG}) +target_link_libraries( + ${TRTLLM_PYBIND_MODULE} PUBLIC ${Python3_LIBRARIES} ${TORCH_LIBRARIES} + torch_python ${UNDEFINED_FLAG}) target_compile_definitions(${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}) target_compile_definitions(${TRTLLM_PYBIND_MODULE} diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index c617cc8ba..2c74104e9 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -178,19 +178,25 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def(py::self != py::self); py::class_(m, "ModelConfig") - .def(py::init(), - py::arg("vocab_size"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"), py::arg("num_heads"), - py::arg("hidden_size"), py::arg("data_type")) + .def(py::init(), + py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"), + py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type")) .def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize) .def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size")) - .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1) - .def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1) + .def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1) + .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1, + py::arg("pipeline_parallelism_rank") = 0) + .def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1, + py::arg("pipeline_parallelism_rank") = 0) + .def("num_kv_heads", &tr::ModelConfig::getNbKvHeads, py::arg("layer_idx")) + .def("set_num_kv_heads", &tr::ModelConfig::setNbKvHeads, py::arg("num_kv_heads")) .def_property_readonly("num_heads", &tr::ModelConfig::getNbHeads) .def_property_readonly("hidden_size", &tr::ModelConfig::getHiddenSize) .def_property_readonly("size_per_head", &tr::ModelConfig::getSizePerHead) .def_property_readonly("data_type", &tr::ModelConfig::getDataType) - .def_property("num_kv_heads", &tr::ModelConfig::getNbKvHeads, &tr::ModelConfig::setNbKvHeads) .def_property("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead) + .def_property( + "num_kv_heads_per_layer", &tr::ModelConfig::getNumKvHeadsPerLayer, &tr::ModelConfig::setNumKvHeadsPerLayer) .def_property("use_gpt_attention_plugin", py::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, py::const_), py::overload_cast(&tr::ModelConfig::useGptAttentionPlugin)) diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index ca0746980..4a79a64ee 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -93,7 +93,8 @@ void InitBindings(pybind11::module_& m) py::enum_(m, "CapacitySchedulerPolicy") .value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION) - .value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT); + .value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) + .value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH); py::enum_(m, "ContextChunkingPolicy") .value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS) @@ -299,7 +300,8 @@ void InitBindings(pybind11::module_& m) .def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize); py::class_(m, "ContextPhaseParams") - .def(py::init(), py::arg("first_gen_tokens")); + .def(py::init(), py::arg("first_gen_tokens"), + py::arg("req_id")); py::class_ request(m, "Request"); request @@ -631,14 +633,18 @@ void InitBindings(pybind11::module_& m) auto extendedRuntimePerfKnobConfigSetstate = [](py::tuple state) { - if (state.size() != 2) + if (state.size() != 4) { throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!"); } - return tle::ExtendedRuntimePerfKnobConfig(state[0].cast(), state[1].cast()); + return tle::ExtendedRuntimePerfKnobConfig( + state[0].cast(), state[1].cast(), state[2].cast(), state[2].cast()); }; auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self) - { return py::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc()); }; + { + return py::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc(), self.getCudaGraphMode(), + self.getCudaGraphCacheSize()); + }; py::class_(m, "ExtendedRuntimePerfKnobConfig") .def( py::init(), py::arg("multi_block_mode") = true, py::arg("enable_context_fmha_fp32_acc") = false) @@ -646,6 +652,10 @@ void InitBindings(pybind11::module_& m) &tle::ExtendedRuntimePerfKnobConfig::setMultiBlockMode) .def_property("enable_context_fmha_fp32_acc", &tle::ExtendedRuntimePerfKnobConfig::getEnableContextFMHAFP32Acc, &tle::ExtendedRuntimePerfKnobConfig::setEnableContextFMHAFP32Acc) + .def_property("cuda_graph_mode", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphMode, + &tle::ExtendedRuntimePerfKnobConfig::setCudaGraphMode) + .def_property("cuda_graph_cache_size", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphCacheSize, + &tle::ExtendedRuntimePerfKnobConfig::setCudaGraphCacheSize) .def(py::pickle(extendedRuntimePerfKnobConfigGetstate, extendedRuntimePerfKnobConfigSetstate)); auto executorConfigGetState = [](tle::ExecutorConfig const& self) diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index 5403b8cd2..8e1f57e9f 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -803,7 +803,7 @@ void GptDecoderBatched::forwardDispatch( } } -GptDecoderBatched::TokenPtr GptDecoderBatched::forwardAsync( +GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync( decoder_batch::Output& output, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -813,7 +813,7 @@ GptDecoderBatched::TokenPtr GptDecoderBatched::forwardAsync( CudaEvent eventStop{}; mRuntimeStream->record(eventStop); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return std::make_unique(std::move(eventStop), input.active); + return std::make_unique(std::move(eventStop), input.active); } void GptDecoderBatched::forwardDecoder( @@ -1019,12 +1019,12 @@ void GptDecoderBatched::forwardDecoder( TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void GptDecoderBatched::updateFinished(decoder_batch::Token const& token) +void GptDecoderBatched::updateFinished(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); for (std::int32_t i = 0; i < mActualBatchSize; ++i) { - if (token.active[i] && !mFinished[i]) + if (decoderFinishEvent.active[i] && !mFinished[i]) { auto finishedSum = ITensor::slice(mJointDecodingOutput->finishedSum, i, 1); mFinished[i] = mFinished[i] @@ -1035,25 +1035,25 @@ void GptDecoderBatched::updateFinished(decoder_batch::Token const& token) TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void GptDecoderBatched::forwardSync(decoder_batch::Token const& token) +void GptDecoderBatched::forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - token.event.synchronize(); + decoderFinishEvent.event.synchronize(); - updateFinished(token); + updateFinished(decoderFinishEvent); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void GptDecoderBatched::forwardSync( - decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input) +void GptDecoderBatched::forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent, + decoder_batch::Output& output, decoder_batch::Input const& input) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - token.event.synchronize(); + decoderFinishEvent.event.synchronize(); forwardDispatch(output, input, ForwardType::kSYNC); - updateFinished(token); + updateFinished(decoderFinishEvent); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -1231,7 +1231,7 @@ void GptDecoderBatched::forwardAsync(decoder::Output& output, decoder::Input con batchOutput.cacheIndirection = output.cacheIndirection; batchOutput.sequenceLengths = output.sequenceLengths; - mForwardToken = forwardAsync(batchOutput, batchInput); + mDecoderFinishEvent = forwardAsync(batchOutput, batchInput); mBufferManager.setZero(*mFinishedSum); kernels::reduce( *mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mRuntimeStream); @@ -1243,7 +1243,7 @@ void GptDecoderBatched::forwardAsync(decoder::Output& output, decoder::Input con void GptDecoderBatched::forwardSync() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - forwardSync(*mForwardToken); + forwardSync(*mDecoderFinishEvent); // wait for mFinishedSum to be updated mForwardEvent.synchronize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp index 620be923b..da58300fa 100644 --- a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp +++ b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp @@ -85,6 +85,8 @@ std::vector buildLayerTypes( auto constexpr layerNameAttention = "attention"; auto constexpr layerNameRecurrent = "recurrent"; + auto constexpr layerNameLinear = "linear"; + auto constexpr layerNameNoop = "no_op"; // The json field specifies a "group" of layers, which gets repeated multiple times // Note that the total number of layers does not need to be a multiple of a layer @@ -102,9 +104,17 @@ std::vector buildLayerTypes( { result[i] = ModelConfig::LayerType::kRECURRENT; } + else if (layerStringTypes[i % groupSize] == layerNameLinear) + { + result[i] = ModelConfig::LayerType::kLINEAR; + } + else if (layerStringTypes[i % groupSize] == layerNameNoop) + { + result[i] = ModelConfig::LayerType::kNOOP; + } else { - TLLM_LOG_ERROR("Unknown layer type: %s", layerStringTypes[i % groupSize].c_str()); + TLLM_LOG_WARNING("Unknown layer type: %s, assuming attention", layerStringTypes[i % groupSize].c_str()); } } @@ -147,9 +157,25 @@ ModelConfig createModelConfig( auto const mlpHiddenSize = parseJsonFieldOptional(config, mlpHiddenSizeField); - auto modelConfig = ModelConfig{vocabSize, numAttentionLayers, numRnnLayers, numHeads, hiddenSize, dataType}; + auto numKvHeadsPerAttentionLayer + = parseJsonFieldOr>(config, "num_kv_heads_per_layer", std::vector()); + + auto modelConfig + = ModelConfig{vocabSize, numLayers, numAttentionLayers, numRnnLayers, numHeads, hiddenSize, dataType}; + + if (!numKvHeadsPerAttentionLayer.empty()) + { + std::transform(numKvHeadsPerAttentionLayer.cbegin(), numKvHeadsPerAttentionLayer.cend(), + numKvHeadsPerAttentionLayer.begin(), + [tensorParallelism](SizeType32 const numKvHeads) { return std::max(numKvHeads / tensorParallelism, 1); }); + modelConfig.setNumKvHeadsPerLayer(numKvHeadsPerAttentionLayer); + } + else + { + modelConfig.setNbKvHeads(numKvHeads); + } + modelConfig.setSizePerHead(sizePerHead); - modelConfig.setNbKvHeads(numKvHeads); modelConfig.setLayerTypes(layerTypes); // Set logits datatype @@ -269,9 +295,24 @@ void parseLora(ModelConfig& modelConfig, Json const& json, Json const& pluginCon if (loraTargetModules.has_value()) { + auto const& loraModuleNames = loraTargetModules.value(); + auto const& numKvHeadsPerLayer = modelConfig.getNumKvHeadsPerLayer(); + if (!loraModuleNames.empty()) + { + TLLM_CHECK_WITH_INFO(std::all_of(numKvHeadsPerLayer.cbegin(), numKvHeadsPerLayer.cend(), + [firstNumKvHeads = numKvHeadsPerLayer[0]](SizeType32 numKvHeads) + { return numKvHeads == firstNumKvHeads; }), + "LORA with a VGQA model is not supported"); + } + // TODO(oargov): don't assume all layers have the same num_kv_heads to support VGQA + auto const numKvHeads = numKvHeadsPerLayer.empty() ? modelConfig.getNbHeads() : numKvHeadsPerLayer[0]; + bool hasMoE = !engineVersionNone && json.at("pretrained_config").contains("moe"); + auto const numExperts = hasMoE + ? json.at("pretrained_config").at("moe").at("num_experts").template get() + : SizeType32{0}; modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(), modelConfig.getHiddenSize(), - modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), modelConfig.getNbKvHeads(), - modelConfig.getSizePerHead(), tensorParallelism)); + modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), numKvHeads, modelConfig.getSizePerHead(), + tensorParallelism, numExperts)); } modelConfig.setMaxLoraRank(loraMaxRank); diff --git a/cpp/tensorrt_llm/runtime/gptSession.cpp b/cpp/tensorrt_llm/runtime/gptSession.cpp index c5d4dda55..c5bc84cf1 100644 --- a/cpp/tensorrt_llm/runtime/gptSession.cpp +++ b/cpp/tensorrt_llm/runtime/gptSession.cpp @@ -219,8 +219,13 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea // tokens, when enabling cyclic kv cache. auto const useOneMoreBlock = maxBeamWidth > 1 && maxSequenceLength > maxAttentionWindow; - auto const localNbLayers = mModelConfig.getNbAttentionLayers(mWorldConfig.getPipelineParallelism()); - auto const nbKvHeads = mModelConfig.getNbKvHeads(); + auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = mModelConfig.getNumKvHeadsPerLayerLocalRange( + mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank()); + TLLM_CHECK_WITH_INFO(std::all_of(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd, + [firstNumKvHeads = *numKvHeadsPerLayerBegin](SizeType32 numKvHeads) + { return numKvHeads == firstNumKvHeads; }), + "Deprecated session API does not support multiple cache pools, use the newer executor API instead"); + auto const sizePerHead = mModelConfig.getSizePerHead(); bool constexpr enableBlockReuse{false}; bool enableDiffMaxAttenWin = false; @@ -235,7 +240,8 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea TLLM_CHECK_WITH_INFO(maxBeamWidth == 1 || !enableDiffMaxAttenWin, "Can't support layer-wise max_attention_window with beam search. Please use a unified max_attention_window for " "all layers."); - mKvCacheManager = std::make_shared(localNbLayers, nbKvHeads, sizePerHead, tokensPerBlock, + mKvCacheManager = std::make_shared( + std::vector(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd), sizePerHead, tokensPerBlock, blocksInPrimaryPool, blocksInSecondaryPool, maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength, useOneMoreBlock, mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.onboardBlocks); @@ -253,6 +259,7 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea for (auto& buffers : mBuffers) { buffers->transformerBuffers->setKvPoolPointers(mKvCacheManager.get()); + buffers->transformerBuffers->setKvPoolMapping(mKvCacheManager.get()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp b/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp index 8ecb2061c..8f543f9ed 100644 --- a/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp +++ b/cpp/tensorrt_llm/runtime/lookaheadBuffers.cpp @@ -11,7 +11,9 @@ */ #include "tensorrt_llm/runtime/lookaheadBuffers.h" +#include "iTensor.h" #include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/layers/lookaheadDecodingUtils.h" #include "tensorrt_llm/runtime/common.h" namespace tensorrt_llm::runtime @@ -28,8 +30,6 @@ LookaheadDecodingBuffers::LookaheadDecodingBuffers( , positionIds( bufferManager.gpu(ITensor::makeShape({maxNumSequences, maxTokensPerStep}), nvinfer1::DataType::kINT32)) { - TLLM_LOG_DEBUG( - "LookaheadDecodingBuffers, maxNumSequences = %d, maxTokensPerStep = %d", maxNumSequences, maxTokensPerStep); } LookaheadRuntimeBuffers::LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, @@ -40,11 +40,11 @@ LookaheadRuntimeBuffers::LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeTy TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK_WITH_INFO(maxBeamWidth == 1, "Lookahead decoding does not support beam search"); - // auto const tokensPerStep = modelConfig.getMaxTokensPerStep(); auto const tokensPerStep = modelConfig.getMaxDecodingTokens(); auto const numPackedMasks = static_cast(tensorrt_llm::common::divUp(tokensPerStep, 32)); - // Copy buffers to device + cumSumLength = manager.pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); + packedMasksDevice = manager.gpu(ITensor::makeShape({maxBatchSize * tokensPerStep, numPackedMasks}), nvinfer1::DataType::kINT32); positionOffsetsDevice = manager.gpu(ITensor::makeShape({maxBatchSize, tokensPerStep}), nvinfer1::DataType::kINT32); @@ -76,24 +76,59 @@ void LookaheadRuntimeBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType auto const tokensPerStep = modelConfig.getMaxDecodingTokens(); + manager.copy(seqSlots, *batchSlotsHostCopy); + manager.copy(*decoderLookaheadBuffers.generationLengths, *generationLengthsHostCopy); manager.copy(*decoderLookaheadBuffers.positionOffsets, *positionOffsetsHostCopy); manager.copy(*decoderLookaheadBuffers.packedMasks, *packedMaskHostCopy); manager.copy(*decoderLookaheadBuffers.positionIds, *positionIdsHostCopy); - manager.copy(seqSlots, *batchSlotsHostCopy); - manager.copy(*decoderLookaheadBuffers.generationLengths, *generationLengthsHostCopy); manager.getStream().synchronize(); BufferRange batchSlotsRange(*batchSlotsHostCopy); + BufferRange cumSumLengthRange(*cumSumLength); + + SizeType32 maxGenerationLength = 0; + for (SizeType32 bi = 0; bi < numGenSequences; bi++) + { + SizeType32 gbi = batchSlotsRange[bi + numCtxSequences]; + SizeType32 theLength = BufferRange(*generationLengthsHostCopy)[gbi]; + maxGenerationLength = std::max(maxGenerationLength, theLength); + } + + auto positionOffsetShape = positionOffsetsHost->getShape(); + positionOffsetShape.d[1] = maxGenerationLength; + positionOffsetsHost->reshape(positionOffsetShape); + positionOffsetsDevice->reshape(positionOffsetShape); + + auto positionIdsShape = positionIdsHostCopy->getShape(); + auto positionIdsShape1D = ITensor::makeShape({ITensor::volume(positionIdsShape)}); + positionIdsHostCopy->reshape(positionIdsShape1D); + positionIdsHost->reshape(positionIdsShape1D); + + cumSumLengthRange[0] = 0; for (SizeType32 bi = 0; bi < numGenSequences; bi++) { SizeType32 gbi = batchSlotsRange[bi + numCtxSequences]; + SizeType32 theLength = BufferRange(*generationLengthsHostCopy)[gbi]; + manager.copy(*ITensor::at(generationLengthsHostCopy, {gbi}), *ITensor::at(generationLengthsHost, {bi})); - manager.copy(*ITensor::at(positionOffsetsHostCopy, {gbi}), *ITensor::at(positionOffsetsHost, {bi})); - manager.copy(*ITensor::slice(packedMaskHostCopy, gbi * tokensPerStep, tokensPerStep), - *ITensor::slice(packedMaskHost, bi * tokensPerStep, tokensPerStep)); - manager.copy(*ITensor::at(positionIdsHostCopy, {gbi}), *ITensor::at(positionIdsHost, {bi})); + + manager.copy(*ITensor::slice(positionOffsetsHostCopy, {gbi, 0}, theLength), + *ITensor::slice(positionOffsetsHost, {bi, 0}, theLength)); + + manager.copy(*ITensor::slice(packedMaskHostCopy, gbi * tokensPerStep, theLength), + *ITensor::slice(packedMaskHost, cumSumLengthRange[0], theLength)); + + manager.copy(*ITensor::slice(positionIdsHostCopy, gbi * tokensPerStep, theLength), + *ITensor::slice(positionIdsHost, cumSumLengthRange[0], theLength)); + + cumSumLengthRange[0] += theLength; } + + positionIdsHostCopy->reshape(positionIdsShape); + positionIdsHost->reshape(positionIdsShape); + positionIdsDevice->reshape(positionIdsShape); + manager.copy(*ITensor::slice(generationLengthsHost, 0, numGenSequences), *ITensor::slice(generationLengthsDevice, 0, numGenSequences)); manager.copy(*ITensor::slice(positionOffsetsHost, 0, numGenSequences), @@ -102,6 +137,7 @@ void LookaheadRuntimeBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType *ITensor::slice(packedMasksDevice, 0, numGenSequences * tokensPerStep)); manager.copy( *ITensor::slice(positionIdsHost, 0, numGenSequences), *ITensor::slice(positionIdsDevice, 0, numGenSequences)); + positionIdsDevice->reshape(ITensor::makeShape({cumSumLengthRange[0]})); manager.getStream().synchronize(); diff --git a/cpp/tensorrt_llm/runtime/loraModule.cpp b/cpp/tensorrt_llm/runtime/loraModule.cpp index 2716e78cb..8a8e2e559 100644 --- a/cpp/tensorrt_llm/runtime/loraModule.cpp +++ b/cpp/tensorrt_llm/runtime/loraModule.cpp @@ -21,7 +21,7 @@ namespace tensorrt_llm::runtime std::vector LoraModule::createLoraModules(std::vector const& loraModuleNames, SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads, - SizeType32 attentionHeadSize, SizeType32 tpSize) + SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts) { auto const hidden = hiddenSize * tpSize; auto const mlpHidden = mlpHiddenSize * tpSize; @@ -55,10 +55,10 @@ std::vector LoraModule::createLoraModules(std::vector c case ModuleType::kMLP_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break; // TODO(TRTLLM-379): Support MOE LoRA weights case ModuleType::kMOE_H_TO_4H: - case ModuleType::kMOE_GATE: - case ModuleType::kMOE_4H_TO_H: - case ModuleType::kMOE_ROUTER: - case ModuleType::kMLP_ROUTER: + case ModuleType::kMOE_GATE: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break; + case ModuleType::kMOE_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break; + case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break; + case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break; case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName); } } diff --git a/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp b/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp index c4f9d888b..6b9c1175f 100644 --- a/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp +++ b/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp @@ -15,11 +15,11 @@ */ #include "tensorrt_llm/runtime/rnnStateBuffers.h" +#include "iBuffer.h" #include "tensorrt_llm/runtime/runtimeBuffers.h" #include "tensorrt_llm/runtime/utils/sessionUtils.h" using namespace tensorrt_llm::runtime; -namespace tc = tensorrt_llm::common; RnnStateBuffers::RnnStateBuffers() { @@ -92,8 +92,8 @@ RnnStateBuffers::RnnStateBuffers( auto statePtrsShape = ITensor::makeShape({localNbLayers}); slotMappingDevice = bufferManager.gpu(slotMappingShape, nvinfer1::DataType::kINT32); slotMappingHost = BufferManager::cpu(slotMappingShape, nvinfer1::DataType::kINT32); - rnnStatePtrs = BufferManager::cpu(statePtrsShape, nvinfer1::DataType::kINT64); - convStatePtrs = BufferManager::cpu(statePtrsShape, nvinfer1::DataType::kINT64); + rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType::value); + convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType::value); } else { @@ -179,8 +179,8 @@ void RnnStateBuffers::fillStatePtrs() rnnStatePtr.resize(mLocalNbLayers); convStatePtr.resize(mLocalNbLayers); - void** rnnStatePtrArray = static_cast(rnnStatePtrs->data()); - void** convStatePtrArray = static_cast(convStatePtrs->data()); + auto* rnnStatePtrArray = bufferCast(*rnnStatePtrs); + auto* convStatePtrArray = bufferCast(*convStatePtrs); for (int i = 0; i < mLocalNbLayers; i++) { diff --git a/cpp/tensorrt_llm/runtime/tllmRuntime.cpp b/cpp/tensorrt_llm/runtime/tllmRuntime.cpp index 2eb101648..fe0cf7c8a 100644 --- a/cpp/tensorrt_llm/runtime/tllmRuntime.cpp +++ b/cpp/tensorrt_llm/runtime/tllmRuntime.cpp @@ -20,6 +20,7 @@ #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/common/safetensors.h" #include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/layers/lookaheadDecodingUtils.h" #include "tllmLogger.h" #include @@ -182,7 +183,9 @@ bool TllmRuntime::executeContext(SizeType32 contextIndex) const { NVTX3_FUNC_RANGE(); auto& context = getContext(contextIndex); - return context.enqueueV3(mStream->get()); + auto res = context.enqueueV3(mStream->get()); + sync_check_cuda_error(); + return res; } void TllmRuntime::setInputTensors(SizeType32 contextIndex, TensorMap const& tensorMap) diff --git a/cpp/tensorrt_llm/runtime/transformerBuffers.cpp b/cpp/tensorrt_llm/runtime/transformerBuffers.cpp index f8a78f091..fead9addf 100644 --- a/cpp/tensorrt_llm/runtime/transformerBuffers.cpp +++ b/cpp/tensorrt_llm/runtime/transformerBuffers.cpp @@ -15,12 +15,15 @@ */ #include "tensorrt_llm/runtime/transformerBuffers.h" +#include "iTensor.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/stlUtils.h" #include "tensorrt_llm/runtime/runtimeBuffers.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/utils/sessionUtils.h" #include // std::getenv +#include using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; @@ -34,6 +37,7 @@ TransformerBuffers::TransformerBuffers() presentKeysVals.clear(); presentKeysValsAlt.clear(); kvCacheBlockPoolPointers = nullptr; + kvCacheBlockPoolMapping = nullptr; kvCacheBlockOffsetsHost = nullptr; kvCacheBlockOffsetsDevice = nullptr; } @@ -101,15 +105,16 @@ void TransformerBuffers::reshape( auto const maxAttentionWindow = generationConfig.maxAttentionWindow; auto const kvCacheReserve = ITensor::makeShape( - {batchSize, 2, modelConfig.getNbKvHeads(), maxAttentionWindow, modelConfig.getSizePerHead()}); + {batchSize, 2, modelConfig.getNbKvHeads(0), maxAttentionWindow, modelConfig.getSizePerHead()}); auto const kvCacheShape - = ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()}); + = ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(0), maxInputLength, modelConfig.getSizePerHead()}); + if (modelConfig.isPagedKVCache()) { auto cacheBlockOffsetsShape = kvCacheBlockOffsetsHost->getShape(); if (cacheBlockOffsetsShape.nbDims > 0) { - cacheBlockOffsetsShape.d[0] = batchSize; + cacheBlockOffsetsShape.d[1] = batchSize; kvCacheBlockOffsetsHost->reshape(cacheBlockOffsetsShape); kvCacheBlockOffsetsDevice->reshape(cacheBlockOffsetsShape); } @@ -123,7 +128,8 @@ void TransformerBuffers::reshape( utils::reshapeBufferVector(presentKeysVals, kvCacheReserve); } - auto const localNbLayers = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()); + auto const localNbLayers + = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); if (modelConfig.useGptAttentionPlugin()) { @@ -147,7 +153,7 @@ void TransformerBuffers::reshapeKvTensors( { auto const& manager = runtime.getBufferManager(); - auto const cacheBlockOffsetsShape = ITensor::makeShape({maxBatchSize * maxBeamWidth, 2, maxBlocksPerSeq}); + auto const cacheBlockOffsetsShape = ITensor::makeShape({1, maxBatchSize * maxBeamWidth, 2, maxBlocksPerSeq}); kvCacheBlockOffsetsHost->reshape(cacheBlockOffsetsShape); manager.setZero(*kvCacheBlockOffsetsHost); @@ -161,6 +167,11 @@ void TransformerBuffers::setKvPoolPointers(KvCacheManager const* kvCacheManager) kvCacheBlockPoolPointers = kvCacheManager->getBlockPoolPointers(); } +void TransformerBuffers::setKvPoolMapping(KvCacheManager const* kvCacheManager) +{ + kvCacheBlockPoolMapping = kvCacheManager->getLayerToPoolMapping(); +} + TransformerBuffers TransformerBuffers::sliceTo( GenerationConfig const& generationConfig, ModelConfig const& modelConfig, SizeType32 offset, SizeType32 batchSize) { @@ -169,8 +180,15 @@ TransformerBuffers TransformerBuffers::sliceTo( auto const generationBatchSize = generationConfig.batchSize; if (modelConfig.isPagedKVCache()) { + auto const& realCacheBlockOffsetsShape = kvCacheBlockOffsetsHost->getShape(); - auto const maxBlocksPerSeq = realCacheBlockOffsetsShape.d[2]; + auto const numPools = realCacheBlockOffsetsShape.d[0]; + // (oargov) with multiple pools, slicing the tensor along the batch*beam dimension would require us to support + // non-contiguous tensors. with a single pool, we can just ignore the pools dimension when slicing and restore + // it later. this is part of the deprecated GPTSession API, so not supporting VGQA here should be ok. + TLLM_CHECK_WITH_INFO(numPools == 1, + "Deprecated transformerBuffers API does not support multiple cache pools, use the newer API instead"); + auto const maxBlocksPerSeq = realCacheBlockOffsetsShape.d[3]; // enable slicing by moving generationBatchSize to first dim auto const fakeCacheBlockOffsetsShape = ITensor::makeShape({generationBatchSize, 2, maxBlocksPerSeq}); @@ -178,13 +196,14 @@ TransformerBuffers TransformerBuffers::sliceTo( TensorPtr kvCacheBlockOffsetsDeviceView{ITensor::view(kvCacheBlockOffsetsDevice, fakeCacheBlockOffsetsShape)}; // slice and reshape to correct shape - auto const cacheBlockOffsetsShape = ITensor::makeShape({batchSize, 2, maxBlocksPerSeq}); + auto const cacheBlockOffsetsShape = ITensor::makeShape({numPools, batchSize, 2, maxBlocksPerSeq}); buffers.kvCacheBlockOffsetsHost = ITensor::slice(kvCacheBlockOffsetsHostView, offset, batchSize); buffers.kvCacheBlockOffsetsHost->reshape(cacheBlockOffsetsShape); buffers.kvCacheBlockOffsetsDevice = ITensor::slice(kvCacheBlockOffsetsDeviceView, offset, batchSize); buffers.kvCacheBlockOffsetsDevice->reshape(cacheBlockOffsetsShape); buffers.kvCacheBlockPoolPointers = kvCacheBlockPoolPointers; + buffers.kvCacheBlockPoolMapping = kvCacheBlockPoolMapping; } else { @@ -529,7 +548,7 @@ void TransformerBuffers::postContextStep(RuntimeBuffers* runtimeBuffers, if (modelConfig.useGptAttentionPlugin() && modelConfig.isPagedKVCache()) { auto cacheBlockOffsetsShape = kvCacheBlockOffsetsHost->getShape(); - cacheBlockOffsetsShape.d[0] = batchSize * beamWidth; + cacheBlockOffsetsShape.d[1] = batchSize * beamWidth; kvCacheBlockOffsetsHost->reshape(cacheBlockOffsetsShape); kvCacheBlockOffsetsDevice->reshape(cacheBlockOffsetsShape); } @@ -720,6 +739,7 @@ void TransformerBuffers::getRuntimeBuffers(RuntimeBuffers const* runtimeBuffers, inputBuffers.insert_or_assign("kv_cache_block_offsets", kvCacheBlockOffsetsDevice); inputBuffers.insert_or_assign("host_kv_cache_block_offsets", kvCacheBlockOffsetsHost); inputBuffers.insert_or_assign("host_kv_cache_pool_pointers", kvCacheBlockPoolPointers); + inputBuffers.insert_or_assign("host_kv_cache_pool_mapping", kvCacheBlockPoolMapping); } else { diff --git a/cpp/tensorrt_llm/runtime/transformerBuffers.h b/cpp/tensorrt_llm/runtime/transformerBuffers.h index 5e4a6a847..4692e9b0e 100644 --- a/cpp/tensorrt_llm/runtime/transformerBuffers.h +++ b/cpp/tensorrt_llm/runtime/transformerBuffers.h @@ -53,6 +53,7 @@ class TransformerBuffers runtime::TllmRuntime const& runtime); void setKvPoolPointers(KvCacheManager const* kvCacheManager); + void setKvPoolMapping(KvCacheManager const* kvCacheManager); void reset(BufferManager& manager){}; @@ -92,9 +93,10 @@ class TransformerBuffers TensorPtr maxAttentionWindows; // with attention plugin, host tensor TensorPtr sinkTokenLengths; // with attention plugin, host tensor TensorPtr kvCacheBlockPoolPointers; - TensorPtr kvCacheBlockOffsetsHost; // [batchSize * beamWidth, 2, maxBlocksPerSeq * 2] - TensorPtr kvCacheBlockOffsetsDevice; // [batchSize * beamWidth, 2, maxBlocksPerSeq * 2] - TensorPtr runtimePerfKnobsHost; // can hold max 16 perf knobs + TensorPtr kvCacheBlockPoolMapping; + TensorPtr kvCacheBlockOffsetsHost; // [numPools, batchSize * beamWidth, 2, maxBlocksPerSeq * 2] + TensorPtr kvCacheBlockOffsetsDevice; // [numPools, batchSize * beamWidth, 2, maxBlocksPerSeq * 2] + TensorPtr runtimePerfKnobsHost; // can hold max 16 perf knobs }; } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/utils/sessionUtils.cpp b/cpp/tensorrt_llm/runtime/utils/sessionUtils.cpp index a15cc1f0d..f324cf5f9 100644 --- a/cpp/tensorrt_llm/runtime/utils/sessionUtils.cpp +++ b/cpp/tensorrt_llm/runtime/utils/sessionUtils.cpp @@ -22,6 +22,7 @@ #include #include +#include using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; @@ -89,6 +90,16 @@ void reshapeBufferVector(std::vector& vector, nvinfer1::Dims } } +void assertNoVGQA(ModelConfig const& modelConfig, WorldConfig const& worldConfig) +{ + auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange( + worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); + TLLM_CHECK_WITH_INFO(std::all_of(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd, + [firstNumKvHeads = *numKvHeadsPerLayerBegin](SizeType32 numKvHeads) + { return numKvHeads == firstNumKvHeads; }), + "Deprecated session API does not support multiple cache pools, use the newer executor API instead"); +} + std::vector sliceBufferVector( std::vector const& vector, SizeType32 const offset, SizeType32 const size) { diff --git a/cpp/tensorrt_llm/runtime/utils/sessionUtils.h b/cpp/tensorrt_llm/runtime/utils/sessionUtils.h index 5fdd94f3e..4627cb369 100644 --- a/cpp/tensorrt_llm/runtime/utils/sessionUtils.h +++ b/cpp/tensorrt_llm/runtime/utils/sessionUtils.h @@ -56,6 +56,8 @@ std::vector createBufferVector( void reshapeBufferVector(std::vector& vector, nvinfer1::Dims const& shape); +void assertNoVGQA(ModelConfig const& modelConfig, WorldConfig const& worldConfig); + std::vector sliceBufferVector( std::vector const& vector, SizeType32 offset, SizeType32 size); diff --git a/cpp/tests/layers/lookaheadAlgorithmTest.cpp b/cpp/tests/layers/lookaheadAlgorithmTest.cpp index fc70b2bff..68eb8c193 100644 --- a/cpp/tests/layers/lookaheadAlgorithmTest.cpp +++ b/cpp/tests/layers/lookaheadAlgorithmTest.cpp @@ -21,6 +21,7 @@ #include "tensorrt_llm/layers/lookaheadAlgorithm.h" #include "tensorrt_llm/layers/lookaheadDecodingUtils.h" #include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/lookaheadModule.h" #include "tests/layers/randomLlm.h" @@ -84,9 +85,10 @@ TEST_P(LookaheadAlgorithmTest, predict) std::tie(std::ignore, std::ignore, maxDraftLenRuntime, std::ignore) = executor::LookaheadDecodingConfig(w, n, g).calculateSpeculativeResource(); auto shape = ITensor::makeShape({maxTokensPerStep}); + auto shape2d = ITensor::makeShape({maxTokensPerStep, maxTokensPerStep}); auto shapeSingle = ITensor::makeShape({1}); TensorPtr posidMax = BufferManager::cpu(shape, nvinfer1::DataType::kINT32); - TensorPtr smaskMax = BufferManager::cpu(shape, nvinfer1::DataType::kBOOL); + TensorPtr attentionMaskMax = BufferManager::cpu(shape2d, nvinfer1::DataType::kBOOL); TensorPtr inputLengthPtr = BufferManager::cpu(shapeSingle, nvinfer1::DataType::kINT32); auto& inputLength(*BufferRange(*inputLengthPtr).begin()); @@ -123,26 +125,34 @@ TEST_P(LookaheadAlgorithmTest, predict) { TLLM_LOG_DEBUG("\noracle[%d] = '%c'", sequenceLength - 1, static_cast(sequenceRange[sequenceLength - 1])); bufferCast(*posidMax)[0] = sequenceLength - 1; - bufferCast(*smaskMax)[0] = true; + BufferLocation amaskLocation(*attentionMaskMax); + for (auto& item : amaskLocation) + { + item = false; + } + for (SizeType32 i = 0; i < maxTokensPerStep; i++) + { + amaskLocation.at(i, 0) = true; + } + algo.prepare( // ITensor::slice(sequence, sequenceLength, maxDraftLenRuntime), // ITensor::slice(posidMax, 1, maxDraftLenRuntime), // - ITensor::slice(smaskMax, 1, maxDraftLenRuntime), // inputLengthPtr, // + attentionMaskMax, 1, // sequenceLengthPtr, // ITensor::slice(sequence, sequenceLength - 1, 1)); TensorPtr input = ITensor::slice(sequence, sequenceLength - 1, inputLength + 1); TensorPtr posid = ITensor::slice(posidMax, 0, inputLength + 1); - TensorPtr smask = ITensor::slice(smaskMax, 0, inputLength + 1); + TensorPtr amask = ITensor::slice(attentionMaskMax, 0, inputLength + 1); PRINT_TOKENS(input); PRINT_VALUES(posid); - PRINT_VALUES(smask); + PRINT_VALUES(amask); TensorPtr output = ITensor::slice(outputMax, 0, inputLength + 1); - llm.foretell(output, input, posid); - llm.sampleByMask(output, smask); + llm.foretell(output, input, posid, amask); PRINT_TOKENS(output); // algo.update(acceptedMax, acceptedOffsetsMax, acceptedLengthPtr, output, endIdPtr); @@ -207,4 +217,46 @@ INSTANTIATE_TEST_CASE_P(CombineLookaheadAlgorithmTestSmall_222, LookaheadAlgorit testing::Combine(testing::Values(std::make_tuple(2, 2)), testing::Values(std::make_tuple(2, 2)), testing::Values(std::make_tuple(2, 2)))); +TEST(LookaheadAlgorithmTest, treeEncodeTest) +{ + auto testWithData = [](TensorPtr inputTokens, TensorPtr inputPosIds, SizeType32 lastPosId, SizeType32 gold_len) + { + auto shape = inputTokens->getShape(); + auto shape2d = ITensor::makeShape({shape.d[0], shape.d[0]}); + + TensorPtr inputMasks = BufferManager::cpu(shape2d, nvinfer1::DataType::kBOOL); + LookaheadAlgorithm::posIdsToMask(inputMasks, inputPosIds); + + TensorPtr outputTokens = BufferManager::cpu(shape, nvinfer1::DataType::kINT32); + TensorPtr outputPosIds = BufferManager::cpu(shape, nvinfer1::DataType::kINT32); + TensorPtr encodeMap = BufferManager::cpu(shape, nvinfer1::DataType::kINT32); + TensorPtr outputMasks = BufferManager::cpu(shape2d, nvinfer1::DataType::kBOOL); + + // auto len = LookaheadAlgorithm::treeEncode(outputTokens, outputPosIds, outputMasks, inputTokens, inputPosIds, + // inputMasks, '$', 9); + auto len = LookaheadAlgorithm::treeEncode(inputTokens, inputPosIds, inputMasks, encodeMap); + TLLM_LOG_DEBUG("len = %d", len); + + EXPECT_EQ(len, gold_len); + }; + + testWithData( // + initTensor(std::string("01234512345")), // + initTensor({10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}), // + 9, 6); + + testWithData( // + initTensor(std::string("01234512abc")), // + initTensor({10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}), // + 9, 9); + + testWithData( // + initTensor(std::string("01234512abc2aBCD")), // + initTensor({10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 12, 13, 14, 15, 16}), // + 9, 12); + + testWithData(initTensor(std::string("wmplhi folxamp")), + initTensor({21, 22, 23, 24, 25, 26, 27, 21, 22, 23, 24, 21, 22, 23, 24}), 20, 15); +} + } // namespace tensorrt_llm::tests::layers diff --git a/cpp/tests/layers/lookaheadDecodingLayerTest.cpp b/cpp/tests/layers/lookaheadDecodingLayerTest.cpp index e3460a52b..f8f7c04f7 100644 --- a/cpp/tests/layers/lookaheadDecodingLayerTest.cpp +++ b/cpp/tests/layers/lookaheadDecodingLayerTest.cpp @@ -230,11 +230,11 @@ class LookaheadDecodingLayerTest : public testing::Test TensorPtr mNumNewTokensCumSum; TensorPtr mPathsOffsets; TensorPtr mDraftLengths; + TensorPtr mPrevDraftLengths; TensorPtr mDraftTokens; TensorPtr mPackedMasks; TensorPtr mPackedMasksBool; TensorPtr mGenerationLengths; - TensorPtr mGenerationLengthsMax; TensorPtr mPositionOffsets; TensorPtr mPositionIds; TensorPtr mAttentionPackedMask; @@ -371,6 +371,7 @@ void LookaheadDecodingLayerTest::allocateBuffers() ITensor::makeShape({mMaxTokensPerStep, maxBatchSize, 1}), nvinfer1::DataType::kINT32); mNumNewTokens = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); mDraftLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); + mPrevDraftLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); mDraftTokens = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32); auto packedMaskShape = ITensor::makeShape( @@ -382,7 +383,6 @@ void LookaheadDecodingLayerTest::allocateBuffers() mPathsOffsets = BufferManager::pinnedPool( ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32); mGenerationLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); - mGenerationLengthsMax = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32); mPositionOffsets = BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, mMaxTokensPerStep}), nvinfer1::DataType::kINT32); mPositionIds @@ -462,10 +462,8 @@ void LookaheadDecodingLayerTest::newRequests(std::vector requestIds) setupParams->prompt.emplace_back(mPrompt[gbi]); setupParams->algoConfigs.emplace_back(mTestParam.w, mTestParam.n, mTestParam.g); PRINT_TOKENS(setupParams->prompt[bi]); - setupParams->generationLengths = mGenerationLengthsMax; - setupParams->actualGenerationLengths = mGenerationLengths; + setupParams->generationLengths = mGenerationLengths; setupParams->positionOffsets = mPositionOffsets; - // setupParams->outputs.positionIds = mPositionIds; setupParams->attentionPackedMasks = mPackedMasks; } std::vector seed(requestIds.begin(), requestIds.end()); @@ -669,14 +667,14 @@ void LookaheadDecodingLayerTest::decodeForward() PRINT_VALUES(mSequenceLengths); outputParams->sequenceLength = mSequenceLengths; outputParams->nextDraftLengths = mDraftLengths; + outputParams->prevDraftLengths = mPrevDraftLengths; outputParams->nextDraftTokens = mDraftTokens; outputParams->packedMasks = mPackedMasks; outputParams->numNewTokens = mNumNewTokens; outputParams->newTokens = mNewTokens; outputParams->numNewTokensCumSum = mNumNewTokensCumSum; outputParams->pathsOffsets = mPathsOffsets; - outputParams->generationLengths = mGenerationLengthsMax; - outputParams->actualGenerationLengths = mGenerationLengths; + outputParams->generationLengths = mGenerationLengths; outputParams->positionOffsets = mPositionOffsets; outputParams->positionIds = mPositionIds; outputParams->packedMasks = mPackedMasks; diff --git a/cpp/tests/layers/randomLlm.cpp b/cpp/tests/layers/randomLlm.cpp index 2116186a6..9746286d9 100644 --- a/cpp/tests/layers/randomLlm.cpp +++ b/cpp/tests/layers/randomLlm.cpp @@ -276,8 +276,8 @@ void LookaheadRandomLlm::foretell(TensorPtr const& output, TensorConstPtr const& { right &= maskLocation.at(i, j) ? oracleRange[positionRange[j]] == inputRange[j] : true; } - if (i < verifyStart) - { // lookahead might be right + if (i < verifyStart && false) + { // lookahead might be right. Since we verify lookahead branch, then must be right. outputRange[i] = ((right || rand() % 5) && legal) ? oracleRange[positionRange[i] + 1] : invalid; } else diff --git a/cpp/tests/resources/scripts/build_llama_engines.py b/cpp/tests/resources/scripts/build_llama_engines.py index 425b636f4..12f56b364 100644 --- a/cpp/tests/resources/scripts/build_llama_engines.py +++ b/cpp/tests/resources/scripts/build_llama_engines.py @@ -90,7 +90,7 @@ def build_engines(model_cache: str, only_multi_gpu: bool): tp_pp_sizes = [(1, 1)] if only_multi_gpu: - tp_pp_sizes = [(1, 4), (4, 1), (1, 2), (2, 2)] + tp_pp_sizes = [(1, 4), (4, 1), (1, 2), (2, 2), (2, 1)] for tp_size, pp_size in tp_pp_sizes: tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu" print(f"\nBuilding fp16 tp{tp_size} pp{pp_size} engine") diff --git a/cpp/tests/resources/scripts/generate_expected_llama_output.py b/cpp/tests/resources/scripts/generate_expected_llama_output.py index 08d904201..cff87fbe0 100644 --- a/cpp/tests/resources/scripts/generate_expected_llama_output.py +++ b/cpp/tests/resources/scripts/generate_expected_llama_output.py @@ -72,7 +72,7 @@ def generate_outputs(num_beams, only_multi_gpu=False): elif COMM_WORLD.size == 4: tp_pp_sizes = [(4, 1), (2, 2), (1, 4)] elif COMM_WORLD.size == 2: - tp_pp_sizes = [(1, 2)] + tp_pp_sizes = [(1, 2), (2, 1)] else: raise RuntimeError( f"The world size of MPI {COMM_WORLD.size} is not equal to 1, 2, or 4." diff --git a/cpp/tests/resources/scripts/test_cpp.py b/cpp/tests/resources/scripts/test_cpp.py index f8ca28a1c..7e40a4cd4 100755 --- a/cpp/tests/resources/scripts/test_cpp.py +++ b/cpp/tests/resources/scripts/test_cpp.py @@ -664,6 +664,17 @@ def run_single_gpu_tests(build_dir: _pl.Path, if excluded_tests: ctest.extend(["-E", "|".join(excluded_tests)]) parallel_run_ctest(ctest, cwd=build_dir, env=cpp_env, timeout=timeout) + if run_gpt: + xml_output_file = build_dir / "results-single-gpu-disagg-executor_gpt.xml" + trt_model_test = produce_mpirun_command( + global_commands=["mpirun", "--allow-run-as-root"], + nranks=2, + local_commands=[ + "tests/executor/executorTest", + "--gtest_filter=*GptSingleDeviceDisaggExecutorTest*" + ], + leader_commands=[f"--gtest_output=xml:{xml_output_file}"]) + run_command(trt_model_test, cwd=build_dir, env=cpp_env, timeout=timeout) def produce_mpirun_command(*, global_commands, nranks, local_commands, @@ -777,25 +788,37 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500): run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500) new_env = copy.copy(cpp_env) - xml_output_file = build_dir / "results-multi-gpu-dist-executor_gpt.xml" + xml_output_file = build_dir / "results-multi-gpu-disagg-executor-2-process.xml" trt_model_test = produce_mpirun_command( global_commands=["mpirun", "--allow-run-as-root"], nranks=2, local_commands=[ - "executor/executorTest", - "--gtest_filter=DistExecutorTest.GPTTokenComparison" + "executor/executorTest", "--gtest_filter=*DisaggExecutorTest*" ], leader_commands=[f"--gtest_output=xml:{xml_output_file}"]) run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500) new_env = copy.copy(cpp_env) - xml_output_file = build_dir / "results-multi-gpu-dist-executor_chatglm.xml" + new_env["RUN_LLAMA_MULTI_GPU"] = "true" + xml_output_file = build_dir / "results-multi-gpu-disagg-executor-4-process.xml" trt_model_test = produce_mpirun_command( global_commands=["mpirun", "--allow-run-as-root"], - nranks=2, + nranks=4, + local_commands=[ + "executor/executorTest", "--gtest_filter=*DisaggExecutorTest*" + ], + leader_commands=[f"--gtest_output=xml:{xml_output_file}"]) + run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500) + + new_env = copy.copy(cpp_env) + new_env["RUN_LLAMA_MULTI_GPU"] = "true" + xml_output_file = build_dir / "results-multi-gpu-disagg-executor-8-process.xml" + trt_model_test = produce_mpirun_command( + global_commands=["mpirun", "--allow-run-as-root"], + nranks=8, local_commands=[ "executor/executorTest", - "--gtest_filter=DistExecutorTest.ChatGLMTokenComparison" + "--gtest_filter=*LlamaTP2PP2DisaggExecutorTest*" ], leader_commands=[f"--gtest_output=xml:{xml_output_file}"]) run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500) diff --git a/cpp/tests/runtime/gptDecoderBatchedTest.cpp b/cpp/tests/runtime/gptDecoderBatchedTest.cpp index 152a060f0..7ea3a00ab 100644 --- a/cpp/tests/runtime/gptDecoderBatchedTest.cpp +++ b/cpp/tests/runtime/gptDecoderBatchedTest.cpp @@ -195,7 +195,8 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector& sa SizeType32 constexpr nbRnnLayers{0}; SizeType32 constexpr nbHeads{16}; SizeType32 constexpr hiddenSize{1024}; - ModelConfig modelConfig{vocabSize, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype}; + ModelConfig modelConfig{ + vocabSize, nbAttentionLayers + nbRnnLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype}; modelConfig.useGptAttentionPlugin(false); auto streamPtr = std::make_shared(); @@ -315,7 +316,8 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector(); @@ -440,7 +442,8 @@ void testDecoderDraft(nvinfer1::DataType const dtype, std::vector(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); + mModelConfig = std::make_unique(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); mModelConfig->setMlpHiddenSize(32); mWorldConfig = std::make_unique(2, 1, 0); std::vector modules{ @@ -166,8 +166,8 @@ TEST_F(LoraCacheTest, LoraCachePageManagerTest) TEST_F(LoraCacheTest, determineNumPages) { - ModelConfig modelConfig(0, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT); - modelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2)); + ModelConfig modelConfig(0, 2, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT); + modelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2, 0)); WorldConfig worldConfig(1, 1, 0); LoraCachePageManagerConfig pageConfig(MemoryType::kCPU, nvinfer1::DataType::kFLOAT, 12393, 40, 80, 16, 1); @@ -358,7 +358,7 @@ TEST_F(LoraCacheTest, basicPutGet) TEST_F(LoraCacheTest, splitTransposeCpu) { - auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); + auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); auto worldConfig = WorldConfig(2, 1, 0); SizeType32 const split{2}; @@ -421,7 +421,7 @@ TEST_F(LoraCacheTest, splitTransposeCpu) TEST_F(LoraCacheTest, copyToPages_tp1) { - auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); + auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); modelConfig.setMlpHiddenSize(32); auto worldConfig = WorldConfig(1, 1, 0); std::vector modules{ @@ -479,7 +479,7 @@ TEST_F(LoraCacheTest, copyToPages_tp1) TEST_F(LoraCacheTest, copyToPages_tp2_rank0) { - auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); + auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); modelConfig.setMlpHiddenSize(32); auto worldConfig = WorldConfig(2, 1, 0); std::vector modules{ @@ -536,7 +536,7 @@ TEST_F(LoraCacheTest, copyToPages_tp2_rank0) TEST_F(LoraCacheTest, copyToPages_tp2_rank1) { - auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); + auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); modelConfig.setMlpHiddenSize(32); auto worldConfig = WorldConfig(2, 1, 1); std::vector modules{ diff --git a/cpp/tests/runtime/loraManagerTest.cpp b/cpp/tests/runtime/loraManagerTest.cpp index 496fb57b0..0718bb316 100644 --- a/cpp/tests/runtime/loraManagerTest.cpp +++ b/cpp/tests/runtime/loraManagerTest.cpp @@ -59,7 +59,7 @@ class LoraManagerTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-t { protected: LoraManagerTest() - : mModelConfig(1, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT) + : mModelConfig(1, 2, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT) { } @@ -70,7 +70,7 @@ class LoraManagerTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-t mWorldConfig = WorldConfig(2); - mModelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2)); + mModelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2, 0)); } std::unique_ptr mManager; @@ -80,7 +80,7 @@ class LoraManagerTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-t PeftTable getPeftTable(SizeType32 tpRank = 0) { - auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); + auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); modelConfig.setMlpHiddenSize(32); auto worldConfig = WorldConfig(2, 2, 3); std::vector modules{ @@ -292,7 +292,7 @@ static std::tuple, std::vector, PeftTable> createF TEST_F(LoraManagerTest, fillInputTensors) { LoraManager loraManager; - auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); + auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT); modelConfig.setMlpHiddenSize(32); auto worldConfig = WorldConfig(1, 1, 0); std::vector modules{ diff --git a/cpp/tests/runtime/loraUtilsTest.cpp b/cpp/tests/runtime/loraUtilsTest.cpp index 88b2b4936..b44303346 100644 --- a/cpp/tests/runtime/loraUtilsTest.cpp +++ b/cpp/tests/runtime/loraUtilsTest.cpp @@ -86,7 +86,7 @@ TEST_F(LoraUtilsTest, dims_mem_type) TEST_F(LoraUtilsTest, loraValidateRequestTensors) { - auto modelConfig = ModelConfig(0, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT); + auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT); auto worldConfig = WorldConfig(); std::optional optReqLoraWeights diff --git a/docker/common/install_cmake.sh b/docker/common/install_cmake.sh index db0dece6d..dac5e9d0a 100644 --- a/docker/common/install_cmake.sh +++ b/docker/common/install_cmake.sh @@ -3,7 +3,7 @@ set -ex ARCH=$(uname -m) -CMAKE_VERSION="3.24.4" +CMAKE_VERSION="3.30.2" PARSED_CMAKE_VERSION=$(echo $CMAKE_VERSION | sed 's/\.[0-9]*$//') CMAKE_FILE_NAME="cmake-${CMAKE_VERSION}-linux-${ARCH}" diff --git a/docker/common/install_tensorrt.sh b/docker/common/install_tensorrt.sh index 77d58c9b1..ff91fbb0e 100644 --- a/docker/common/install_tensorrt.sh +++ b/docker/common/install_tensorrt.sh @@ -2,7 +2,7 @@ set -ex -TRT_VER="10.3.0.26" +TRT_VER="10.4.0.26" # Align with the pre-installed cuDNN / cuBLAS / NCCL versions from # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-07.html#rel-24-07 CUDA_VER="12.5" # 12.5.1 @@ -14,6 +14,7 @@ CUBLAS_VER="12.5.3.2-1" # Align with the pre-installed CUDA / NVCC / NVRTC versions from # https://docs.nvidia.com/cuda/archive/12.5.1/cuda-toolkit-release-notes/index.html NVRTC_VER="12.5.82-1" +CUDA_RUNTIME="12.5.82-1" for i in "$@"; do case $i in @@ -71,12 +72,14 @@ install_centos_requirements() { yum -y install epel-release wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/libnccl-${NCCL_VER}.x86_64.rpm wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/libnccl-devel-${NCCL_VER}.x86_64.rpm - yum remove -y libnccl* && yum -y localinstall libnccl-${NCCL_VER}.x86_64.rpm libnccl-devel-${NCCL_VER}.x86_64.rpm - wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${NVRTC_VER}.noarch.rpm - yum remove -y cuda-toolkit* && yum -y localinstall cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${NVRTC_VER}.noarch.rpm + yum remove -y "libnccl*" && yum -y localinstall libnccl-${NCCL_VER}.x86_64.rpm libnccl-devel-${NCCL_VER}.x86_64.rpm + wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${CUDA_RUNTIME}.noarch.rpm + wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/cuda-toolkit-12-config-common-${CUDA_RUNTIME}.noarch.rpm + wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch.rpm + yum remove -y "cuda-toolkit*" && yum -y localinstall cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${CUDA_RUNTIME}.noarch.rpm cuda-toolkit-12-config-common-${CUDA_RUNTIME}.noarch.rpm cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch.rpm wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm - yum remove -y libcublas* && yum -y localinstall libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm + yum remove -y "libcublas*" && yum -y localinstall libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm yum clean all nvcc --version } @@ -84,7 +87,7 @@ install_centos_requirements() { install_tensorrt() { PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") - TRT_CUDA_VERSION="12.5" + TRT_CUDA_VERSION="12.6" if [ -z "$RELEASE_URL_TRT" ];then ARCH=${TRT_TARGETARCH} @@ -92,8 +95,8 @@ install_tensorrt() { if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi - if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi - RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz + if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu24_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi + RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz fi wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar tar -xf /tmp/TensorRT.tar -C /usr/local/ diff --git a/docs/source/advanced/batch-manager.md b/docs/source/advanced/batch-manager.md index d46e05d6d..4a6d8650a 100644 --- a/docs/source/advanced/batch-manager.md +++ b/docs/source/advanced/batch-manager.md @@ -147,6 +147,7 @@ Note: this feature isn't supported with the `V1` batching scheme for the moment. * `capacitySchedulerPolicy`, policy used to select the subset available requests in each iteration of the InflightBatching generation loop. - `MAX_UTILIZATION` packs as many requests as the underlying TRT engine can support in any iteration of the InflightBatching generation loop. While this is expected to maximize GPU throughput, it might require that some requests be paused and restarted depending on peak KV cache memory availability. - `GUARANTEED_NO_EVICT` uses KV cache more conservatively guaranteeing that a request, once started, will run to completion without eviction. + - `STATIC_BATCH` similarly to `GUARANTEED_NO_EVICT` schedules the maximum possible batch size without eviction. New requests are scheduled only after all requests in the previous batch have finished. ### Optional GptManager parameters * `TrtGptModelOptionalParams` class encapsulates the following fields: @@ -227,6 +228,9 @@ It can also adopt a more conservative approach and schedule requests only when i knows that the memory allocation will be sufficient to process all active requests even in the worst case of KV cache consumption. That mode corresponds to a `SchedulerConfig::capacitySchedulerPolicy` set to `kGUARANTEED_NO_EVICT`. +Another traditional batching scheme with a batch of requests running in lockstep +until generation for all of them is completed corresponds to +`SchedulerConfig::capacitySchedulerPolicy` set to `kSTATIC_BATCH`. The `GptManager`'s worker thread terminates when the `GptManager` destructor is called and there are no more active requests. diff --git a/docs/source/advanced/executor.md b/docs/source/advanced/executor.md index 3c7964614..8955c6bae 100644 --- a/docs/source/advanced/executor.md +++ b/docs/source/advanced/executor.md @@ -50,7 +50,7 @@ If replication is expensive or infeasible, use `LogitsPostProcessorConfig::setRe The `Request` class is used to define properties of the request, such as the input token ids and the maximum number of tokens to generate. The `streaming` parameter can be used to indicate if the request should generate a response for each new generated tokens (`streaming = true`) or only after all tokens have been generated (`streaming = false`). Other mandatory parameters of the request include the sampling configuration (defined by the `SamplingConfig` class) which contains parameters controlling the decoding process and the output configuration (defined by the `OutputConfig` class) which controls what information should be included in the `Result` for a particular response. -Optional parameters can also be provided when constructing a request such as a list of bad words, a list of stop words, a client id, or configurations objects for prompt tuning, LoRA, or speculative decoding for example. +Optional parameters can also be provided when constructing a request such as a list of bad words, a list of stop words, a client id, or configurations objects for prompt tuning, LoRA, or speculative decoding, or a number of sequences to generate for example. ### The Response Class @@ -58,7 +58,19 @@ The `awaitResponses` method of the `Executor` class returns a vector of response ### The Result Class -The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = true` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request. +The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false` and `numReturnSequences = 1`, a single response will be returned, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = true` and `numReturnSequences = 1` is used, a `Result` will include one or more tokens (depending on the request `returnAllGeneratedTokens` parameter) except the last result and the `isFinal` flag will be set to `true` for the last result associated with this request. + +The request `numReturnSequences` parameter controls the number of output sequences to generate for each prompt. When this option is used, the Executor will return at least `numReturnSequences` responses for each request, each containing one Result. The `sequenceIndex` attribute of the `Result` class indicates the index of the generated sequence in the result (`0 <= sequenceIndex < numReturnSequences`). It contains a Boolean parameter called `isSequenceFinal` that indicates if this is the last result for the sequence and also contains a Boolean parameter `isFinal` that indicates when all sequences for the request have been generated. When `numReturnSequences = 1`, `isFinal` is identical to `isSequenceFinal`. + +Here is an example that shows how a subset of 3 responses might look like for `numReturnSequences = 3`: + +``` +Response 1: requestId = 1, Result with sequenceIndex = 0, isSequenceFinal = false, isFinal = false +Response 2: requestId = 1, Result with sequenceIndex = 1, isSequenceFinal = true, isFinal = false +Response 3: requestId = 1, Result with sequenceIndex = 2, isSequenceFinal = false, isFinal = false +``` + +In this example, each response contains one result for different sequences. The `isSequenceFinal` flag of the second Result is set to true, indicating that it is the last result for `sequenceIndex = 1`, however, the isFinal flag of each Response is set to false because sequences 0 and 2 are not completed. ### Sending Requests with Different Beam Widths diff --git a/docs/source/architecture/model-weights-loader.md b/docs/source/architecture/model-weights-loader.md index 919d2713e..eb393d4a7 100644 --- a/docs/source/architecture/model-weights-loader.md +++ b/docs/source/architecture/model-weights-loader.md @@ -195,6 +195,7 @@ loader = ModelWeightsLoader(external_checkpoint_dir, llava_dict) loader.generate_tllm_weights(trtllm_model) ``` Users need to specify the different part from the default `tllm_to_externel_key_dict`. The loader still have support across different precisions. +The support for LLaVA and Exaone is in `LLaMAForCausalLM.from_hugging_face()` of [model.py](../../../tensorrt_llm/models/llama/model.py), and can also be taken as examples. ### Models with customized weight layout For models with different weight layout, users can write the conversion loop explicitly and do customized operations. @@ -225,9 +226,10 @@ for tllm_key, _ in tqdm(trtllm_model.named_parameters()): tllm_weights.update(loader.load(tllm_key, preprocess=customized_preprocess)) else: tllm_weights.update(loader.load(tllm_key)) -loader.check(tllm_weights) +loader.fill(tllm_weights) ``` This will apply `preprocess` after `load_tensor()` and before `postprocess`, and demonstrates how to convert the loaded shard into default HF layout. The loader still have support for precisions quantized from FP16/BF16 (e.g. INT8-wo/INT4-wo), the other precisions may require special operations, and can be addressed inside the `preprocess` function. +The support for Qwen-1 is in `QWenForCausalLM.from_hugging_face()` of [model.py](../../../tensorrt_llm/models/qwen/model.py), and can also be taken as example. ### Fully customized If the model weights loader cannot satisfy the requirements, users can write the conversion loop totally on their own. diff --git a/docs/source/installation/build-from-source-windows.md b/docs/source/installation/build-from-source-windows.md index 15ede162f..9dcb3e1b2 100644 --- a/docs/source/installation/build-from-source-windows.md +++ b/docs/source/installation/build-from-source-windows.md @@ -11,7 +11,7 @@ This section is for advanced users. Skip this section if you plan to use the pre 1. Install prerequisites listed in our [Installing on Windows](https://nvidia.github.io/TensorRT-LLM/installation/windows.html) document. 2. Install [CMake](https://cmake.org/download/), version 3.27.7 is recommended, and select the option to add it to the system path. 3. Download and install [Visual Studio 2022](https://visualstudio.microsoft.com/). -4. Download and unzip [TensorRT 10.3.0.26](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip). +4. Download and unzip [TensorRT 10.4.0.26](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip). ## Building a TensorRT-LLM Docker Image @@ -65,7 +65,7 @@ git submodule update --init --recursive 2. Build TensorRT-LLM. This command generates `build\tensorrt_llm-*.whl`. ```bash -python .\scripts\build_wheel.py -a "89-real" --trt_root C:\workspace\TensorRT-10.3.0.26\ +python .\scripts\build_wheel.py -a "89-real" --trt_root C:\workspace\TensorRT-10.4.0.26\ ``` 3. Copy or move `build\tensorrt_llm-*.whl` into your mounted folder so it can be accessed on your host machine. If you intend to use the C++ runtime, you'll also need to gather various DLLs from the build into your mounted folder. For more information, refer to [C++ Runtime Usage](#c-runtime-usage). @@ -103,7 +103,7 @@ python .\scripts\build_wheel.py -a "89-real" --trt_root C:\workspace\TensorRT-10 1. Install [CMake](https://cmake.org/download/), version 3.27.7 is recommended, and select the option to add it to the system path. 2. Download and install [Visual Studio 2022](https://visualstudio.microsoft.com/). When prompted to select more Workloads, check **Desktop development with C++**. - 3. Download and unzip [TensorRT 10.3.0.26](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip). Move the folder to a location you can reference later, such as `%USERPROFILE%\inference\TensorRT`. + 3. Download and unzip [TensorRT 10.4.0.26](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip). Move the folder to a location you can reference later, such as `%USERPROFILE%\inference\TensorRT`. 1. Add the libraries for TensorRT to your system's `Path` environment variable. Your `Path` should include a line like this: diff --git a/docs/source/installation/windows.md b/docs/source/installation/windows.md index 2d4a3a7e9..e105e998f 100644 --- a/docs/source/installation/windows.md +++ b/docs/source/installation/windows.md @@ -4,7 +4,7 @@ ```{note} The Windows release of TensorRT-LLM is currently in beta. -We recommend checking out the [v0.12.0 tag](https://github.com/NVIDIA/TensorRT-LLM/releases/tag/v0.12.0) for the most stable experience. +We recommend checking out the [v0.13.0 tag](https://github.com/NVIDIA/TensorRT-LLM/releases/tag/v0.13.0) for the most stable experience. ``` **Prerequisites** @@ -15,7 +15,7 @@ We recommend checking out the [v0.12.0 tag](https://github.com/NVIDIA/TensorRT-L 1. Install all dependencies together. - 1. Run the provided PowerShell script `setup_env.ps1` located under the `/windows/` folder which installs Python and CUDA 12.4.1 automatically with default settings. Run PowerShell as Administrator to use the script. + 1. Run the provided PowerShell script `setup_env.ps1` located under the `/windows/` folder which installs Python and CUDA 12.5.1 automatically with default settings. Run PowerShell as Administrator to use the script. ```bash ./setup_env.ps1 [-skipCUDA] [-skipPython] @@ -52,7 +52,7 @@ We recommend checking out the [v0.12.0 tag](https://github.com/NVIDIA/TensorRT-L before installing TensorRT-LLM with the following command. ```bash - pip install tensorrt_llm==0.12.0 --extra-index-url https://pypi.nvidia.com --extra-index-url https://download.pytorch.org/whl/cu121/ + pip install tensorrt_llm==0.13.0 --extra-index-url https://pypi.nvidia.com --extra-index-url https://download.pytorch.org/whl/ ``` Run the following command to verify that your TensorRT-LLM installation is working properly. diff --git a/docs/source/media/image-09-29-2024.png b/docs/source/media/image-09-29-2024.png new file mode 100644 index 000000000..840c76907 Binary files /dev/null and b/docs/source/media/image-09-29-2024.png differ diff --git a/docs/source/performance/perf-best-practices.md b/docs/source/performance/perf-best-practices.md index a5c81f1e4..ea9e8214a 100644 --- a/docs/source/performance/perf-best-practices.md +++ b/docs/source/performance/perf-best-practices.md @@ -85,10 +85,6 @@ select better kernels. However, this feature will increase the engine build time. -**Known issue**: We observed that enabling multiple profiles can lead to extra -unexpected GPU memory usage on some cases starting from v0.11. The issue will be -addressed in future releases. - ### GPT Attention Plugin and Context Fused Multi-Head Attention The GPT attention plugin and fused multi-head attention kernel are enabled by @@ -166,24 +162,21 @@ improve throughput. However, the following conditions have to be satisfied: 2. Both look_up plugin and gemm plugin are enabled, 3. The sharding dimension of the embedding lookup table is set correctly. -To enable the features, use the `--use_parallel_embedding`, -`--use_embedding_sharing`, `--use_lookup_plugin`, `--use_gemm_plugin` -arguments, and set correct dimension to `--embedding_sharding_dim` argument -with `trtllm-build`. See those +To enable the features, use the `--use_parallel_embedding`, `--embedding_sharding_dim` and +`--use_embedding_sharing` arguments in `convert_checkpoint.py`, and use the +`--lookup_plugin`, `--gemm_plugin` arguments in `trtllm-build` command. See those [Examples](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/gpt#embedding-parallelism-and-sharing) for details. ### Horizontal Fusion in Gated-MLP Horizontal fusion in Gated-MLP combines two Matmul operations into a single one -followed by a separate SwiGLU kernel. However, for FP8 PTQ, the -downside is slight reduction of accuracy because one of the quantization scaling -factors are discarded. - -If both model and batch sizes are large, it is recommended to enable the feature -by using the `--use_fused_mlp=enable` argument with `trtllm-build`. When the workload -is very small, or if you're using FP8 PTQ and the accuracy after enabling it -does not satisfy your requirement, it is not recommended to enable that feature. +followed by a separate SwiGLU kernel. It can effectively reduce latency. + +The feature is enabled by default. However, for FP8 PTQ, the downside is slight +reduction of accuracy because one of the quantization scaling factors are discarded. +If you're using FP8 PTQ and the accuracy does not satisfy your requirement, you +can try disable the feature by setting `--use_fused_mlp=disable` argument to `trtllm-build`. ### GEMM + SwiGLU Fusion in Gated-MLP diff --git a/docs/source/performance/perf-overview.md b/docs/source/performance/perf-overview.md index d125c3252..97e3affe1 100644 --- a/docs/source/performance/perf-overview.md +++ b/docs/source/performance/perf-overview.md @@ -18,12 +18,6 @@ performance that can be delivered by TensorRT-LLM. The following issues are being addressed to improve the efficiency of TensorRT-LLM. -### Unexpected extra GPU memory allocation when enabling `--multiple_profiles` - -We observed that enabling multiple profiles can lead to extra -unexpected GPU memory usage on some cases starting from v0.11. -The issue will be addressed in future releases. - ### Fused Matmul + Gated-SiLU (LLaMA) The current implementation combines two Matmul operations into one Matmul followed by @@ -45,7 +39,7 @@ The performance numbers below were collected using the steps described in this d | | | | | | | | | | | ------------ | ------------------------ | ------------- | --------------- | ----------- | -------------- | -------------- | -------------- | ------- | | | | **GPU** | H200 141GB HBM3 | GH200 120GB | H100 80GB HBM3 | H100 80GB HBM3 | A100-SXM4-80GB | L40S | -| | | **Precision** | FP8 | FP8 | FP8 | Mixed | Mixed | FP8 | +| | | **Precision** | FP8 | FP8 | FP8 | FP16 | FP16 | FP8 | | **Model** | **Input/Output Lengths** | **TP** | | | | | | | | GPTJ 6B | 128/128 | 1 | 24834.76 | 22454.79 | 24429.55 | 13085.91 | 5864.81 | 7647.24 | | | 128/2048 | 1 | 8348.93 | 6656.25 | 7831.38 | 3882.21 | 2194.57 | 1843.91 | diff --git a/docs/source/quick-start-guide.md b/docs/source/quick-start-guide.md index 4b90d62db..ea62b081e 100644 --- a/docs/source/quick-start-guide.md +++ b/docs/source/quick-start-guide.md @@ -58,72 +58,16 @@ python3 ../run.py --engine_dir ./llama-3.1-8b-engine --max_output_len 100 --tok To create a production-ready deployment of your LLM, use the [Triton Inference Server backend for TensorRT-LLM](https://github.com/triton-inference-server/tensorrtllm_backend) to leverage the TensorRT-LLM C++ runtime for rapid inference execution and include optimizations like in-flight batching and paged KV caching. Triton Inference Server with the TensorRT-LLM backend is available as a [pre-built container through NVIDIA NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags). -1. Pull down the example model repository so that Triton Inference Server can read the model and any associated metadata. - - ```bash - # After exiting the TensorRT-LLM Docker container - cd .. - git clone https://github.com/triton-inference-server/tensorrtllm_backend.git - cd tensorrtllm_backend - cp ../TensorRT-LLM/examples/llama/out/* all_models/inflight_batcher_llm/tensorrt_llm/1/ - ``` - - The `tensorrtllm_backend` repository includes the skeleton of a model repository under `all_models/inflight_batcher_llm/` that you can use. - -2. Copy the model you compiled ({ref}`quick-start-guide-compile`) to the example model repository. - -3. Modify the configuration files from the model repository. Specify the path to the compiled model engine, the tokenizer, and how to handle memory allocation for the KV cache when performing inference in batches. - - ```bash - python3 tools/fill_template.py --in_place \ - all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt \ - decoupled_mode:true,engine_dir:/all_models/inflight_batcher_llm/tensorrt_llm/1,\ - max_tokens_in_paged_kv_cache:,batch_scheduler_policy:guaranteed_completion,kv_cache_free_gpu_mem_fraction:0.2,\ - max_num_sequences:4 - - python tools/fill_template.py --in_place \ - all_models/inflight_batcher_llm/preprocessing/config.pbtxt \ - tokenizer_type:llama,tokenizer_dir:Meta-Llama-3.1-8B-Instruct - - python tools/fill_template.py --in_place \ - all_models/inflight_batcher_llm/postprocessing/config.pbtxt \ - tokenizer_type:llama,tokenizer_dir:Meta-Llama-3.1-8B-Instruct - ``` - -4. Start Triton Inference Server in the container. Specify `world_size`, which is the number of GPUs the model was built for, and point to the `model_repo` that was just set up. - - ```bash - docker run -it --rm --gpus all --network host --shm-size=1g \ - -v $(pwd)/all_models:/all_models \ - -v $(pwd)/scripts:/opt/scripts \ - nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 - - # Log in to huggingface-cli to get tokenizer - huggingface-cli login --token ***** - - # Install python dependencies - pip install sentencepiece protobuf - - # Launch Server - python /opt/scripts/launch_triton_server.py --model_repo /all_models/inflight_batcher_llm --world_size 1 - ``` - -## Send Requests - -Use one of the Triton Inference Server client libraries or send HTTP requests to the generated endpoint. To get started, you can use the more fully featured client script or the following command: - -```bash -curl -X POST localhost:8000/v2/models/ensemble/generate -d \ -'{ -"text_input": "How do I count to nine in French?", -"parameters": { -"max_tokens": 100, -"bad_words":[""], -"stop_words":[""] -} -}' +1. Clone the TensorRT-LLM backend repository: + +```console +cd .. +git clone https://github.com/triton-inference-server/tensorrtllm_backend.git +cd tensorrtllm_backend ``` +2. Refer to [End to end workflow to run llama 7b](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/llama.md) in the TensorRT-LLM backend repository to deploy the model with Triton Inference Server. + ## LLM API The LLM API is a Python API to setup & infer with TensorRT-LLM directly in python.It allows for optimizing models by specifying a HuggingFace repo name or a model checkpoint. The LLM API handles checkpoint conversion, engine building, engine loading, and model inference, all from one python object. @@ -146,7 +90,6 @@ In this Quick Start Guide, you: - Retrieved the model weights - Compiled and ran the model - Deployed the model with Triton Inference Server -- Sent HTTP requests For more examples, refer to: diff --git a/docs/source/reference/support-matrix.md b/docs/source/reference/support-matrix.md index 48f0caa03..bb7071aab 100644 --- a/docs/source/reference/support-matrix.md +++ b/docs/source/reference/support-matrix.md @@ -87,13 +87,11 @@ If a GPU is not listed, it is important to note that TensorRT-LLM is expected to - TensorRT-LLM requires Linux x86_64 or Windows. * - GPU Model Architectures - - - [NVIDIA Hopper H100 GPU](https://www.nvidia.com/en-us/data-center/h100/) - - [NVIDIA L40S GPU](https://www.nvidia.com/en-us/data-center/l40s/) - - [NVIDIA Ada Lovelace GPU](https://www.nvidia.com/en-us/technologies/ada-architecture/) - - [NVIDIA Ampere A100 GPU](https://www.nvidia.com/en-us/data-center/a100/) - - [NVIDIA A30 GPU](https://www.nvidia.com/en-us/data-center/products/a30-gpu/) - - [NVIDIA Turing T4 GPU](https://www.nvidia.com/en-us/data-center/tesla-t4/) - - [NVIDIA Volta V100 GPU](https://www.nvidia.com/en-us/data-center/v100/) (experimental) + - [NVIDIA Hopper Architecture](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/) + - [NVIDIA Ada Lovelace Architecture](https://www.nvidia.com/en-us/technologies/ada-architecture/) + - [NVIDIA Ampere Architecture](https://www.nvidia.com/en-us/data-center/ampere-architecture/) + - [NVIDIA Turing Architecture](https://www.nvidia.com/en-us/geforce/turing/) + - [NVIDIA Volta Architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) (experimental) ``` (support-matrix-software)= @@ -110,7 +108,7 @@ The following table shows the supported software for TensorRT-LLM. * - Container - [24.07](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html) * - TensorRT - - [10.3](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html) + - [10.4](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html) * - Precision - - Hopper (SM90) - FP32, FP16, BF16, FP8, INT8, INT4 diff --git a/docs/source/release-notes.md b/docs/source/release-notes.md index 4183e23ea..b4b77ea43 100644 --- a/docs/source/release-notes.md +++ b/docs/source/release-notes.md @@ -5,6 +5,65 @@ All published functionality in the Release Notes has been fully tested and verified with known limitations documented. To share feedback about this release, access our [NVIDIA Developer Forum](https://forums.developer.nvidia.com/). +## TensorRT-LLM Release 0.13.0 + +### Key Features and Enhancements + - Supported lookahead decoding (experimental), see `docs/source/speculative_decoding.md`. + - Added some enhancements to the `ModelWeightsLoader` (a unified checkpoint converter, see `docs/source/architecture/model-weights-loader.md`). + - Supported Qwen models. + - Supported auto-padding for indivisible TP shape in INT4-wo/INT8-wo/INT4-GPTQ. + - Improved performance on `*.bin` and `*.pth`. + - Supported OpenAI Whisper in C++ runtime. + - Added some enhancements to the `LLM` class. + - Supported LoRA. + - Supported engine building using dummy weights. + - Supported `trust_remote_code` for customized models and tokenizers downloaded from Hugging Face Hub. + - Supported beam search for streaming mode. + - Supported tensor parallelism for Mamba2. + - Supported returning generation logits for streaming mode. + - Added `curand` and `bfloat16` support for `ReDrafter`. + - Added sparse mixer normalization mode for MoE models. + - Added support for QKV scaling in FP8 FMHA. + - Supported FP8 for MoE LoRA. + - Supported KV cache reuse for P-Tuning and LoRA. + - Supported in-flight batching for CogVLM models. + - Supported LoRA for the `ModelRunnerCpp` class. + - Supported `head_size=48` cases for FMHA kernels. + - Added FP8 examples for DiT models, see `examples/dit/README.md`. + - Supported decoder with encoder input features for the C++ `executor` API. + +### API Changes + - [BREAKING CHANGE] Set `use_fused_mlp` to `True` by default. + - [BREAKING CHANGE] Enabled `multi_block_mode` by default. + - [BREAKING CHANGE] Enabled `strongly_typed` by default in `builder` API. + - [BREAKING CHANGE] Renamed `maxNewTokens`, `randomSeed` and `minLength` to `maxTokens`, `seed` and `minTokens` following OpenAI style. + - The `LLM` class + - [BREAKING CHANGE] Updated `LLM.generate` arguments to include `PromptInputs` and `tqdm`. + - The C++ `executor` API + - [BREAKING CHANGE] Added `LogitsPostProcessorConfig`. + - Added `FinishReason` to `Result`. + +### Model Updates + - Supported Gemma 2, see "Run Gemma 2" section in `examples/gemma/README.md`. + +### Fixed Issues + - Fixed an accuracy issue when enabling remove padding issue for cross attention. (#1999) + - Fixed the failure in converting qwen2-0.5b-instruct when using `smoothquant`. (#2087) + - Matched the `exclude_modules` pattern in `convert_utils.py` to the changes in `quantize.py`. (#2113) + - Fixed build engine error when `FORCE_NCCL_ALL_REDUCE_STRATEGY` is set. + - Fixed unexpected truncation in the quant mode of `gpt_attention`. + - Fixed the hang caused by race condition when canceling requests. + - Fixed the default factory for `LoraConfig`. (#1323) + +### Infrastructure Changes + - Base Docker image for TensorRT-LLM is updated to `nvcr.io/nvidia/pytorch:24.07-py3`. + - Base Docker image for TensorRT-LLM Backend is updated to `nvcr.io/nvidia/tritonserver:24.07-py3`. + - The dependent TensorRT version is updated to 10.4.0. + - The dependent CUDA version is updated to 12.5.1. + - The dependent PyTorch version is updated to 2.4.0. + - The dependent ModelOpt version is updated to v0.15. + + ## TensorRT-LLM Release 0.12.0 ### Key Features and Enhancements diff --git a/examples/apps/README.md b/examples/apps/README.md index a076951af..75d0e7663 100644 --- a/examples/apps/README.md +++ b/examples/apps/README.md @@ -28,7 +28,8 @@ curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": , - "prompt": "Where is New York?", + "messages":[{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Where is New York?"}], "max_tokens": 16, "temperature": 0 }' diff --git a/examples/apps/openai_client.py b/examples/apps/openai_client.py index bdcf2d49e..710cd7c63 100644 --- a/examples/apps/openai_client.py +++ b/examples/apps/openai_client.py @@ -38,7 +38,7 @@ def run_chat(args: argparse.Namespace): api_key="tensorrt_llm", ) prompt = args.prompt if args.prompt else "Where is New York?" - completion = client.chat.completions.create( + chat_completion = client.chat.completions.create( model="llama-v3-8b-instruct-hf", messages=[{ "role": "user", @@ -46,8 +46,10 @@ def run_chat(args: argparse.Namespace): }], top_p=args.top_p, temperature=args.temperature, + max_tokens=args.max_tokens, stream=args.stream, n=args.n, + logprobs=args.return_logprobs, extra_body={ "top_k": args.top_k, "use_beam_search": args.use_beam_search, @@ -55,11 +57,11 @@ def run_chat(args: argparse.Namespace): }, ) if args.stream: - for chunk in completion: + for chunk in chat_completion: print(chunk) else: - for choice in completion.choices: - print(choice.message) + for choice in chat_completion.choices: + print(choice) if __name__ == "__main__": @@ -78,6 +80,7 @@ def run_chat(args: argparse.Namespace): choices=["chat", "completions"], default="chat") parser.add_argument("--prompt", type=str, default=None) + parser.add_argument("--return_logprobs", action="store_true", default=False) args = parser.parse_args() if args.api == "chat": run_chat(args) diff --git a/examples/apps/openai_server.py b/examples/apps/openai_server.py index 89608b05f..5c502aa13 100644 --- a/examples/apps/openai_server.py +++ b/examples/apps/openai_server.py @@ -18,6 +18,7 @@ from tensorrt_llm.hlapi import LLM, BuildConfig, KvCacheConfig from tensorrt_llm.hlapi.llm import RequestOutput from tensorrt_llm.hlapi.openai_protocol import ( + ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, @@ -68,10 +69,15 @@ def __init__(self, kv_cache_config: KvCacheConfig, hf_tokenizer: PreTrainedTokenizer = None): self.llm = llm - self.model = model self.kv_cache_config = kv_cache_config self.tokenizer = hf_tokenizer + model_dir = Path(model) + if model_dir.exists() and model_dir.is_dir(): + self.model = model_dir.name + else: + self.model = model + self.app = FastAPI() @self.app.exception_handler(RequestValidationError) @@ -110,12 +116,7 @@ async def version(self) -> JSONResponse: return JSONResponse(content=ver) async def get_model(self) -> JSONResponse: - model_dir = Path(self.model) - if model_dir.exists() and model_dir.is_dir(): - model = model_dir.name - else: - model = self.model - model_list = ModelList(data=[ModelCard(id=model)]) + model_list = ModelList(data=[ModelCard(id=self.model)]) return JSONResponse(content=model_list.model_dump()) async def openai_chat(self, request: ChatCompletionRequest) -> Response: @@ -138,29 +139,47 @@ def stream_usage_info(prompt_tokens: int, completion_tokens: int): usage = None return usage + def create_logprobs(token_ids: List[int], + logprobs: List[float]) -> ChatCompletionLogProbs: + assert len(token_ids) == len(logprobs), \ + "token_ids and logprobs have different lengths" + content: List[ChatCompletionLogProbsContent] = [] + for token_id, logprob in zip(token_ids, logprobs): + token = self.tokenizer.decode(token_id) + # returning multiple logprobs is not supported + first_logprob = ChatCompletionLogProbsContent( + token=token, logprob=max(logprob, -9999.0), + bytes=list(token.encode("utf-8", errors="replace")) + ) + content.append(first_logprob) + chat_logprobs = ChatCompletionLogProbs(content=content) + return chat_logprobs + async def chat_stream_generator(promise: RequestOutput) -> AsyncGenerator[str, None]: first_iteration = True num_choices = 1 if request.n is None else request.n finish_reason_sent = [False] * num_choices role = get_role() + + def yield_first_chat(num_tokens: int, role: str = None, content: str = None): + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, content=content), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + choices=[choice_data], model=self.model) + chunk.usage = stream_usage_info(num_tokens, 0) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + async for res in promise: prompt_tokens = len(res.prompt_token_ids) if first_iteration: - # Send first response for each request.n (index) with - # the role - for i in range(num_choices): - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(role=role), - logprobs=None, - finish_reason=None) - chunk = ChatCompletionStreamResponse( - choices=[choice_data], model=request.model) - chunk.usage = stream_usage_info( - prompt_tokens, 0) - - data = chunk.model_dump_json() - yield f"data: {data}\n\n" + yield_first_chat(prompt_tokens, role=role) if request.echo: last_msg_content = "" @@ -171,22 +190,7 @@ async def chat_stream_generator(promise: RequestOutput) -> AsyncGenerator[str, N "content"] if last_msg_content: - for i in range(num_choices): - choice_data = ( - ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage( - content=last_msg_content), - logprobs=None, - finish_reason=None)) - chunk = ChatCompletionStreamResponse( - choices=[choice_data], - model=request.model) - chunk.usage = stream_usage_info( - prompt_tokens, 0) - data = chunk.model_dump_json( - exclude_unset=True) - yield f"data: {data}\n\n" + yield_first_chat(prompt_tokens, content=last_msg_content) first_iteration = False for output in res.outputs: @@ -213,8 +217,12 @@ async def chat_stream_generator(promise: RequestOutput) -> AsyncGenerator[str, N index=i, delta=delta_message, finish_reason=None) + if request.logprobs: + logprobs = output.logprobs_diff + token_ids = output.token_ids_diff + choice_data.logprobs = create_logprobs(token_ids, logprobs) chunk = ChatCompletionStreamResponse( - choices=[choice_data], model=request.model) + choices=[choice_data], model=self.model) chunk.usage = stream_usage_info( prompt_tokens, output.length) data = chunk.model_dump_json() @@ -233,7 +241,7 @@ async def chat_stream_generator(promise: RequestOutput) -> AsyncGenerator[str, N ) final_usage_chunk = ChatCompletionStreamResponse( - choices=[], model=request.model, usage=final_usage) + choices=[], model=self.model, usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() yield f"data: {final_usage_data}\n\n" @@ -259,6 +267,9 @@ async def create_chat_response(promise: RequestOutput) -> JSONResponse: index=output.index, message=message, ) + + if request.logprobs: + choice.logprobs = create_logprobs(output.token_ids, output.logprobs) choices.append(choice) if request.echo: @@ -279,7 +290,7 @@ async def create_chat_response(promise: RequestOutput) -> JSONResponse: total_tokens=num_prompt_tokens + num_generated_tokens, ) response = ChatCompletionResponse( - model=request.model, + model=self.model, choices=choices, usage=usage, ) @@ -303,12 +314,15 @@ async def create_chat_response(promise: RequestOutput) -> JSONResponse: ) sampling_params = request.to_sampling_params() - promise = self.llm.generate_async(prompt, sampling_params, - request.stream) + promise = self.llm.generate_async( + inputs=prompt, + sampling_params=sampling_params, + streaming=request.stream, + ) if request.stream: response_generator = chat_stream_generator(promise) return StreamingResponse(content=response_generator, - media_type="text/event-stream") + media_type="text/event-stream") else: response = await create_chat_response(promise) return JSONResponse(content=response.model_dump()) @@ -352,7 +366,7 @@ async def create_completion_generator(generator: AsyncIterator[Tuple[int, Reques delta_text = prompt + delta_text echoed[response_idx] = True response = CompletionStreamResponse( - model=request.model, + model=self.model, choices=[ CompletionResponseStreamChoice( index=response_idx, text=delta_text) @@ -387,7 +401,7 @@ async def create_completion_response(generator: AsyncIterator[Tuple[int, Request total_tokens=num_gen_tokens + num_prompt_tokens, ) response = CompletionResponse( - model=request.model, + model=self.model, choices=choices, usage=usage_info, ) @@ -403,8 +417,11 @@ async def create_completion_response(generator: AsyncIterator[Tuple[int, Request promises: List[RequestOutput] = [] sampling_params = request.to_sampling_params() for prompt in prompts: - promise = self.llm.generate_async(prompt, sampling_params, - request.stream) + promise = self.llm.generate_async( + inputs=prompt, + sampling_params=sampling_params, + streaming=request.stream, + ) promises.append(promise) generator = merge_promises(promises) num_choices = len(prompts) if request.n is None else len(prompts) * request.n diff --git a/examples/apps/requirements.txt b/examples/apps/requirements.txt index e6c1aa78c..8f0fadd9c 100644 --- a/examples/apps/requirements.txt +++ b/examples/apps/requirements.txt @@ -1,4 +1,4 @@ -fastapi +fastapi==0.112 uvicorn colorama httpx diff --git a/examples/baichuan/requirements.txt b/examples/baichuan/requirements.txt index 5c1d7709e..75b8c7914 100644 --- a/examples/baichuan/requirements.txt +++ b/examples/baichuan/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.15.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/bloom/README.md b/examples/bloom/README.md index c6506ff8d..16066b577 100644 --- a/examples/bloom/README.md +++ b/examples/bloom/README.md @@ -123,7 +123,7 @@ trtllm-build --checkpoint_dir ./bloom/176B/trt_ckpt/fp16/8-gpu/ \ --workers 2 # share embedding table between embedding() and lm_head() layers -# To reduce the generated engine size, we has to use gemm and lookup plugin (--use_gemm_plugin --use_lookup_plugin) and must shard the embedding table in the vocab dimension. +# To reduce the generated engine size, we has to use gemm and lookup plugin (--gemm_plugin --lookup_plugin) and must shard the embedding table in the vocab dimension. python convert_checkpoint.py --model_dir ./bloom/176B/ \ --dtype float16 \ --output_dir ./bloom/176B/trt_ckpt/fp16/8-gpu/ \ diff --git a/examples/bloom/convert_checkpoint.py b/examples/bloom/convert_checkpoint.py index ad2eb7e5f..1c19c9c3f 100644 --- a/examples/bloom/convert_checkpoint.py +++ b/examples/bloom/convert_checkpoint.py @@ -22,7 +22,7 @@ from tensorrt_llm.quantization import QuantAlgo, QuantMode from tensorrt_llm.models.convert_utils import iterate_shard_files, load_state_dict, \ load_calib_dataset, split_matrix_tp, get_weight_and_bias, split, smooth_gemm, \ - generate_int8 + generate_int8,get_weight # isort: on @@ -394,8 +394,7 @@ def get_tllm_linear_sq_weight( if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + results[prefix + 'weight'] = cur_weights.t().contiguous() if smoother_value is None: results[last_prefix] = torch.from_numpy( np.array([1.0], dtype=np.float32)) @@ -406,33 +405,26 @@ def get_tllm_linear_sq_weight( axis=cat_dim)[rank] else: cur_per_channel_value = vals["scale_w_quant_orig.col"] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array(cur_per_channel_value, - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape( + col_shape).contiguous() else: - original_weights = np.array(vals["weight.int8"]) + original_weights = vals["weight.int8"] cur_weights = np.split(original_weights, tensor_parallel, axis=cat_dim)[rank] if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + results[prefix + 'weight'] = cur_weights.t().contiguous() cur_per_channel_value = vals["scale_y_accum_quant"] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array([cur_per_channel_value], - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape( + col_shape).contiguous() - results[last_prefix] = torch.from_numpy( - np.array([vals['scale_x_orig_quant']], - dtype=np.float32)).contiguous() + results[last_prefix] = vals['scale_x_orig_quant'].contiguous() - results[prefix + 'act_scale'] = torch.from_numpy( - np.array([[vals["scale_y_quant_orig"]]], - dtype=np.float32)).contiguous() + results[prefix + 'act_scale'] = vals["scale_y_quant_orig"].contiguous() if smoother_value is not None: cur_smoother_value = np.split(smoother_value, diff --git a/examples/bloom/requirements.txt b/examples/bloom/requirements.txt index 4abbbacac..bf948f69f 100644 --- a/examples/bloom/requirements.txt +++ b/examples/bloom/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/chatglm/requirements.txt b/examples/chatglm/requirements.txt index e16e87b61..eeb7788bc 100644 --- a/examples/chatglm/requirements.txt +++ b/examples/chatglm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 protobuf diff --git a/examples/dbrx/convert_checkpoint.py b/examples/dbrx/convert_checkpoint.py index 86d884af6..41e19b309 100644 --- a/examples/dbrx/convert_checkpoint.py +++ b/examples/dbrx/convert_checkpoint.py @@ -7,9 +7,8 @@ import traceback from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, Optional +from typing import Dict, Optional, Tuple -import numpy as np import safetensors import torch import torch.nn as nn @@ -203,6 +202,27 @@ def args_to_build_options(args): } +def get_weight(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> torch.Tensor: + if f'{prefix}' in params: + return params[f'{prefix}'].to(dtype).detach().cpu() + elif f'{prefix}.weight' not in params: + return None + return params[f'{prefix}.weight'].to(dtype).detach().cpu() + + +def get_bias(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> torch.Tensor: + if f'{prefix}.bias' not in params: + return None + return params[f'{prefix}.bias'].to(dtype).detach().cpu() + + +def get_weight_and_bias(params: Dict[str, torch.Tensor], prefix: str, + dtype: torch.dtype) -> Tuple[torch.Tensor]: + return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype) + + @torch.no_grad() def capture_activation_range(model, tokenizer, @@ -369,9 +389,8 @@ def convert_hf_dbrx(model_params: dict, is_qkv=True, multi_query_mode=multi_query_mode) weights[ - f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.from_numpy( - np.array([int8_weights['scale_y_quant_orig']], - dtype=np.float32)).contiguous() + f'{tllm_prex}.attention.kv_cache_scaling_factor'] = int8_weights[ + 'scale_y_quant_orig'].contiguous() # input layer_norm input_ln_weight = get_weight(model_params, diff --git a/examples/dbrx/requirements.txt b/examples/dbrx/requirements.txt index 2858ba809..94f0ed2d2 100644 --- a/examples/dbrx/requirements.txt +++ b/examples/dbrx/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/falcon/requirements.txt b/examples/falcon/requirements.txt index 21d904ffb..815dca5b8 100644 --- a/examples/falcon/requirements.txt +++ b/examples/falcon/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 transformers>=4.31.0 datasets~=2.14.5 evaluate~=0.4.1 diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt index 85042ed1c..0a5ffa77b 100644 --- a/examples/gemma/requirements.txt +++ b/examples/gemma/requirements.txt @@ -3,7 +3,7 @@ # WAR the new posting of "nvidia-cudnn-cu12~=9.0". # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9". nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64" -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 flax~=0.8.0 # jax[cuda12_pip]~=0.4.19; platform_system != "Windows" jax~=0.4.19; platform_system == "Windows" diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index 1cdca2c02..c179f7ffc 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptj/requirements.txt b/examples/gptj/requirements.txt index 0060a1631..73cd7c4df 100644 --- a/examples/gptj/requirements.txt +++ b/examples/gptj/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptneox/requirements.txt b/examples/gptneox/requirements.txt index c29598eda..15bcb91c5 100644 --- a/examples/gptneox/requirements.txt +++ b/examples/gptneox/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 rouge_score~=0.1.2 evaluate~=0.4.1 diff --git a/examples/grok/requirements.txt b/examples/grok/requirements.txt index b13b48bc7..3019f01b3 100644 --- a/examples/grok/requirements.txt +++ b/examples/grok/requirements.txt @@ -1,6 +1,6 @@ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/hf_lora_convert.py b/examples/hf_lora_convert.py index 4808b303f..06f43a847 100644 --- a/examples/hf_lora_convert.py +++ b/examples/hf_lora_convert.py @@ -47,15 +47,25 @@ def get_all_lora_weights(lora_weights): pattern = re.compile( r'.*\.layers\.([0-9]+)\.(self_attn|mlp)\.([a-z_]+)\.lora_(A|B)\.weight.*' ) + moe_pattern = re.compile( + r'.*\.layers\.([0-9]+)\.(block_sparse_moe)\.((experts)\.([0-9]+)\.|)([a-zA-Z0-9_]+)\.lora_(A|B)\.weight.*' + ) for key, weights in lora_weights.items(): m = pattern.match(key) - if not m: + m_moe = moe_pattern.match(key) + if m: + layer_idx = int(m.group(1)) + hf_module = m.group(3) + inout = "in" if m.group(4) == "A" else "out" + all_weights[layer_idx][hf_module][inout] = weights + elif m_moe: + layer_idx = int(m_moe.group(1)) + hf_module = m_moe.group(6) + inout = "in" if m_moe.group(7) == "A" else "out" + all_weights[layer_idx][hf_module][inout] = weights + else: print(f"no match {key}") continue - layer_idx = int(m.group(1)) - hf_module = m.group(3) - inout = "in" if m.group(4) == "A" else "out" - all_weights[layer_idx][hf_module][inout] = weights return all_weights @@ -87,6 +97,10 @@ def preprocess_lora_weights(lora_model): "gate_up_proj": "mlp_h_to_4h", "c_fc": "mlp_h_to_4h", "c_proj": "mlp_4h_to_h", + "w1": "moe_h_to_4h", + "w2": "moe_4h_to_h", + "w3": "moe_gate", + "gate": "moe_router", } # lora modules on llama hf_modules_to_module_id = { k: LoraManager.LORA_MODULE_IDS[v] diff --git a/examples/internlm/requirements.txt b/examples/internlm/requirements.txt index d2fa98b74..3078974dd 100644 --- a/examples/internlm/requirements.txt +++ b/examples/internlm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets==2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/jais/requirements.txt b/examples/jais/requirements.txt index 1cdca2c02..c179f7ffc 100644 --- a/examples/jais/requirements.txt +++ b/examples/jais/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/llama/README.md b/examples/llama/README.md index c54681dc0..d373f932a 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -717,16 +717,19 @@ To run the GPTQ LLaMa example, the following steps are required: 1. Weight quantization: - Quantized weights for GPTQ are generated using [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa.git) as follow: + Quantized weights for GPTQ are generated using [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ) as follow: ```bash - git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git - cd GPTQ-for-LLaMa - pip install -r requirements.txt + git clone https://github.com/AutoGPTQ/AutoGPTQ + cd AutoGPTQ + pip install . + + # Download the quant_autogptq script + wget https://gist.githubusercontent.com/TheBloke/b47c50a70dd4fe653f64a12928286682/raw/ebcee019d90a178ee2e6a8107fdd7602c8f1192a/quant_autogptq.py # Quantize weights into INT4 and save as safetensors # Quantized weight with parameter "--act-order" is not supported in TRT-LLM - python llama.py ./tmp/llama/7B/ c4 --wbits 4 --true-sequential --groupsize 128 --save_safetensors ./llama-7b-4bit-gs128.safetensors + python quant_autogptq.py ./tmp/llama/7B ./llama-7b-4bit-gs128.safetensors wikitext --bits 4 --group_size 128 --desc_act 0 --damp 0.1 --dtype float16 --seqlen 4096 --num_samples 3 --use_fast ``` Let us build the TRT-LLM engine with the saved `./llama-7b-4bit-gs128.safetensors`. diff --git a/examples/llama/convert_checkpoint.py b/examples/llama/convert_checkpoint.py index 820db5b2d..17034568f 100644 --- a/examples/llama/convert_checkpoint.py +++ b/examples/llama/convert_checkpoint.py @@ -79,7 +79,7 @@ def parse_arguments(): type=str, nargs='?', default='int8', - choices=['int8', 'int4', 'int4_gptq'], + choices=['int8', 'int4', 'int4_gptq', 'int4_awq'], help= 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' @@ -281,6 +281,12 @@ def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: quant_config.pre_quant_scale = False quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + if args.weight_only_precision == 'int4_awq': + quant_config.group_size = args.group_size + quant_config.has_zero_point = False + quant_config.pre_quant_scale = True + quant_config.quant_algo = QuantAlgo.W4A16_AWQ + return quant_config diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index 5a5c7ac02..0658842e8 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/llm-api/requirements.txt b/examples/llm-api/requirements.txt index a9f717dc2..6d9b7b3e5 100644 --- a/examples/llm-api/requirements.txt +++ b/examples/llm-api/requirements.txt @@ -1,2 +1,2 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt index 159c21287..08e64bb07 100644 --- a/examples/mamba/requirements.txt +++ b/examples/mamba/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 transformers>=4.39.0 datasets~=2.14.5 evaluate diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt index fd75e3f5b..82689ec1c 100644 --- a/examples/medusa/requirements.txt +++ b/examples/medusa/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/mixtral/requirements.txt b/examples/mixtral/requirements.txt index d2b3b6edb..9edade2c6 100644 --- a/examples/mixtral/requirements.txt +++ b/examples/mixtral/requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 transformers==4.38.2 accelerate==0.25.0 diff --git a/examples/model_api/llama.py b/examples/model_api/llama.py index 5c0b1a8cd..c699ee192 100644 --- a/examples/model_api/llama.py +++ b/examples/model_api/llama.py @@ -55,14 +55,14 @@ def main(): engine.save(args.engine_dir) tokenizer = AutoTokenizer.from_pretrained(args.hf_model_dir) - executor = GenerationExecutor.create(args.engine_dir) - sampling_params = SamplingParams(max_tokens=5) + with GenerationExecutor.create(args.engine_dir) as executor: + sampling_params = SamplingParams(max_tokens=5) - input_str = "What should you say when someone gives you a gift? You should say:" - output = executor.generate(tokenizer.encode(input_str), - sampling_params=sampling_params) - output_str = tokenizer.decode(output.outputs[0].token_ids) - print(f"{input_str} {output_str}") + input_str = "What should you say when someone gives you a gift? You should say:" + output = executor.generate(tokenizer.encode(input_str), + sampling_params=sampling_params) + output_str = tokenizer.decode(output.outputs[0].token_ids) + print(f"{input_str} {output_str}") if __name__ == "__main__": diff --git a/examples/model_api/llama_quantize.py b/examples/model_api/llama_quantize.py index a6cd05d6d..699229181 100644 --- a/examples/model_api/llama_quantize.py +++ b/examples/model_api/llama_quantize.py @@ -63,14 +63,14 @@ def main(): engine.save(engine_dir) tokenizer = AutoTokenizer.from_pretrained(args.hf_model_dir) - executor = GenerationExecutor.create(engine_dir) - sampling_params = SamplingParams(max_tokens=5) + with GenerationExecutor.create(engine_dir) as executor: + sampling_params = SamplingParams(max_tokens=5) - input_str = "What should you say when someone gives you a gift? You should say:" - output = executor.generate(tokenizer.encode(input_str), - sampling_params=sampling_params) - output_str = tokenizer.decode(output.outputs[0].token_ids) - print(f"{input_str} {output_str}") + input_str = "What should you say when someone gives you a gift? You should say:" + output = executor.generate(tokenizer.encode(input_str), + sampling_params=sampling_params) + output_str = tokenizer.decode(output.outputs[0].token_ids) + print(f"{input_str} {output_str}") if __name__ == "__main__": diff --git a/examples/mpt/requirements.txt b/examples/mpt/requirements.txt index 0060a1631..73cd7c4df 100644 --- a/examples/mpt/requirements.txt +++ b/examples/mpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index a2e85a42e..79b28d294 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -621,7 +621,7 @@ Currently, CogVLM only support bfloat16 precision. 1. Download Huggingface weights ```bash - export MODEL_NAME="Phi-3-vision-128k-instruct" + export MODEL_NAME="Phi-3-vision-128k-instruct" # or Phi-3.5-vision-instruct git clone https://huggingface.co/microsoft/${MODEL_NAME} tmp/hf_models/${MODEL_NAME} ``` diff --git a/examples/nemotron/requirements.txt b/examples/nemotron/requirements.txt index bf3e39b44..eedf21867 100644 --- a/examples/nemotron/requirements.txt +++ b/examples/nemotron/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 nemo-toolkit[all]==2.0.0rc1 megatron-core==0.8.0 datasets~=2.14.5 diff --git a/examples/nemotron_nas/README.md b/examples/nemotron_nas/README.md new file mode 100644 index 000000000..ef3da2448 --- /dev/null +++ b/examples/nemotron_nas/README.md @@ -0,0 +1,91 @@ +# Nemotron-NAS + +This document shows how to convert and build a model generated by Nemotron-NAS, such as Llama-3_1-Nemotron-51B-Instruct, in TensorRT-LLM. + +- [NemotronNas](#nemotron-nas) + - [Overview](#overview) + - [Support Matrix](#support-matrix---verify-with-omer--nave) + - [Custom Layers](#custom-layers) + - [Usage](#usage) + - [Build TensorRT engine(s)](#build-tensorrt-engines) + - [Runtime](#runtime) + +## Overview +The TensorRT-LLM Nemotron-NAS implementation can be found in [tensorrt_llm/models/nemotron_nas/model.py](../../tensorrt_llm/models/nemotron_nas/model.py). The TensorRT-LLM Nemotron-NAS example code is located in [`examples/nemotron_nas`](./). There is one main file: + +* [`convert_checkpoint.py`](./convert_checkpoint.py) to convert the model into tensorrt-llm checkpoint format. + +## Support Matrix + * FP16 + * BF16 + * Tensor parallelism + * Pipeline parallelism + +## Custom Layers +Nemotron-NAS offers the ability to replace both `attention` and `FFN` layers with either `Linear` or `NoOp` layers. +`attention` layers can be replaced with `LinearAttention` (which eventually calls `tensorrt_llm/layers/Linear`). +Additionally, `attention` layers can also be replaced with `NoOpAttention` (which essentially returns 0, thus implementing a no-op operation). +`LinearAttention` and `NoOpAttention` require no kv-cache. +Likewise, `FFN` layers can be replaced with either `LinearFFN` or `NoOpFFN`. + +Different attention layers of the model may have a different number of key-value attention heads, and different MLP layers may have different hidden sizes. + +## A note on Pipeline Parallelism +Due the model's non-uniform architecture, the different pipeline parallelism ranks may run different types of layers, resulting in a possibly unbalanced load between GPUs during inference. + +## Usage + +The TensorRT-LLM example code is located at [examples/nemotron_nas](./). It takes HF weights as input, and builds the corresponding TensorRT engines. The number of TensorRT engines depends on the number of GPUs used to run inference. + +### Build TensorRT engine(s) + +To create a TensorRT engine you first need to obtain a Nemotron-NAS checkpoint (HF format). For example, [Llama-3_1-Nemotron-51B-Instruct](https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct). + +The `trtllm-build` command builds TensorRT engine(s) from TRTLLM checkpoint. If no checkpoint directory is specified, TensorRT-LLM will build engine(s) with dummy weights. + +`trtllm-build` command has a variety of options. In particular, the plugin-related options have two categories: +* Plugin options that requires a data type (e.g., `gpt_attention_plugin`), you can + * explicitly specify `float16`/`bfloat16`/`float32`, so that the plugins are enabled with the specified precision; + * implicitly specify `auto`, so that the plugins are enabled with the precision automatically inferred from model dtype (i.e., the dtype specified in weight conversion); or + +```bash +# Optional: prepare environment variables +export MODEL_DIR=... +export TRT_CHECKPOINT_DIR=... +export TRT_ENGINE_DIR=... +export TP_SIZE=... +export PP_SIZE=... + +# create a local copy of the model checkpoint +git clone https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct $MODEL_DIR + +# Convert the model to TRT BF16 checkpoint +# Note, currently must provide --trust_remote_code flag +python convert_checkpoint.py --model_dir $MODEL_DIR \ + --dtype bfloat16 \ + --output_dir $TRT_CHECKPOINT_DIR \ + --tp_size=$TP_SIZE --pp_size=$PP_SIZE \ + --trust_remote_code + +# Build the model engine using a single GPU and FP16 +trtllm-build --checkpoint_dir $TRT_CHECKPOINT_DIR \ + --output_dir $TRT_ENGINE_DIR \ + --gemm_plugin auto \ + --kv_cache_type paged +``` + +The conversion script supports additional models with variable GQA, such as [DeciLM-7B](https://huggingface.co/Deci/DeciLM-7B). + +## Runtime +Once built, the TRT engine may be used with any TRTLLM entrypoint or API. For example, to run inference with [examples/run.py](../run.py): + +```bash +export MODEL_DIR=... +export TRT_ENGINE_DIR=... + +python run.py --engine_dir $TRT_ENGINE_DIR --tokenizer_dir $MODEL_DIR --max_output_len 1024 ... + +# for multi-GPU inference (engine must be built with either tp_size>1, pp_size>1, or both) +export NUM_GPUS=... +mpirun -n $NUM_GPUS --allow-run-as-root python run.py ... +``` diff --git a/examples/nemotron_nas/calibration_utils.py b/examples/nemotron_nas/calibration_utils.py new file mode 100644 index 000000000..42b4382fa --- /dev/null +++ b/examples/nemotron_nas/calibration_utils.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DATASET = "Magpie-Align/Magpie-Pro-MT-300K-v0.1" + + +def create_trtllm_magpie_calibration_dataset(output_dir: str, + calib_size: int = 512) -> None: + from datasets import load_dataset + + dataset = load_dataset(DATASET, split="train") + + def transform(conversation): + value = '\n'.join(turn['value'] + for turn in conversation['conversations']) + return {"text": value} + + dataset = dataset.select(range(calib_size)).map( + transform, remove_columns=dataset.column_names) + # https://github.com/huggingface/datasets/issues/6703#issuecomment-1974766332 + dataset.to_parquet(output_dir + "/data.parquet") + + +if __name__ == "__main__": + import sys + output_dir = sys.argv[1] + create_trtllm_magpie_calibration_dataset(output_dir) diff --git a/examples/nemotron_nas/convert_checkpoint.py b/examples/nemotron_nas/convert_checkpoint.py new file mode 100644 index 000000000..f2fcc04e8 --- /dev/null +++ b/examples/nemotron_nas/convert_checkpoint.py @@ -0,0 +1,162 @@ +import argparse +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from tensorrt_llm._utils import release_gc +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import DeciLMForCausalLM + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_dir', type=str, required=True) + + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument('--dtype', + type=str, + default='bfloat16', + choices=['float32', 'bfloat16', 'float16']) + + parser.add_argument('--load_by_shard', + action='store_true', + help='Load a pretrained model shard-by-shard.') + + parser.add_argument("--load_model_on_cpu", action="store_true") + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--use_embedding_sharing', + action="store_true", + default=False, + help= + 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' + 'Note: the flag might not take effect when the criteria are not met.') + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + + parser.add_argument( + '--save_config_only', + action="store_true", + default=False, + help= + 'Only save the model config w/o read and converting weights, be careful, this is for debug only' + ) + + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Pass trust_remote_code=True to HF loading functions as needed") + + args = parser.parse_args() + return args + + +def args_to_build_options(args): + return { + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'share_embedding_table': args.use_embedding_sharing, + } + + +def convert_and_save_hf(args): + model_dir = args.model_dir + load_by_shard = args.load_by_shard + world_size = args.tp_size * args.pp_size + # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. + # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, + # before the refactor is done. + override_fields = {} + override_fields.update(args_to_build_options(args)) + + def convert_and_save_rank(args, rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + model = DeciLMForCausalLM.from_hugging_face( + model_dir, + args.dtype, + mapping=mapping, + load_by_shard=load_by_shard, + load_model_on_cpu=args.load_model_on_cpu, + trust_remote_code=args.trust_remote_code, + **override_fields, + ) + model.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del model + + execute(args.workers, [convert_and_save_rank] * world_size, args) + release_gc() + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) + else: + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + +def main(): + args = parse_arguments() + tik = time.time() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # TODO(oargov): all deci checkpoints require trust_remote_code=True at the moment, remove this when this changes + # NOTE: we opt not to make this the default since users should be made aware of this in-case they don't want to trust remote code + assert args.trust_remote_code, "Nemotron NAS checkpoint require --trust_remote_code" + + convert_and_save_hf(args) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() diff --git a/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.cpp b/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.cpp index 004373d94..b165e577a 100644 --- a/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.cpp +++ b/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.cpp @@ -197,15 +197,17 @@ int TritonFlashAttentionPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputD nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { + int res = 1; if (mType == DataType::kHALF) { - return enqueueImpl(inputDesc, outputDesc, inputs, outputs, workspace, stream); + res = enqueueImpl(inputDesc, outputDesc, inputs, outputs, workspace, stream); } else if (mType == DataType::kFLOAT) { - return enqueueImpl(inputDesc, outputDesc, inputs, outputs, workspace, stream); + res = enqueueImpl(inputDesc, outputDesc, inputs, outputs, workspace, stream); } - return 1; + sync_check_cuda_error(); + return res; } // IPluginV2Ext Methods diff --git a/examples/opt/requirements.txt b/examples/opt/requirements.txt index 0060a1631..73cd7c4df 100644 --- a/examples/opt/requirements.txt +++ b/examples/opt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/phi/README.md b/examples/phi/README.md index c6c25ebb5..802ee282a 100644 --- a/examples/phi/README.md +++ b/examples/phi/README.md @@ -1,18 +1,16 @@ # Phi -This document explains how to build the [phi-2](https://huggingface.co/microsoft/phi-2), [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct), -[Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct), [Phi-3-small-8k-instruct](https://huggingface.co/microsoft/Phi-3-small-8k-instruct), [Phi-3-small-128k-instruct](https://huggingface.co/microsoft/Phi-3-small-128k-instruct), [Phi-3-medium-4k-instruct](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct/) and [Phi-3-medium-128k-instruct](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/) -models using TensorRT-LLM and run on a single GPU. - -- [Phi](#phi) - - [Overview](#overview) - - [Support Matrix](#support-matrix) - - [Usage](#usage) - - [1. Convert weights from HF Transformers to TensorRT-LLM format](#1-convert-weights-from-hf-transformers-to-tensorrt-llm-format) - - [2. Build TensorRT engine(s)](#2-build-tensorrt-engines) - - [3. Summarization using the Phi model](#3-summarization-using-the-phi-model) - - [4. Quantization](#4-quantization) - - [5. Run Phi-3 with LoRA](#5-run-phi-3-with-lora) +This document explains how to build Phi-2, Phi-3 and Phi-3.5 family of models using TensorRT-LLM and run on a single or multiple GPUs. +For multimodal models (Phi-3-vision-128k-instruct and Phi-3.5-vision-instruct), see `../multimodal/README.md`. + +- [Overview](#overview) +- [Support Matrix](#support-matrix) +- [Usage](#usage) + - [1. Convert weights from HF Transformers to TensorRT-LLM format](#1-convert-weights-from-hf-transformers-to-tensorrt-llm-format) + - [2. Build TensorRT engine(s)](#2-build-tensorrt-engines) + - [3. Summarization using the Phi model](#3-summarization-using-the-phi-model) + - [4. Quantization](#4-quantization) + - [5. Run Phi-3 with LoRA](#5-run-phi-3-with-lora) ## Overview @@ -29,13 +27,15 @@ In addition, there are two shared files in the parent folder [`examples`](../) f | Model Name | FP16 | BF16 | FP8 | INT8 | TP | | :--------------: | :---: | :---: | :---: | :---: | :---: | -| phi-2 | Y | Y | | | Y | +| Phi-2 | Y | Y | | | Y | | Phi-3-mini-4k-instruct | Y | Y | Y | Y | | Phi-3-mini-128k-instruct | Y | Y | Y | Y | | Phi-3-small-8k-instruct | Y | Y | Y | Y | Y | | Phi-3-small-128k-instruct | Y | Y | Y | Y | Y | | Phi-3-medium-8k-instruct | Y | Y | Y | Y | | Phi-3-medium-128k-instruct | Y | Y | Y | Y | +| Phi-3.5-mini-instruct | Y | Y | Y | Y | +| Phi-3.5-MoE-instruct | Y | Y | Y | Y | Y | * Model Name: the name of the model, the same as the name on HuggingFace * TP: Tensor Parallel @@ -57,6 +57,11 @@ python ./convert_checkpoint.py \ --dtype float16 ``` +If a model supports tensor-parallelism, number of tensor parallel ranks to split the model into can be specified as `--tp_size` argument to `convert_checkpoint.py`. + +For Phi-3.5-MoE-instruct model, expert parallelism can be enabled using `--moe_tp_size` and `--moe_ep_size` arguments. +The section on Parallelism Modes in `../mixtral/README.md` discusses tensor and expert parallelism for Mixture of Experts models in detail. + ### 2. Build TensorRT engine(s) TensorRT-LLM builds TensorRT engine(s) using a HF checkpoint. If no checkpoint directory is specified, TensorRT-LLM will build engine(s) using dummy weights. diff --git a/examples/phi/convert_checkpoint.py b/examples/phi/convert_checkpoint.py index cddb110b0..249dae2f2 100644 --- a/examples/phi/convert_checkpoint.py +++ b/examples/phi/convert_checkpoint.py @@ -59,6 +59,20 @@ def parse_arguments(): 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) + parser.add_argument( + '--moe_tp_size', + type=int, + default=-1, + help= + 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' + ) + parser.add_argument( + '--moe_ep_size', + type=int, + default=-1, + help= + 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' + ) parser.add_argument('--output_dir', type=str, default='tllm_checkpoint', @@ -110,6 +124,18 @@ def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: args = parse_arguments() assert args.pp_size == 1, "Pipeline parallelism is not supported." + world_size = args.tp_size * args.pp_size + if (args.moe_tp_size == -1 and args.moe_ep_size == -1): + # moe default to tp-only + args.moe_tp_size = args.tp_size + args.moe_ep_size = 1 + elif (args.moe_tp_size == -1): + args.moe_tp_size = args.tp_size // args.moe_ep_size + elif (args.moe_ep_size == -1): + args.moe_ep_size = args.tp_size // args.moe_tp_size + assert (args.moe_tp_size * args.moe_ep_size == args.tp_size + ), "moe_tp_size * moe_ep_size must equal to tp_size" + tik = time.time() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) @@ -119,39 +145,35 @@ def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: model_type = model_config.architectures[0] supported_models = [ 'PhiForCausalLM', 'Phi3ForCausalLM', 'Phi3VForCausalLM', - 'Phi3SmallForCausalLM' + 'Phi3SmallForCausalLM', 'PhiMoEForCausalLM' ] if model_type not in supported_models: assert False, "Invalid model type" - phi_model = Phi3ForCausalLM if model_type.find( - 'Phi3') != -1 else PhiForCausalLM - - hf_model = None + is_phi3 = 'Phi3' in model_type or 'MoE' in model_type + phi_model = Phi3ForCausalLM if is_phi3 else PhiForCausalLM - override_fields = {} - # override_fields.update(args_to_build_options(args)) quant_config = args_to_quant_config(args) def convert_and_save_rank(args, rank): - mapping = Mapping(world_size=args.tp_size * args.pp_size, + mapping = Mapping(world_size=world_size, rank=rank, tp_size=args.tp_size, - pp_size=args.pp_size) + pp_size=args.pp_size, + moe_tp_size=args.moe_tp_size, + moe_ep_size=args.moe_ep_size) phi = phi_model.from_hugging_face( - args.model_dir if hf_model is None else hf_model, + args.model_dir, args.dtype, mapping=mapping, quant_config=quant_config, - **override_fields, ) phi.save_checkpoint(args.output_dir, save_config=(rank == 0)) del phi - execute(args.workers, [convert_and_save_rank] * args.tp_size * args.pp_size, - args) + execute(args.workers, [convert_and_save_rank] * world_size, args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) diff --git a/examples/phi/requirements.txt b/examples/phi/requirements.txt index d041d57d6..b711cf2ad 100644 --- a/examples/phi/requirements.txt +++ b/examples/phi/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/quantization/requirements.txt b/examples/quantization/requirements.txt index 9e5059e68..4df6d2b70 100644 --- a/examples/quantization/requirements.txt +++ b/examples/quantization/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets>=2.14.4 nemo-toolkit[all]<=1.20.0,>=1.18.0 rouge_score~=0.1.2 diff --git a/examples/qwen/requirements.txt b/examples/qwen/requirements.txt index a100bbffe..4ad36e3f6 100644 --- a/examples/qwen/requirements.txt +++ b/examples/qwen/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/qwenvl/README.md b/examples/qwenvl/README.md index 4578ca92e..d4e2da3a9 100644 --- a/examples/qwenvl/README.md +++ b/examples/qwenvl/README.md @@ -50,18 +50,16 @@ python3 run.py \ --tokenizer_dir=./Qwen-VL-Chat \ --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat \ - --vit_engine_dir=./plan \ - --images_path='{"image": "./pics/demo.jpeg"}' \ - --input_dir='{"image": "image.pt"}' + --vit_engine_path=./plan/visual_encoder/visual_encoder_fp16.plan \ + --images_path='{"image": "./pics/demo.jpeg"}' ``` 4.2 (Optional) For multiple rounds of dialogue, you can run: ```bash python3 run_chat.py \ --tokenizer_dir=./Qwen-VL-Chat \ --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat \ - --vit_engine_dir=./plan \ - --images_path='{"image": "./pics/demo.jpeg"}' \ - --input_dir='{"image": "image.pt"}' + --vit_engine_path=./plan/visual_encoder/visual_encoder_fp16.plan \ + --images_path='{"image": "./pics/demo.jpeg"}' ``` 4.3 (Optional) To show the bounding box result in the demo picture, install OpenCV, ZMQ, and request: ```bash @@ -85,7 +83,7 @@ python3 run_chat.py \ --tokenizer_dir=./Qwen-VL-Chat \ --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat \ - --vit_engine_dir=./plan \ + --vit_engine_path=./plan/visual_encoder/visual_encoder_fp16.plan \ --display \ --port=8006 ``` @@ -98,7 +96,7 @@ python3 run_chat.py \ --tokenizer_dir=./Qwen-VL-Chat \ --qwen_engine_dir=./trt_engines/Qwen-VL-7B-Chat \ - --vit_engine_dir=./plan \ + --vit_engine_path=./plan/visual_encoder/visual_encoder_fp16.plan \ --display \ --local_machine ``` diff --git a/examples/qwenvl/requirements.txt b/examples/qwenvl/requirements.txt index a6b70325a..989f4fb60 100644 --- a/examples/qwenvl/requirements.txt +++ b/examples/qwenvl/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/qwenvl/run.py b/examples/qwenvl/run.py index f721cd307..d0bbfa3a0 100644 --- a/examples/qwenvl/run.py +++ b/examples/qwenvl/run.py @@ -25,6 +25,7 @@ import tensorrt_llm import tensorrt_llm.profiler as profiler from tensorrt_llm import logger +from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.quantization import QuantMode from tensorrt_llm.runtime import (ModelConfig, SamplingConfig, Session, TensorInfo) @@ -113,8 +114,11 @@ def get_model(self): num_layers = config["pretrained_config"]["num_hidden_layers"] num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_heads) - paged_kv_cache = config["build_config"]["plugin_config"][ - "paged_kv_cache"] + if "kv_cache_type" in config["build_config"]: + kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"]) + else: + kv_cache_type = KVCacheType.CONTINUOUS + tokens_per_block = config["build_config"]["plugin_config"][ "tokens_per_block"] max_prompt_embedding_table_size = config["build_config"].get( @@ -144,7 +148,7 @@ def get_model(self): vocab_size=vocab_size, num_layers=num_layers, gpt_attention_plugin=use_gpt_attention_plugin, - paged_kv_cache=paged_kv_cache, + kv_cache_type=kv_cache_type, tokens_per_block=tokens_per_block, remove_input_padding=remove_input_padding, dtype=dtype, @@ -418,9 +422,9 @@ def parse_arguments(): parser.add_argument("--max_new_tokens", type=int, default=200) parser.add_argument("--log_level", type=str, default="info") parser.add_argument( - "--vit_engine_dir", + "--vit_engine_path", type=str, - default="qwen_outputs", + default="plan/visual_encoder/visual_encoder_fp16.plan", ) parser.add_argument( "--qwen_engine_dir", @@ -468,18 +472,25 @@ def parse_arguments(): type=int, help="Use beam search if num_beams >1", default=1) + parser.add_argument("--display", default=False, action='store_true') + parser.add_argument('--port', type=str, default='8006') + parser.add_argument("--local_machine", default=False, action='store_true') + return parser.parse_args() -def vit_process(image_path, vit_engine_dir, stream): +def vit_process(image_path, vit_engine_path, stream): img_processor = Preprocss(448) - logger.info(f"Loading engine from {vit_engine_dir}") - with open(vit_engine_dir, "rb") as f: + logger.info(f"Loading engine from {vit_engine_path}") + with open(vit_engine_path, "rb") as f: engine_buffer = f.read() - logger.info(f"Creating session from engine {vit_engine_dir}") + logger.info(f"Creating session from engine {vit_engine_path}") session_vit = Session.from_serialized_engine(engine_buffer) device = torch.device("cuda") if torch.cuda.is_available() else "cpu" - images = img_processor.encode(image_path).to(device) + image_path_list = [] + for item in image_path: + image_path_list.append(next(iter(item.values()))) + images = img_processor.encode(image_path_list).to(device) batch_size = images.size(0) images = images.expand(batch_size, -1, -1, -1).contiguous() visual_inputs = {"input": images.float()} @@ -510,7 +521,7 @@ def vit_process(image_path, vit_engine_dir, stream): args = parse_arguments() stream = torch.cuda.current_stream().cuda_stream tensorrt_llm.logger.set_level(args.log_level) - image_embeds = vit_process(args.images_path, args.vit_engine_dir, stream) + image_embeds = vit_process(args.images_path, args.vit_engine_path, stream) qinfer = QWenInfer( args.tokenizer_dir, args.qwen_engine_dir, diff --git a/examples/qwenvl/run_chat.py b/examples/qwenvl/run_chat.py index 37becbb24..e3457b8a5 100644 --- a/examples/qwenvl/run_chat.py +++ b/examples/qwenvl/run_chat.py @@ -80,7 +80,7 @@ def exist_cooridinate(input): if __name__ == '__main__': args = parse_arguments() stream = torch.cuda.current_stream().cuda_stream - image_embeds = vit_process(args.input_dir, args.vit_engine_dir, stream) + image_embeds = vit_process(args.images_path, args.vit_engine_path, stream) qinfer = QWenInfer(args.tokenizer_dir, args.qwen_engine_dir, args.log_level, args.output_csv, args.output_npy, args.num_beams) qinfer.qwen_model_init() diff --git a/examples/recurrentgemma/requirements.txt b/examples/recurrentgemma/requirements.txt index 05d787ade..9303bc411 100644 --- a/examples/recurrentgemma/requirements.txt +++ b/examples/recurrentgemma/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 git+https://github.com/google-deepmind/recurrentgemma.git flax>=0.8.2 jax~=0.4.23 diff --git a/examples/redrafter/requirements.txt b/examples/redrafter/requirements.txt index fd75e3f5b..82689ec1c 100644 --- a/examples/redrafter/requirements.txt +++ b/examples/redrafter/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/run.py b/examples/run.py index 7a3b767fc..cadd6cd19 100644 --- a/examples/run.py +++ b/examples/run.py @@ -18,6 +18,7 @@ import csv import os from pathlib import Path +from typing import List, Optional import numpy as np import torch @@ -198,35 +199,43 @@ def parse_input_token_extra_ids(prompt_table_path, kv_cache_enable_block_reuse, def print_output(tokenizer, - output_ids, - input_lengths, - sequence_lengths, - output_csv=None, - output_npy=None, - context_logits=None, - generation_logits=None, - cum_log_probs=None, - log_probs=None, - output_logits_npy=None, - output_cum_log_probs_npy=None, - output_log_probs_npy=None): - batch_size, num_beams, _ = output_ids.size() + output_ids: torch.Tensor, + input_lengths: List[int], + sequence_lengths: torch.Tensor, + output_csv: Optional[str] = None, + output_npy: Optional[str] = None, + context_logits: Optional[torch.Tensor] = None, + generation_logits: Optional[torch.Tensor] = None, + cum_log_probs: Optional[torch.Tensor] = None, + log_probs: Optional[torch.Tensor] = None, + output_logits_npy: Optional[str] = None, + output_cum_log_probs_npy: Optional[str] = None, + output_log_probs_npy: Optional[str] = None): + num_output_sents, num_beams, _ = output_ids.size() + batch_size = len(input_lengths) + num_return_sequences = num_output_sents // batch_size + if output_csv is None and output_npy is None: - for batch_idx in range(batch_size): - inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist( - ) + for i in range(batch_size * num_return_sequences): + batch_idx = i // num_return_sequences + seq_idx = i % num_return_sequences + inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist() input_text = tokenizer.decode(inputs) - print(f'Input [Text {batch_idx}]: \"{input_text}\"') + if seq_idx == 0: + print(f'Input [Text {batch_idx}]: \"{input_text}\"') + for beam in range(num_beams): output_begin = input_lengths[batch_idx] - output_end = sequence_lengths[batch_idx][beam] - outputs = output_ids[batch_idx][beam][ - output_begin:output_end].tolist() + output_end = sequence_lengths[i][beam] + outputs = output_ids[i][beam][output_begin:output_end].tolist() output_text = tokenizer.decode(outputs) - print( - f'Output [Text {batch_idx} Beam {beam}]: \"{output_text}\"') + index_str = (f'Text {batch_idx} Seq {seq_idx} Beam {beam}' + if num_return_sequences > 1 else + f'Text {batch_idx} Beam {beam}') + print(f'Output [{index_str}]: \"{output_text}\"') output_ids = output_ids.reshape((-1, output_ids.size(2))) + if output_csv is not None: output_file = Path(output_csv) output_file.parent.mkdir(exist_ok=True, parents=True) @@ -394,6 +403,7 @@ def main(args): "WARNING: using this option may increase network usage significantly (quadratically w.r.t output length)." ) args.return_all_generated_tokens = True + runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp runner_kwargs = dict( engine_dir=args.engine_dir, @@ -430,7 +440,8 @@ def main(args): kv_cache_free_gpu_memory_fraction=args. kv_cache_free_gpu_memory_fraction, enable_chunked_context=args.enable_chunked_context, - multi_block_mode=args.multi_block_mode) + multi_block_mode=args.multi_block_mode, + cuda_graph_mode=args.cuda_graph_mode) runner_kwargs.update( enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc) runner = runner_cls.from_dir(**runner_kwargs) @@ -453,6 +464,7 @@ def main(args): top_k=args.top_k, top_p=args.top_p, num_beams=args.num_beams, + num_return_sequences=args.num_return_sequences, length_penalty=args.length_penalty, early_stopping=args.early_stopping, repetition_penalty=args.repetition_penalty, @@ -483,10 +495,10 @@ def main(args): sequence_lengths = curr_outputs['sequence_lengths'] cum_log_probs = None log_probs = None - if args.output_cum_log_probs_npy != None: - cum_log_probs = outputs['cum_log_probs'] - if args.output_log_probs_npy != None: - log_probs = outputs['log_probs'] + if args.output_cum_log_probs_npy is not None: + cum_log_probs = curr_outputs['cum_log_probs'] + if args.output_log_probs_npy is not None: + log_probs = curr_outputs['log_probs'] print_output( tokenizer, output_ids, @@ -510,9 +522,9 @@ def main(args): context_logits = outputs['context_logits'] if runner.gather_generation_logits: generation_logits = outputs['generation_logits'] - if args.output_cum_log_probs_npy != None: + if args.output_cum_log_probs_npy is not None: cum_log_probs = outputs['cum_log_probs'] - if args.output_log_probs_npy != None: + if args.output_log_probs_npy is not None: log_probs = outputs['log_probs'] print_output(tokenizer, output_ids, @@ -550,9 +562,9 @@ def main(args): frequency_penalty=args.frequency_penalty, stop_words_list=stop_words_list, bad_words_list=bad_words_list, - output_cum_log_probs=(args.output_cum_log_probs_npy != - None), - output_log_probs=(args.output_log_probs_npy != None), + output_cum_log_probs=(args.output_cum_log_probs_npy + is not None), + output_log_probs=(args.output_log_probs_npy is not None), random_seed=args.random_seed, lora_uids=args.lora_task_uids, lookahead_config=args.lookahead_config, diff --git a/examples/skywork/requirements.txt b/examples/skywork/requirements.txt index 563246878..0ac45f703 100644 --- a/examples/skywork/requirements.txt +++ b/examples/skywork/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets~=2.16.1 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/smaug/requirements.txt b/examples/smaug/requirements.txt index 5a5c7ac02..0658842e8 100644 --- a/examples/smaug/requirements.txt +++ b/examples/smaug/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/summarize.py b/examples/summarize.py index 1ede00e7d..b908c83f6 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -15,6 +15,7 @@ import argparse import ast +import itertools import os from pathlib import Path @@ -143,6 +144,7 @@ def main(args): random_seed = args.random_seed temperature = args.temperature num_beams = args.num_beams + num_return_sequences = args.num_return_sequences length_penalty = args.length_penalty early_stopping = args.early_stopping repetition_penalty = args.repetition_penalty @@ -165,13 +167,22 @@ def main(args): # TODO: Add random_seed flag in gptj rouge_dir = args.rouge_dir if args.rouge_dir and os.path.exists( args.rouge_dir) else "rouge" - metric_tensorrt_llm = [evaluate.load(rouge_dir) for _ in range(num_beams)] - metric_hf = [evaluate.load(rouge_dir) for _ in range(num_beams)] - for i in range(num_beams): - metric_tensorrt_llm[i].seed = 0 + metric_tensorrt_llm = [[evaluate.load(rouge_dir) for _ in range(num_beams)] + for _ in range(num_return_sequences)] + for i, j in itertools.product(range(num_return_sequences), + range(num_beams)): + metric_tensorrt_llm[i][j].seed = 0 + ppls_trt_llm = [[[] for _ in range(num_beams)] + for _ in range(num_return_sequences)] + + # HF returns num_return_sequences output ids. If beam search is enabled, + # num_return_sequences should be less or equal to num_beams. + num_returns_hf = (num_beams + if num_return_sequences == 1 else num_return_sequences) + metric_hf = [evaluate.load(rouge_dir) for _ in range(num_returns_hf)] + for i in range(num_returns_hf): metric_hf[i].seed = 0 - ppls_trt_llm = [[] for _ in range(num_beams)] - ppls_hf = [[] for _ in range(num_beams)] + ppls_hf = [[] for _ in range(num_returns_hf)] def _prepare_inputs(batch_input_texts, eval_task='summarize', @@ -245,6 +256,7 @@ def eval_trt_llm(datapoint, batch_input_ids, max_new_tokens=output_len, max_attention_window_size=max_attention_window_size, + num_return_sequences=num_return_sequences, sink_token_length=sink_token_length, end_id=end_id, pad_id=pad_id, @@ -263,6 +275,7 @@ def eval_trt_llm(datapoint, lookahead_config=args.lookahead_config, output_sequence_lengths=True, return_dict=True, + random_seed=random_seed, medusa_choices=args.medusa_choices) torch.cuda.synchronize() @@ -270,46 +283,46 @@ def eval_trt_llm(datapoint, if runtime_rank == 0: output_ids = outputs['output_ids'] output_beams_list = [ - tokenizer.batch_decode(output_ids[batch_idx, :, - input_lengths[batch_idx]:], - skip_special_tokens=True) - for batch_idx in range(batch_size) + tokenizer.batch_decode( + beam_tokens[:, input_lengths[i // num_return_sequences]:], + skip_special_tokens=True) + for i, beam_tokens in enumerate(output_ids) ] output_ids_list = [ - output_ids[batch_idx, :, input_lengths[batch_idx]:] - for batch_idx in range(batch_size) + beam_tokens[:, input_lengths[i // num_return_sequences]:] + for i, beam_tokens in enumerate(output_ids) ] - ppls = [[] for _ in range(batch_size)] - seq_lengths_array = outputs["sequence_lengths"].cpu().tolist() + ppls = [[] for _ in range(batch_size * num_return_sequences)] lengths_info = { 'input_lengths': input_lengths, - 'seq_lengths': seq_lengths_array + 'seq_lengths': outputs["sequence_lengths"].cpu().tolist(), } if eval_ppl: seq_lengths = outputs['sequence_lengths'] context_logits = outputs['context_logits'] - # Remove the first generation logits which are same to last context logits + # Remove the first generation logits which are same to last + # context logits. generation_logits = outputs['generation_logits'][:, :, 1:] - for batch_idx in range(batch_size): + for result_idx in range(batch_size * num_return_sequences): + batch_idx = result_idx // num_return_sequences # [batch, beam, step] for beam_idx in range(num_beams): - curr_len = seq_lengths[batch_idx, beam_idx] + curr_len = seq_lengths[result_idx, beam_idx] curr_ctx_len = input_lengths[batch_idx] curr_gen_len = curr_len - curr_ctx_len - curr_ids = output_ids[batch_idx, beam_idx, 1:curr_len] + curr_ids = output_ids[result_idx, beam_idx, 1:curr_len] curr_logits = torch.cat([ - context_logits[batch_idx], - generation_logits[batch_idx, + context_logits[result_idx], + generation_logits[result_idx, beam_idx, :curr_gen_len - 1] ], dim=0) curr_ppl = ppl(curr_logits, curr_ids) - logger.debug( - f"TensorRT-LLM PPL: {curr_ppl:.3f} | Generation length: {curr_gen_len}" - ) - ppls[batch_idx].append(curr_ppl) + logger.debug(f"TensorRT-LLM PPL: {curr_ppl:.3f} | " + f"Generation length: {curr_gen_len}") + ppls[result_idx].append(curr_ppl) return output_beams_list, output_ids_list, ppls, lengths_info return [], [], [], {} @@ -361,11 +374,14 @@ def eval_hf(datapoint, else: hf_config.update({ "num_beams": num_beams, - "num_return_sequences": num_beams, "early_stopping": local_early_stopping, }) + assert num_return_sequences < num_beams, ( + f'In HF, num_return_sequences ({num_return_sequences}) ' + f'has to be smaller or equal to num_beams ({num_beams})') outputs = model.generate(batch_input_ids, max_new_tokens=output_len, + num_return_sequences=num_return_sequences, temperature=temperature, eos_token_id=end_id, pad_token_id=pad_id, @@ -379,13 +395,12 @@ def eval_hf(datapoint, context_outputs = model(batch_input_ids) output_ids = outputs['sequences'] - tokens_list = output_ids[:, len(batch_input_ids[0]):].tolist() - output_ids = output_ids.reshape([batch_size, num_beams, -1]) + tokens_list = output_ids[:, max_length:].tolist() + output_ids = output_ids.reshape([batch_size, num_returns_hf, -1]) output_lines_list = [ - tokenizer.batch_decode(output_ids[:, i, - len(batch_input_ids[0]):], + tokenizer.batch_decode(output_ids[:, i, max_length:], skip_special_tokens=True) - for i in range(num_beams) + for i in range(num_returns_hf) ] ppls = [[] for _ in range(batch_size)] @@ -407,7 +422,7 @@ def eval_hf(datapoint, generation_logits = generation_logits.view(batch_size, num_beams, max_gen_len, voc_size) for batch_idx in range(batch_size): - for beam_idx in range(num_beams): + for beam_idx in range(num_returns_hf): curr_len = seq_lens[batch_idx, beam_idx] curr_ctx_len = input_lengths[batch_idx] curr_gen_len = curr_len - curr_ctx_len @@ -472,7 +487,8 @@ def eval_hf(datapoint, kv_cache_free_gpu_memory_fraction=args. kv_cache_free_gpu_memory_fraction, enable_chunked_context=args.enable_chunked_context, - multi_block_mode=args.multi_block_mode) + multi_block_mode=args.multi_block_mode, + cuda_graph_mode=args.cuda_graph_mode) runner_kwargs.update( enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc) runner = runner_cls.from_dir(**runner_kwargs) @@ -515,7 +531,7 @@ def eval_hf(datapoint, min_input_length=args.min_input_length) profiler.stop('tensorrt_llm') - empty_batch = (runtime_rank == 0 and len(output_tensorrt_llm) == 0) + empty_batch = runtime_rank == 0 and len(output_tensorrt_llm) == 0 empty_batch = mpi_broadcast(empty_batch, 0) if empty_batch: # No valid samples in the current batch, skip this iteration @@ -526,23 +542,21 @@ def eval_hf(datapoint, input_lengths = lengths_info['input_lengths'] seq_lengths = lengths_info['seq_lengths'] output_token_count_trt_llm = sum( - seq_lengths[bs][bm] - input_lengths[bs] - for bm in range(len(output_tensorrt_llm[0])) - for bs in range(len(output_tensorrt_llm))) + beam_len - input_lengths[seq_idx // num_return_sequences] + for seq_idx, beam_lens in enumerate(seq_lengths) + for beam_len in beam_lens) total_output_token_count_trt_llm += output_token_count_trt_llm - for batch_idx in range(len(output_tensorrt_llm)): - for beam_idx in range(num_beams): - metric_tensorrt_llm[beam_idx].add_batch( - predictions=[ - output_tensorrt_llm[batch_idx][beam_idx] - ], - references=[ - datapoint[dataset_output_key][batch_idx] - ]) + for result_idx, output_beams in enumerate(output_tensorrt_llm): + batch_idx, seq_idx = divmod(result_idx, + num_return_sequences) + reference = datapoint[dataset_output_key][batch_idx] + for beam_idx, output_beam in enumerate(output_beams): + metric_tensorrt_llm[seq_idx][beam_idx].add_batch( + predictions=[output_beam], references=[reference]) if args.eval_ppl: - ppls_trt_llm[beam_idx].append( - curr_ppls_trt_llm[batch_idx][beam_idx]) + ppls_trt_llm[seq_idx][beam_idx].append( + curr_ppls_trt_llm[result_idx][beam_idx]) if output_dir is not None: for i in range(len(output_tensorrt_llm[0])): for beam_idx in range(num_beams): @@ -642,8 +656,7 @@ def eval_hf(datapoint, if runtime_rank == 0: seq_lengths = [len(tokens) for tokens in token_list] total_output_token_count_hf += sum(seq_lengths) - - for beam_idx in range(num_beams): + for beam_idx in range(num_returns_hf): for batch_idx in range(len(output_hf[beam_idx])): metric_hf[beam_idx].add_batch( predictions=[output_hf[beam_idx][batch_idx]], @@ -655,7 +668,7 @@ def eval_hf(datapoint, curr_ppls_hf[batch_idx][beam_idx]) if output_dir is not None: for i in range(len(output_hf[0])): - for beam_idx in range(num_beams): + for beam_idx in range(num_returns_hf): with (output_dir / 'hf.out').open('a') as f: f.write( f'[{data_point_idx + i}] [Beam {beam_idx}] {output_hf[beam_idx][i]}\n' @@ -670,7 +683,7 @@ def eval_hf(datapoint, ite_count += 1 del model - if runtime_rank == 0: + if runtime_rank == 0 and args.max_ite > 0: if test_trt_llm: np.random.seed(0) # rouge score use sampling to compute the score logger.info( @@ -683,24 +696,25 @@ def eval_hf(datapoint, logger.info( f'TensorRT-LLM (tokens per second: {total_output_token_count_trt_llm / profiler.elapsed_time_in_sec("tensorrt_llm")})' ) - for beam_idx in range(num_beams): + for seq_idx, beam_idx in itertools.product( + range(num_return_sequences), range(num_beams)): logger.info(f"TensorRT-LLM beam {beam_idx} result") if args.eval_task != "eval_context_ppl": computed_metrics_tensorrt_llm = metric_tensorrt_llm[ - beam_idx].compute() + seq_idx][beam_idx].compute() for key in computed_metrics_tensorrt_llm.keys(): logger.info( f' {key} : {computed_metrics_tensorrt_llm[key]*100}' ) - if args.check_accuracy and beam_idx == 0: + if args.check_accuracy and seq_idx == 0 and beam_idx == 0: assert computed_metrics_tensorrt_llm[ 'rouge1'] * 100 > args.tensorrt_llm_rouge1_threshold if args.eval_ppl: logger.info( - f" Per-token perplexity: {np.mean(ppls_trt_llm[beam_idx])}" + f" Per-token perplexity: {np.mean(ppls_trt_llm[seq_idx][beam_idx])}" ) - if args.check_accuracy and beam_idx == 0: - avg_ppl = np.mean(ppls_trt_llm[beam_idx]) + if args.check_accuracy and seq_idx == 0 and beam_idx == 0: + avg_ppl = np.mean(ppls_trt_llm[seq_idx][beam_idx]) assert avg_ppl < args.tensorrt_llm_ppl_threshold, f"[FAILED] average PPL ({avg_ppl}) is larger than threshold ({args.tensorrt_llm_ppl_threshold})" if test_hf: np.random.seed(0) # rouge score use sampling to compute the score @@ -714,7 +728,7 @@ def eval_hf(datapoint, f'Hugging Face (tokens per second: {total_output_token_count_hf / profiler.elapsed_time_in_sec("hf")})' ) - for beam_idx in range(num_beams): + for beam_idx in range(num_returns_hf): logger.info(f"HF beam {beam_idx} result") computed_metrics_hf = metric_hf[beam_idx].compute() if args.eval_task != "eval_context_ppl": diff --git a/examples/utils.py b/examples/utils.py index 308f6dffa..de3361c3a 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -218,6 +218,10 @@ def add_common_args(parser): type=int, help="Use beam search if num_beams > 1", default=1) + parser.add_argument('--num_return_sequences', + type=int, + help="Number of sequences to generate for each input.", + default=1) parser.add_argument('--temperature', type=float, default=1.0) parser.add_argument('--top_k', type=int, default=1) parser.add_argument('--top_p', type=float, default=0.0) @@ -285,6 +289,9 @@ def add_common_args(parser): parser.add_argument('--enable_context_fmha_fp32_acc', action='store_true', help="Enable FMHA runner FP32 accumulation.") + parser.add_argument('--cuda_graph_mode', + action='store_true', + help="Enable cuda graphs in the inference.") parser.add_argument('--log_level', type=str, default='info') parser.add_argument( '--no_prompt_template', diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index 9c9fed2b4..5bf15bc3b 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.14.0.dev2024092401 +tensorrt_llm==0.14.0.dev2024100100 tiktoken datasets kaldialign diff --git a/requirements-dev-windows.txt b/requirements-dev-windows.txt index 819e0e3eb..2789ebaef 100644 --- a/requirements-dev-windows.txt +++ b/requirements-dev-windows.txt @@ -1,5 +1,5 @@ -r requirements-windows.txt ---extra-index-url https://download.pytorch.org/whl/cu121 +--extra-index-url https://download.pytorch.org/whl/cu124 datasets einops graphviz diff --git a/requirements-windows.txt b/requirements-windows.txt index 6592c0417..e3fc56308 100644 --- a/requirements-windows.txt +++ b/requirements-windows.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com ---extra-index-url https://download.pytorch.org/whl/cu121 +--extra-index-url https://download.pytorch.org/whl/cu124 accelerate==0.25.0 build colored @@ -16,10 +16,10 @@ h5py==3.10.0 pywin32 StrEnum sentencepiece>=0.1.99 -tensorrt~=10.3.0 +tensorrt~=10.4.0 tokenizers>=0.14 # Default torch is CPU-only on Windows, so need to specify a torch version with GPU support -torch==2.4.0+cu121 +torch==2.4.0+cu124 nvidia-modelopt~=0.15.0 transformers>=4.38.2 wheel diff --git a/requirements.txt b/requirements.txt index dfb97c25c..698662167 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ lark mpi4py numpy<2 onnx>=1.12.0 -openai +openai==1.39.0 polygraphy psutil pynvml>=11.5.0 @@ -17,7 +17,7 @@ pandas h5py==3.10.0 StrEnum sentencepiece>=0.1.99 -tensorrt~=10.3.0 +tensorrt~=10.4.0 # https://github.com/pytorch/pytorch/blob/v2.4.0/version.txt uses 2.4.0a0. # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-07.html#rel-24-07 uses 2.4.0a0. torch>=2.4.0a0,<=2.4.0 @@ -27,7 +27,6 @@ pillow==10.3.0 wheel optimum evaluate -janus mpmath>=1.3.0 click click_option_group diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py index 2d7df1e75..35282ce28 100644 --- a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py +++ b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py @@ -26,6 +26,7 @@ class IdxEntry(Enum): KV_CACHE_BLOCK_OFFSETS = auto() HOST_KV_CACHE_BLOCK_OFFSETS = auto() HOST_KV_CACHE_POOL_POINTERS = auto() + HOST_KV_CACHE_POOL_MAPPING = auto() PAST_KEY_VALUE = auto() KV_CACHE_QUANTIZATION_SCALE = auto() KV_CACHE_DEQUANTIZATION_SCALE = auto() @@ -101,6 +102,8 @@ def is_entry_used(self, entry: IdxEntry) -> bool: return self.use_cache and self.paged_kv_cache elif entry == IdxEntry.HOST_KV_CACHE_POOL_POINTERS: return self.use_cache and self.paged_kv_cache + elif entry == IdxEntry.HOST_KV_CACHE_POOL_MAPPING: + return self.use_cache and self.paged_kv_cache elif entry == IdxEntry.PAST_KEY_VALUE: return self.use_cache and not self.paged_kv_cache elif entry == IdxEntry.KV_CACHE_QUANTIZATION_SCALE: diff --git a/tensorrt_llm/bench/build/benchmark_config.yml b/tensorrt_llm/bench/build/benchmark_config.yml index 80a509f47..ab432c85d 100644 --- a/tensorrt_llm/bench/build/benchmark_config.yml +++ b/tensorrt_llm/bench/build/benchmark_config.yml @@ -44,6 +44,11 @@ meta-llama/Meta-Llama-3-8B: general: max_batch_size: 2048 max_num_tokens: 8192 +meta-llama/Meta-Llama-3.1-8B: + tp1_pp1: + general: + max_batch_size: 2048 + max_num_tokens: 8192 meta-llama/Meta-Llama-3-70B: tp1_pp1: general: @@ -64,6 +69,46 @@ meta-llama/Meta-Llama-3-70B: general: max_batch_size: 8192 max_num_tokens: 16384 +meta-llama/Meta-Llama-3.1-70B: + tp1_pp1: + general: + max_batch_size: 2048 + max_num_tokens: 2048 + tp2_pp1: + general: + max_batch_size: 256 + max_num_tokens: 1024 + 4096: + max_batch_size: 2048 + max_num_tokens: 1024 + tp4_pp1: + general: + max_batch_size: 2048 + max_num_tokens: 1024 + tp8_pp1: + general: + max_batch_size: 8192 + max_num_tokens: 16384 +meta-llama/Meta-Llama-3.1-405B: + tp8_pp1: + general: + max_batch_size: 320 + max_num_tokens: 5440 + 256: + max_batch_size: 2048 + max_num_tokens: 4096 + 2500: + max_batch_size: 320 + max_num_tokens: 512 + 4096: + max_batch_size: 192 + max_num_tokens: 512 + 5500: + max_batch_size: 192 + max_num_tokens: 512 + 22000: + max_batch_size: 64 + max_num_tokens: 768 mistralai/Mixtral-8x7B-v0.1: tp2_pp1: general: @@ -73,6 +118,10 @@ mistralai/Mixtral-8x7B-v0.1: general: max_batch_size: 8192 max_num_tokens: 8192 + tp8_pp1: + general: + max_batch_size: 8192 + max_num_tokens: 8192 mistralai/Mistral-7B-v0.1: tp1_pp1: general: diff --git a/tensorrt_llm/bench/run/dataclasses.py b/tensorrt_llm/bench/run/dataclasses.py index e234b4491..e507f0bdb 100644 --- a/tensorrt_llm/bench/run/dataclasses.py +++ b/tensorrt_llm/bench/run/dataclasses.py @@ -2,7 +2,7 @@ from importlib.util import find_spec from pathlib import Path -from typing import Any, List +from typing import Any, List, Optional from pydantic import (BaseModel, Field, PositiveFloat, computed_field, model_validator) @@ -93,12 +93,59 @@ def get_scheduler_config(self) -> trtllm.SchedulerConfig: ) -class ResponseRecord(BaseModel): - request_id: int - timestamp: float - output_tokens: List[int] - is_final: bool - has_error: bool +class RequestRecord(BaseModel): + id: int = -1 + num_input_tokens: int = -1 + tokens: List[int] = [] + error_tokens: int = 0 + start_timestamp: int = -1 + first_token_timestamp: int = -1 + end_timestamp: int = -1 + + def register_event(self, is_error: bool, is_final: bool, timestamp: int, + tokens: List[int]) -> None: + if is_final: + self.end_timestamp = timestamp + elif self.first_token_timestamp == -1: + self.first_token_timestamp = timestamp + + if is_error: + self.error_tokens += 1 + + self.tokens += tokens + + @computed_field + def num_output_tokens(self) -> int: + return len(self.tokens) + + @computed_field + def num_generated_tokens(self) -> int: + return self.num_output_tokens - 1 + + @computed_field + def generation_time(self) -> int: + return self.end_timestamp - self.time_to_first_token + + @computed_field + def time_to_first_token(self) -> int: + return self.first_token_timestamp - self.start_timestamp + + @computed_field + def intertoken_latency(self) -> float: + return (self.end_timestamp - + self.first_token_timestamp) / self.num_generated_tokens + + @computed_field + def end_to_end_latency(self) -> int: + return self.end_timestamp - self.start_timestamp + + @computed_field + def total_token_throughput(self) -> float: + return self.num_output_tokens / self.end_to_end_latency + + @computed_field + def output_token_throughput(self) -> float: + return self.num_output_tokens / self.generation_time class PercentileStats(BaseModel): @@ -112,43 +159,17 @@ class PercentileStats(BaseModel): @classmethod def from_iterable(cls, values: List[Any]) -> PercentileStats: length = len(values) + sorted_values = sorted(values) return cls( - p50=values[int(length * 0.50)], - p95=values[int(length * 0.95)], - p99=values[int(length * 0.99)], + p50=sorted_values[int(length * 0.50)], + p95=sorted_values[int(length * 0.95)], + p99=sorted_values[int(length * 0.99)], average=float(sum(values)) / length, minimum=min(values), maximum=max(values), ) -class RequestStats(BaseModel): - request_id: int - input_tokens: int - time_log: List[float] = Field(default_factory=list, init=False) - error_responses: int = Field(default=0, init=False) - num_responses: int = Field(default=0, init=False) - num_tokens: int = Field(default=0, init=False) - - @computed_field - def first_token_latency(self) -> float: - try: - return self.time_log[1] - self.time_log[0] - except IndexError: - return 0 - - @computed_field - def request_latency(self) -> float: - return max(self.time_log) - min(self.time_log) - - def register_event(self, is_error: bool, is_response: bool, - timestamp: float, num_tokens: int) -> None: - self.time_log.append(timestamp) - self.error_responses += 1 if is_error else 0 - self.num_responses += 1 if is_response else 0 - self.num_tokens += num_tokens - - class BenchmarkStatistics(BaseModel): total_latency_ns: float total_output_tokens: int @@ -156,8 +177,10 @@ class BenchmarkStatistics(BaseModel): num_requests: int issue_rate_ns: float - request_percentiles: PercentileStats = None - token_percentiles: PercentileStats = None + request_percentiles: Optional[PercentileStats] = None + token_percentiles: Optional[PercentileStats] = None + itl_percentiles: Optional[PercentileStats] = None + ttft_percentiles: Optional[PercentileStats] = None @computed_field def token_throughput_ns(self) -> float: diff --git a/tensorrt_llm/bench/run/run.py b/tensorrt_llm/bench/run/run.py index 52eed4f5a..4cf22d836 100644 --- a/tensorrt_llm/bench/run/run.py +++ b/tensorrt_llm/bench/run/run.py @@ -15,8 +15,10 @@ import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm.bench.dataclasses import BenchmarkEnvironment from tensorrt_llm.bench.enums import IFBSchedulingPolicy -from tensorrt_llm.bench.run.dataclasses import ResponseRecord, RuntimeConfig -from tensorrt_llm.bench.run.utils import (StatsKeeper, get_executor_request, +from tensorrt_llm.bench.run.dataclasses import (BenchmarkStatistics, + RuntimeConfig) +from tensorrt_llm.bench.run.utils import (ResponseTuple, StatsKeeper, + get_executor_request, get_settings_from_engine) from tensorrt_llm.bench.utils.data import generate_dataset_from_stream from tensorrt_llm.logger import logger @@ -83,6 +85,12 @@ help="Number of requests to cap benchmark run at. Minimum between value and" "length of dataset.", ) +@click.option( + "--streaming", + is_flag=True, + default=False, + help="Enable streaming mode for requests.", +) @click.pass_obj def run_command( bench_env: BenchmarkEnvironment, @@ -113,6 +121,7 @@ def run_command( runtime_max_tokens = runtime_max_bs if runtime_max_tokens else engine_tokens kv_cache_percent = params.pop("kv_cache_free_gpu_mem_fraction") beam_width = params.pop("beam_width") + streaming = params.pop("streaming") # Update configuration with runtime options exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent @@ -138,7 +147,10 @@ def run_command( while requests: request = requests.pop() executor_requests.append( - get_executor_request(request, pad_id=-1, eos_id=-1)) + get_executor_request(request, + pad_id=-1, + eos_id=-1, + streaming=streaming)) del request logger.info("Setting up benchmarker and infrastructure.") @@ -151,6 +163,7 @@ def run_command( runtime_cfg=runtime_config, request_queue=new_request_queue, response_queue=response_queue, + streaming=streaming, ) logger.set_level("info") try: @@ -195,7 +208,7 @@ def __init__(self, runtime_cfg: RuntimeConfig, self.response_thread = Thread(target=self.response_daemon) self.response_thread.start() - def enqueue(self, *requests: trtllm.Request) -> Generator[int]: + def enqueue(self, *requests: trtllm.Request) -> Generator[Tuple[int, int]]: """Generate the next request identifier. Yields: @@ -227,17 +240,14 @@ def response_daemon(self) -> None: def _process_response() -> None: responses = self.executor.await_responses(timeout=timedelta( - milliseconds=1)) + microseconds=0.00000000000001)) now = monotonic_ns() - for response in responses: - # logger.info("Pushing response to queue") - self.responses.put( - ResponseRecord( - timestamp=now, - request_id=response.request_id, - has_error=response.has_error(), - is_final=response.result.is_final, - output_tokens=response.result.output_token_ids[0])) + if len(responses) > 0: + self.responses.put([ + ResponseTuple(now, r.request_id, r.result.is_final, + r.has_error(), r.result.output_token_ids[0]) + for r in responses + ]) while not self._shutdown.is_set(): _process_response() @@ -259,6 +269,7 @@ def __init__( runtime_cfg: RuntimeConfig, request_queue: mp.Queue, response_queue: mp.Queue, + streaming: bool, ) -> None: """Initialize the throughput benchmark. @@ -280,6 +291,7 @@ def __init__( # Runtime configuration for Executor self.runtime_config = deepcopy(runtime_cfg) + self.streaming = streaming self.executor = None # Request and response reporting structures @@ -364,8 +376,17 @@ def _process_requests() -> None: new_request[2]) while not self.response_queue.empty(): - response: ResponseRecord = self.response_queue.get_nowait() - self.statistics.register_response(response) + responses: Tuple[ + int, + List[trtllm.Response]] = self.response_queue.get_nowait() + for response in responses: + self.statistics.register_response( + response.request_id, + response.timestamp, + response.final, + response.error, + response.tokens, + ) logger.info("Collecting live stats...") # TODO: Revisit this conditional, if the request rate is slow enough this @@ -382,7 +403,7 @@ def _process_requests() -> None: self.parsing_complete.set() logger.info("Ending statistics collection.") - def report_statistics(self) -> None: + def report_statistics(self) -> BenchmarkStatistics: """Report internal statistics about benchmark.""" config_path = self.runtime_config.engine_dir / "config.json" @@ -395,8 +416,8 @@ def report_statistics(self) -> None: pretrain_cfg = engine_config["pretrained_config"] total_latency_s = stats.total_latency_ns / 1.0e9 - logger.info( - "\n===========================================================\n" + logging_info = ( + "\n\n===========================================================\n" "= ENGINE DETAILS\n" "===========================================================\n" f"Model:\t\t\t{rt_cfg.model}\n" @@ -416,16 +437,34 @@ def report_statistics(self) -> None: f"Max Runtime Batch Size:\t{rt_cfg.settings_config.max_batch_size}\n" f"Max Runtime Tokens:\t{rt_cfg.settings_config.max_num_tokens}\n" f"Scheduling Policy:\t{rt_cfg.settings_config.scheduler_policy.values[1]}\n" - f"KV Memory Percentage:\t{rt_cfg.settings_config.kv_cache_percent * 100.0}%\n" - f"Issue Rate (req/sec):\t{stats.issue_rate_ns * 1e9}" + f"KV Memory Percentage:\t{rt_cfg.settings_config.kv_cache_percent * 100.0:.2f}%\n" + f"Issue Rate (req/sec):\t{stats.issue_rate_ns * 1e9:.4E}\n" f"\n" "===========================================================\n" - "= STATISTICS\n" + "= PERFORMANCE OVERVIEW \n" "===========================================================\n" f"Number of requests:\t\t{stats.num_requests}\n" - f"Average Input Length (tokens):\t{stats.average_input_length}\n" - f"Average Output Length (tokens):\t{stats.average_output_length}\n" - f"Token Throughput (tokens/sec):\t{stats.total_output_tokens / total_latency_s}\n" - f"Request Throughput (req/sec):\t{stats.num_requests / total_latency_s}\n" - f"Total Latency (seconds):\t{total_latency_s}\n" - "===========================================================\n") + f"Average Input Length (tokens):\t{stats.average_input_length:.4f}\n" + f"Average Output Length (tokens):\t{stats.average_output_length:.4f}\n" + f"Token Throughput (tokens/sec):\t{stats.total_output_tokens / total_latency_s:.4f}\n" + f"Request Throughput (req/sec):\t{stats.num_requests / total_latency_s:.4f}\n" + f"Total Latency (ms):\t\t{stats.total_latency_ns * 1.0e-6:.4f}\n") + + if self.streaming: + logging_info = ( + f"{logging_info}" + "\n" + "===========================================================\n" + "= STREAMING STATISTICS \n" + "===========================================================\n" + f"Average request latency (ms):\t\t{stats.request_percentiles.average * 1.0e-6:.4f}\n" + f"Average time-to-first-token (ms):\t{stats.ttft_percentiles.average * 1.0e-6:.4f}\n" + f"Average inter-token latency (ms):\t{stats.itl_percentiles.average * 1.0e-6:.4f}\n" + ) + + logging_info = ( + f"{logging_info}" + "\n===========================================================\n") + + logger.info(logging_info) + return stats diff --git a/tensorrt_llm/bench/run/utils.py b/tensorrt_llm/bench/run/utils.py index 73d176db9..ad69d160e 100644 --- a/tensorrt_llm/bench/run/utils.py +++ b/tensorrt_llm/bench/run/utils.py @@ -1,16 +1,18 @@ from __future__ import annotations import json -from collections import defaultdict +from collections import defaultdict, namedtuple from pathlib import Path -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple, Union import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm.bench.run.dataclasses import (BenchmarkStatistics, - PercentileStats, RequestStats, - ResponseRecord) + PercentileStats, RequestRecord) from tensorrt_llm.bindings import InferenceRequest +ResponseTuple = namedtuple( + "ResponseTuple", ["timestamp", "request_id", "final", "error", "tokens"]) + def get_executor_request(request: InferenceRequest, pad_id: int, @@ -66,60 +68,51 @@ def get_settings_from_engine( class StatsKeeper: def __init__(self) -> None: - self.requests: RequestStats = {} + self.requests: Dict[RequestRecord] = defaultdict(RequestRecord) self.num_complete: int = 0 - self._unseen_cache = defaultdict(list) - def register_request( self, request_id: int, timestamp: float, num_tokens: int, ) -> None: - request = RequestStats(request_id=request_id, input_tokens=num_tokens) - request.register_event(False, False, timestamp, 0) - self.requests[request_id] = request - - def register_response(self, response: ResponseRecord) -> None: - request_id = response.request_id - - if request_id not in self.requests: - self._unseen_cache[request_id].append(response) - else: - self.requests[request_id].register_event( - is_error=response.has_error, - is_response=True, - timestamp=response.timestamp, - num_tokens=len(response.output_tokens)) + record = self.requests[request_id] + record.num_input_tokens = num_tokens + record.start_timestamp = timestamp - if response.is_final: - self.num_complete += 1 + def register_response(self, request_id: int, timestamp: int, final: bool, + error: bool, tokens: List[int]) -> None: + record = self.requests[request_id] + record.register_event(error, final, timestamp, tokens) + if final: + self.num_complete = self.num_complete + 1 def generate_statistics_summary(self) -> None: total_output_tokens: int = 0 total_input_tokens: int = 0 num_requests = len(self.requests) - total_request_latency: float = 0.0 start_time = float("inf") end_time = -1 request_latencies = [] + intertoken_avg_latencies = [] + ttft_times = [] last_queue_time = 0.0 queue_time_total = 0.0 for entry in self.requests.values(): - entry.time_log.sort() + start_time = min(entry.start_timestamp, start_time) + end_time = max(entry.end_timestamp, end_time) + queue_time_total += entry.start_timestamp - last_queue_time + last_queue_time = entry.start_timestamp - queue_time_total += entry.time_log[0] - last_queue_time - last_queue_time = entry.time_log[0] + request_latencies.append(entry.end_to_end_latency) + ttft_times.append(entry.time_to_first_token) + intertoken_avg_latencies.append(entry.intertoken_latency) - request_latencies.append(entry.request_latency) - total_output_tokens += entry.num_tokens - total_input_tokens += entry.input_tokens - total_request_latency += entry.request_latency - start_time = min(start_time, entry.time_log[0]) - end_time = max(end_time, entry.time_log[-1]) + total_output_tokens += entry.num_output_tokens + total_input_tokens += entry.num_input_tokens stats = BenchmarkStatistics( num_requests=num_requests, @@ -128,6 +121,9 @@ def generate_statistics_summary(self) -> None: total_input_tokens=total_input_tokens, request_percentiles=PercentileStats.from_iterable( request_latencies), + itl_percentiles=PercentileStats.from_iterable( + intertoken_avg_latencies), + ttft_percentiles=PercentileStats.from_iterable(ttft_times), issue_rate_ns=queue_time_total / num_requests) return stats diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 335f5c139..5849bc20d 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -974,14 +974,11 @@ def deserialize_managed_weights(path: str | Path) -> dict[str, np.ndarray]: return managed_weights -def build(model: PretrainedModel, - build_config: BuildConfig, - return_build_config: bool = False) -> Engine | BuildConfig: +def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: '''Build engine from given model and optimization options specified in the build_config WARNING: this function may change the given \p model object state in some optimization passes to avoid cloning a model since normally the LLM models consumes large memory. Create a new fresh model object if you need to build with different options. - ''' tic = time.time() # avoid changing the input config @@ -991,6 +988,12 @@ def build(model: PretrainedModel, _init_max_seq_len(model.config, build_config) + if build_config.plugin_config.reduce_fusion and ( + model.config.mapping.tp_size == 1 + or model.config.architecture != "LlamaForCausalLM"): + logger.warning('Overriding reduce_fusion to False') + build_config.plugin_config.reduce_fusion = False + if model.config.quantization.quant_algo == QuantAlgo.FP8 or \ model.config.quantization.kv_cache_quant_algo == QuantAlgo.FP8: build_config.strongly_typed = True @@ -1097,11 +1100,6 @@ def build(model: PretrainedModel, nccl_plugin = model.config.dtype if model.config.mapping.world_size > 1 else None network.plugin_config.set_nccl_plugin(nccl_plugin) - # NOTE: Please never change the build_config object after this point! - if return_build_config: - # Get an modified build_config that is the same as the one in the final engine dir - return build_config - with net_guard(network): # Prepare network.set_named_parameters(model.named_parameters()) @@ -1150,6 +1148,10 @@ def build(model: PretrainedModel, "max_batch_size": build_config.max_batch_size, } + if build_config.speculative_decoding_mode == SpeculativeDecodingMode.LOOKAHEAD_DECODING: + prepare_input_args[ + "spec_decoding_is_generation_length_variable"] = True + inputs = model.prepare_inputs(**prepare_input_args) model(**inputs) diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index 609b4dac5..e4efeb509 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -342,9 +342,6 @@ def build_model( "StreamingLLM is only supported in the llama model." real_rank = rank - if build_config.plugin_config.reduce_fusion and model_config.mapping.tp_size == 1: - build_config.plugin_config.reduce_fusion = False - model_config.mapping.gpus_per_node = build_config.auto_parallel_config.gpus_per_node if build_config.auto_parallel_config.enabled: assert rank < build_config.auto_parallel_config.world_size diff --git a/tensorrt_llm/executor.py b/tensorrt_llm/executor.py index 7c6578b97..eaafdb5a3 100644 --- a/tensorrt_llm/executor.py +++ b/tensorrt_llm/executor.py @@ -2,6 +2,7 @@ import atexit import concurrent.futures import datetime +import io import json import secrets import time @@ -9,14 +10,14 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from multiprocessing.connection import Client, Listener +from multiprocessing.shared_memory import SharedMemory from pathlib import Path from queue import Queue -from typing import (Any, Dict, Generator, List, NamedTuple, Optional, Tuple, - Union) +from typing import (Any, Dict, Generator, List, Literal, NamedTuple, Optional, + Tuple, Union) import numpy as np import torch -from janus import Queue as AsyncQueue from ._utils import mpi_rank, mpi_world_size from .bindings import executor as tllm @@ -24,7 +25,7 @@ from .hlapi.mpi_session import (MpiPoolSession, MpiSession, external_mpi_comm_available, find_free_port, need_spawn_mpi_workers) -from .hlapi.utils import ManagedThread, SamplingParams, exception_handler +from .hlapi.utils import ManagedThread, SamplingParams from .lora_manager import LoraManager from .runtime import ModelConfig from .runtime.model_runner import _engine_config_to_model_config @@ -98,20 +99,42 @@ class CompletionOutput: token_ids (List[int]): The token ids of the generated output text. cumulative_logprob (float): The cumulative log probability of the generated output text. logprobs (List[float]): The log probabilities of the top probability words at each position if the logprobs are requested. + finish_reason (Literal['stop', 'length']): The reason why the sequence is finished. + stop_reason (Union[int, str]): The stop string or token id that caused the completion to stop, None if the completion finished for some other reason. generation_logits (torch.Tensor): The logits on the generated output token ids. + length (int): The number of generated tokens. + token_ids_diff (List[int]): Newly generated token ids. + logprobs_diff (List[float]): Logprobs of newly generated tokens. + text_diff (str): Newly generated tokens. """ index: int text: str = "" token_ids: List[int] = field(default_factory=list) cumulative_logprob: Optional[float] = None logprobs: List[float] = field(default_factory=list) - generation_logits: Optional[torch.Tensor] = field(default=None, repr=False) + finish_reason: Optional[Literal['stop', 'length']] = None + stop_reason: Optional[Union[int, str]] = None + generation_logits: Optional[torch.Tensor] = None _last_text: str = field(default="", init=False, repr=False) + _last_logprobs_len: int = field(default=0, init=False, repr=False) + _last_token_ids_len: int = field(default=0, init=False, repr=False) @property def length(self): return len(self.token_ids) + @property + def token_ids_diff(self) -> List[int]: + diff = self.token_ids[self._last_token_ids_len:] + self._last_token_ids_len = len(self.token_ids) + return diff + + @property + def logprobs_diff(self) -> List[float]: + diff = self.logprobs[self._last_logprobs_len:] + self._last_logprobs_len = len(self.logprobs) + return diff + @property def text_diff(self) -> str: diff = self.text[len(self._last_text):] @@ -119,6 +142,68 @@ def text_diff(self) -> str: return diff +class _SyncQueue: + ''' + A simplified Queue that provides a `get` method that is compatible with the asyncio event loop. + ''' + + def __init__(self, + queue: Queue, + event: asyncio.Event, + loop: Optional[asyncio.AbstractEventLoop] = None): + self._q = queue + self._event = event + self._loop = loop or asyncio.get_event_loop() + + def put(self, item) -> None: + + async def _set_event(event): + event.set() + + self._q.put_nowait(item) + + if self._loop.is_running(): + asyncio.run_coroutine_threadsafe(_set_event(self._event), + self._loop) + else: + raise AsyncQueue.EventLoopShutdownError + + def full(self) -> bool: + return self._q.full() + + +class _AsyncQueue: + ''' + A simplified asyncio.Queue that provides a `get` method that is compatible with the standard library Queue. + ''' + + def __init__(self, queue: Queue): + self._event = asyncio.Event() + self._q = queue + + async def get(self): + await self._event.wait() + res = self._q.get() + if self._q.empty(): + self._event.clear() + return res + + +class AsyncQueue: + ''' + AsyncQueue is container containing `async_q` for `async get` and `sync_q` for sync `get`. + This is used to provide a compatible interface for janus.Queue. + ''' + + class EventLoopShutdownError(Exception): + pass + + def __init__(self): + self._q = Queue() + self.async_q = _AsyncQueue(self._q) + self.sync_q = _SyncQueue(self._q, self.async_q._event) + + class CppExecutorError(RuntimeError): def __init__(self, message: Optional[str] = None): @@ -204,14 +289,23 @@ def handle_response(self, response: "GenerationExecutor.Response"): self.outputs[i].generation_logits = tensors.generation_logits[ i, :self.outputs[i].length] - if self.finished and not self._generation_request.sampling_params.include_stop_str_in_output: - for beam_output in self.outputs: - for stop_ids in self._generation_request.sampling_params._get_stop_words( - ): - if beam_output.token_ids[-len(stop_ids):] == stop_ids: - beam_output.token_ids = beam_output.token_ids[:-len( - stop_ids)] - break + if self.finished: + for i, beam_output in enumerate(self.outputs): + if response.finish_reasons[i] == tllm.FinishReason.END_ID: + beam_output.finish_reason = 'stop' + elif response.finish_reasons[i] == tllm.FinishReason.STOP_WORDS: + beam_output.finish_reason = 'stop' + sampling_params = self._generation_request.sampling_params + for stop_reason, stop_ids in sampling_params._get_stop_reasons_and_words( + ): + if beam_output.token_ids[-len(stop_ids):] == stop_ids: + beam_output.stop_reason = stop_reason + if not sampling_params.include_stop_str_in_output: + beam_output.token_ids = beam_output.token_ids[:-len( + stop_ids)] + break + elif response.finish_reasons[i] == tllm.FinishReason.LENGTH: + beam_output.finish_reason = 'length' if tensors.context_logits is not None: self.context_logits = tensors.context_logits @@ -281,7 +375,10 @@ def exception(self, timeout: Optional[float] = None): return e def _repr_fields(self): - return ['request_id', 'prompt_token_ids', 'outputs', 'finished'] + return [ + 'request_id', 'prompt_token_ids', 'outputs', 'finished', + "context_logits" + ] def __repr__(self) -> str: repr = [] @@ -305,8 +402,10 @@ class GenerationExecutor(ABC): class ResponseTensors(NamedTuple): output_token_ids: list - context_logits: Optional[torch.Tensor] - generation_logits: Optional[torch.Tensor] + # context_logits is a tensor or a string denoting the path to the shared memory. + context_logits: Optional[torch.Tensor | str] + # generation_logits is a tensor or a string denoting the path to the shared memory. + generation_logits: Optional[torch.Tensor | str] log_probs: Optional[list] cum_log_probs: Optional[list] @@ -314,6 +413,7 @@ class Response(NamedTuple): """ The response from the cpp-executor to the Python main thread. """ request_id: int tensors: Optional["GenerationExecutor.ResponseTensors"] + finish_reasons: Optional[List[tllm.FinishReason]] is_final: Optional[bool] # error is either str from cpp-executor or a Exception from Python threads/processes error: Optional[str | Exception] @@ -327,6 +427,8 @@ def __init__(self): self._stats = None self.stats_queue = None + atexit.register(self.shutdown) + # This is used to capture the exceptions from the threads. self._error_queue = Queue() @@ -334,9 +436,6 @@ def __init__(self): self._pending_responses: Dict[ int, List[GenerationExecutor.PendingResponse]] = {} - exception_handler.register(self, 'shutdown') - atexit.register(self.shutdown) - @abstractmethod def submit(self, request: GenerationRequest) -> GenerationResult: pass @@ -406,6 +505,7 @@ def _handle_background_error(self): # more than one error. if not self._error_queue.empty(): e = self._error_queue.get() + self.shutdown() # We can catch some exceptions here. raise e @@ -567,9 +667,13 @@ def __init__( self._lora_manager = LoraManager() self.await_response_thread = ManagedThread( - self.await_response_task, error_queue=self._error_queue) + self.await_response_task, + error_queue=self._error_queue, + name="await_response_thread") self.dispatch_stats_thread = ManagedThread( - self.dispatch_stats_task, error_queue=self._error_queue) + self.dispatch_stats_task, + error_queue=self._error_queue, + name="dispatch_stats_thread") def create_stats_queue(self): # Stats queue is created during first submission to ensure event loop exists if it is needed. @@ -619,10 +723,12 @@ def await_response_task(self) -> bool: milliseconds=100)): req_id = response.request_id if response.has_error(): - rsp = self.Response(req_id, - tensors=None, - is_final=None, - error=response.error_msg) + rsp = self.Response( + req_id, + tensors=None, + finish_reasons=response.result.finish_reasons, + is_final=None, + error=response.error_msg) else: tensors = self.ResponseTensors( response.result.output_token_ids, @@ -630,10 +736,12 @@ def await_response_task(self) -> bool: response.result.generation_logits, response.result.log_probs, response.result.cum_log_probs) - rsp = self.Response(req_id, - tensors, - is_final=response.result.is_final, - error=None) + rsp = self.Response( + req_id, + tensors, + finish_reasons=response.result.finish_reasons, + is_final=response.result.is_final, + error=None) if self._to_delay_response(rsp): continue @@ -647,6 +755,7 @@ def await_response_task(self) -> bool: if bck_error is not None: rsp = self.Response(req_id, tensors=None, + finish_reasons=None, is_final=None, error=bck_error) @@ -663,7 +772,14 @@ def dispatch_stats_task(self) -> bool: for stats in self.engine.get_latest_iteration_stats(): while hasattr(self.stats_queue, "full") and self.stats_queue.full(): self.stats_queue.get() - self.stats_queue.put(stats.to_json_str()) + + try: + self.stats_queue.put(stats.to_json_str()) + except AsyncQueue.EventLoopShutdownError: + # This happens in the last stats loop while the generate workflow is stopped. + pass + except Exception as e: + raise e return True # success @@ -816,29 +932,102 @@ def setup(self): def put(self, obj: Any): if self.conn is None: self.setup() + + if isinstance(obj, GenerationExecutor.Response): + tensors = self._store_tensors_in_shmm(obj.tensors) + obj = GenerationExecutor.Response(request_id=obj.request_id, + tensors=tensors, + finish_reasons=obj.finish_reasons, + is_final=obj.is_final, + error=obj.error) + self.conn.send(obj) def get(self) -> Any: if self.conn is None: self.setup() - return self.conn.recv() + + obj = self.conn.recv() + if isinstance(obj, GenerationExecutor.Response): + tensors = self._load_tensors_from_shmm(obj.tensors) + obj = GenerationExecutor.Response(request_id=obj.request_id, + tensors=tensors, + finish_reasons=obj.finish_reasons, + is_final=obj.is_final, + error=obj.error) + return obj + + def _store_tensors_in_shmm( + self, tensors: GenerationExecutor.ResponseTensors + ) -> GenerationExecutor.ResponseTensors: + # The tensors are huge and cannot be transferred through socket directly. We need to store them in shared memory, + # and replace the tensors with the shared memory path. + def store_tensor(tensor: Optional[torch.Tensor]) -> Optional[str]: + if tensor is None: + return None + # NOTE: We create random shmm here rather than two specific shmm for context and generation logit, since the + # shmm may not be read timely by the IpcQueue.get() in the other side, so there might be multiple alive shmm + # for logits. + # A known issue: the shmm instance may leak if the IpcQueue.get() thread is stopped before the IpcQueue.put() + # thread. This is not a big issue since the shmm will be automatically cleaned up when the process exits. + shm = SharedMemory(create=True, size=tensor.nbytes + 2048) + torch.save(tensor, shm._mmap) + shm.close() + return shm.name + + return GenerationExecutor.ResponseTensors( + output_token_ids=tensors.output_token_ids, + context_logits=store_tensor(tensors.context_logits), + generation_logits=store_tensor(tensors.generation_logits), + log_probs=tensors.log_probs, + cum_log_probs=tensors.cum_log_probs, + ) + + def _load_tensors_from_shmm( + self, tensors: GenerationExecutor.ResponseTensors + ) -> GenerationExecutor.ResponseTensors: + + def load_tensor(tensor: Optional[str]) -> Optional[torch.Tensor]: + if tensor is None or isinstance(tensor, torch.Tensor): + return tensor + + shm = SharedMemory(name=tensor, create=False) + tensor = torch.load(io.BytesIO(shm.buf)) + shm.close() + shm.unlink() + return tensor + + return GenerationExecutor.ResponseTensors( + output_token_ids=tensors.output_token_ids, + context_logits=load_tensor(tensors.context_logits), + generation_logits=load_tensor(tensors.generation_logits), + log_probs=tensors.log_probs, + cum_log_probs=tensors.cum_log_probs, + ) @property def address(self) -> Tuple[str, int, bytes]: return (self.host_port[0], self.host_port[1], self.authkey) + def __del__(self): + if self.conn is not None: + self.conn.close() + if self.is_server: + self.listener.close() + class ExecutorBindingsProxy(GenerationExecutor): - def __init__( - self, - workers_kwargs, - model_world_size: int = 1, - mpi_session: Optional[MpiSession] = None, - ) -> None: + def __init__(self, + workers_kwargs, + model_world_size: int = 1, + mpi_session: Optional[MpiSession] = None, + *, + worker_cls: type = ExecutorBindingsWorker) -> None: super().__init__() self.workers_started = False + self.worker_cls = worker_cls self.request_queue = IpcQueue(is_server=True) # Return request id back to dispatcher @@ -868,22 +1057,23 @@ def __init__( }) self.dispatch_result_thread = ManagedThread( - self.dispatch_result_task, error_queue=self._error_queue) + self.dispatch_result_task, + error_queue=self._error_queue, + name="proxy_dispatch_result_thread") self.dispatch_stats_thread = ManagedThread( - self.dispatch_stats_task, error_queue=self._error_queue) - - exception_handler.register(self, 'shutdown') - atexit.register(self.shutdown) + self.dispatch_stats_task, + error_queue=self._error_queue, + name="proxy_dispatch_stats_thread") @staticmethod - def workers_main( - engine: Union[Path, Engine], - request_queue_addr: Tuple[str, int, bytes], - request_id_queue_addr: Tuple[str, int, bytes], - result_queue_addr: Tuple[str, int, bytes], - stats_queue_addr: Tuple[str, int, bytes], - executor_config: tllm.ExecutorConfig = tllm.ExecutorConfig(1) - ) -> None: + def workers_main(engine: Union[Path, Engine], + request_queue_addr: Tuple[str, int, bytes], + request_id_queue_addr: Tuple[str, int, bytes], + result_queue_addr: Tuple[str, int, bytes], + stats_queue_addr: Tuple[str, int, bytes], + executor_config: tllm.ExecutorConfig = tllm.ExecutorConfig( + 1), + worker_cls: type = ExecutorBindingsWorker) -> None: result_queue = None if mpi_rank() == 0: @@ -899,7 +1089,7 @@ def notify_proxy_threads_to_quit(): mp_stats_queue.put(None) try: - executor = ExecutorBindingsWorker(engine, executor_config) + executor = worker_cls(engine, executor_config) except Exception as e: raise CppExecutorError(f"Failed to initialize executor: {e}") from e @@ -917,7 +1107,7 @@ def notify_proxy_threads_to_quit(): notify_proxy_threads_to_quit() except ExecutorBindingsWorker.WorkerExit as e: - raise e + raise e # This will capture by the with-statement and exit normally. except Exception as e: # other critical errors if mpi_rank() == 0: @@ -947,14 +1137,29 @@ def dispatch_result_task(self) -> bool: return True # success def dispatch_stats_task(self) -> bool: - if (stats := self.mp_stats_queue.get()) is None: - return False # shutdown the thread - # get-stats is not urgent, so we can sleep a bit + time.sleep(0.1) + + try: + stats = self.mp_stats_queue.get() + except: + return False + + if stats is None: + return False + while self.stats_queue.full(): self.stats_queue.get() - self.stats_queue.put(stats) + + try: + self.stats_queue.put(stats) + except AsyncQueue.EventLoopShutdownError: + # This happens in the last stats loop while the generate workflow is stopped. + pass + except Exception as e: + raise e + return True # success def start(self): @@ -966,7 +1171,9 @@ def mpi_done_callback(future: concurrent.futures.Future): self._error_queue.put_nowait(future.exception()) self.mpi_futures = self.mpi_session.submit( - ExecutorBindingsProxy.workers_main, **self.workers_kwargs) + ExecutorBindingsProxy.workers_main, + **self.workers_kwargs, + worker_cls=self.worker_cls) for fut in self.mpi_futures: fut.add_done_callback(mpi_done_callback) @@ -988,8 +1195,10 @@ def shutdown(self): f.result() if self.dispatch_result_thread.is_alive(): - self.dispatcher.join() + self.dispatch_result_thread.stop() + self.dispatch_result_thread.join() if self.dispatch_stats_thread.is_alive(): + self.dispatch_stats_thread.stop() self.dispatch_stats_thread.join() # It is possible that some requests are still pending in the workers, we need to process them before shutdown diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 797c9ec2c..082a83561 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -4594,6 +4594,7 @@ def gpt_attention( kv_cache_block_offsets: Optional[Tensor] = None, host_kv_cache_block_offsets: Tensor = None, host_kv_cache_pool_pointers: Tensor = None, + host_kv_cache_pool_mapping: Tensor = None, do_cross_attention: bool = False, cross_qkv: Optional[Tensor] = None, # for cross attention cross_qkv_length: Optional[Tensor] = None, # for cross attention @@ -4609,6 +4610,7 @@ def gpt_attention( spec_decoding_position_offsets: Tensor = None, spec_decoding_packed_mask: Tensor = None, host_runtime_perf_knobs: Optional[Tensor] = None, + layer_idx_in_cache_pool: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: ''' Add an operation that performs the multi-head attention in GPT-like models. @@ -4785,9 +4787,12 @@ def gpt_attention( The same as kv_cache_block_offsets, but on cpu, host_kv_cache_pool_pointers: - The tensor of pool pointers for the KV cache. Its shape is [2], + The tensor of pool pointers for the KV cache. Its shape is [num_layers, 2], See KV cache section in docs/source/advanced/gpt-attention.md, on gpu, + host_kv_cache_pool_mapping: + The tensor of pool mapping for the different memory pools. Its shape is [num_layers,], + do_cross_attention: bool = False Do we use this as cross attention instead of self attention, @@ -4861,6 +4866,9 @@ def gpt_attention( assert host_max_attention_window_sizes is not None assert host_sink_token_length is not None + if layer_idx_in_cache_pool is None: + layer_idx_in_cache_pool = layer_idx + paged_kv_cache_flag = default_net().plugin_config.paged_kv_cache if isinstance(qkv, list): is_unfuse_qkv_gemm = 1 @@ -4884,6 +4892,10 @@ def gpt_attention( num_kv_heads = trt.PluginField("num_kv_heads", np.array(num_kv_heads, dtype=np.int32), trt.PluginFieldType.INT32) + layer_idx_in_cache_pool = trt.PluginField( + "layer_idx_in_cache_pool", + np.array(layer_idx_in_cache_pool, dtype=np.int32), + trt.PluginFieldType.INT32) head_size = trt.PluginField("head_size", np.array(hidden_size_per_head, dtype=np.int32), trt.PluginFieldType.INT32) @@ -5034,13 +5046,14 @@ def gpt_attention( trt.PluginFieldType.INT32) pfc = trt.PluginFieldCollection([ - layer_idx, nheads, vision_start, vision_length, num_kv_heads, head_size, - unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, - rotary_embedding_dim, rotary_embedding_base, - rotary_embedding_scale_type, rotary_embedding_scale, - rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, - rotary_embedding_max_positions, rotary_embedding_original_max_positions, - tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type, enable_xqa, + layer_idx, nheads, vision_start, vision_length, num_kv_heads, + layer_idx_in_cache_pool, head_size, unidirectional, q_scaling, + qk_tanh_scale, position_embedding_type, rotary_embedding_dim, + rotary_embedding_base, rotary_embedding_scale_type, + rotary_embedding_scale, rotary_embedding_short_m_scale, + rotary_embedding_long_m_scale, rotary_embedding_max_positions, + rotary_embedding_original_max_positions, tp_size, tp_rank, + unfuse_qkv_gemm, context_fmha_type, enable_xqa, kv_cache_quant_mode_field, remove_input_padding, mask_type, block_sparse_block_size, block_sparse_homo_head_pattern, block_sparse_num_local_blocks, block_sparse_vertical_stride, @@ -5079,9 +5092,10 @@ def gpt_attention( assert kv_cache_block_offsets is not None, "Paged kv cache is enabled, the kv_cache_block_offsets tensor shall not be None" assert host_kv_cache_block_offsets is not None, "Paged kv cache is enabled, the host_kv_cache_block_offsets tensor shall not be None" assert host_kv_cache_pool_pointers is not None, "Paged kv cache is enabled, the host_kv_cache_pool_pointers tensor shall not be None" + assert host_kv_cache_pool_mapping is not None, "Paged kv cache is enabled, the host_kv_cache_pool_mapping tensor shall not be None" plug_inputs += [ kv_cache_block_offsets, host_kv_cache_block_offsets, - host_kv_cache_pool_pointers + host_kv_cache_pool_pointers, host_kv_cache_pool_mapping ] else: plug_inputs += [past_key_value] diff --git a/tensorrt_llm/hlapi/build_cache.py b/tensorrt_llm/hlapi/build_cache.py index ccb1af824..30fcfffba 100644 --- a/tensorrt_llm/hlapi/build_cache.py +++ b/tensorrt_llm/hlapi/build_cache.py @@ -35,6 +35,10 @@ class BuildCacheConfig: cache_root (str): The root directory for the build cache. max_records (int): The maximum number of records to store in the cache. max_cache_storage_gb (float): The maximum amount of storage (in GB) to use for the cache. + + Note: + The build-cache assumes the weights of the model are not changed during the execution. If the weights are + changed, you should remove the caches manually. """ def __init__(self, @@ -82,31 +86,33 @@ def __init__(self, config: Optional[BuildCacheConfig] = None): if config.max_records < 1: raise ValueError("max_records should be greater than 0") + def free_storage_in_gb(self) -> float: + ''' Get the free storage capacity of the cache. ''' + # measure the root directory + if self.cache_root.parent.exists(): + usage = shutil.disk_usage(self.cache_root.parent) + return usage.free / 1024**3 + return 0 + def get_engine_building_cache_stage(self, build_config: BuildConfig, model_path: Optional[Path] = None, + force_rebuild: bool = False, **kwargs) -> 'CachedStage': ''' Get the build step for engine building. ''' - from tensorrt_llm.hlapi.llm_utils import \ - _ModelFormatKind # avoid cyclic import - force_rebuild = False - if parallel_config := kwargs.get('parallel_config'): - if parallel_config.auto_parallel: - force_rebuild = True - if model_format := kwargs.get('model_format'): - if model_format is not _ModelFormatKind.HF: - force_rebuild = True - - build_config_str = BuildCache.prune_build_config_for_cache_key( - build_config.to_dict()) + build_config_str = json.dumps(self.prune_build_config_for_cache_key( + build_config.to_dict()), + sort_keys=True) + + kwargs_str = json.dumps(kwargs, sort_keys=True) return CachedStage(parent=self, kind=CacheRecord.Kind.Engine, cache_root=self.cache_root, force_rebuild=force_rebuild, - inputs=[build_config_str, model_path, kwargs]) + inputs=[build_config_str, model_path, kwargs_str]) def prune_caches(self, has_incoming_record: bool = False): ''' @@ -246,7 +252,7 @@ def get_cache_metadata(self) -> dict: } return res - def cache_hitted(self) -> bool: + def is_cached(self) -> bool: ''' Check if the product of the build step is in the cache ''' @@ -265,22 +271,36 @@ def cache_hitted(self) -> bool: @contextlib.contextmanager def write_guard(self): - ''' - Write the filelock to indicate that the build step is in progress + ''' Guard the cache writing process. + + The cache writing process should be atomic, so the filelock is used to protect the cache writing process. And + the cache metadata will be written to the cache directory. + + Args: + final_engien_dir: the final engine directory ''' self.parent.prune_caches(has_incoming_record=True) target_dir = self.get_cache_path() - target_dir.mkdir(parents=True, exist_ok=True) + + # To avoid the cache modification conflict, a dummy directory is used to write the cache, and then rename it to + # the target directory + dummy_target_dir = Path(f"{target_dir.parent}/{target_dir.name}.dummy") + + dummy_target_dir.mkdir(parents=True, exist_ok=True) # TODO[chunweiy]: deal with the cache modification conflict - lock = filelock.FileLock(target_dir / '.filelock', timeout=10) + lock = filelock.FileLock(dummy_target_dir / '.filelock', timeout=10) - with open(target_dir / 'metadata.json', 'w') as f: + with open(dummy_target_dir / 'metadata.json', 'w') as f: f.write(json.dumps(self.get_cache_metadata())) - lock.__enter__() - yield target_dir / 'content' - lock.__exit__(None, None, None) + with lock: + yield dummy_target_dir / 'content' + + # If engine building is successful, rename the dummy directory to the target directory + if target_dir.exists(): + shutil.rmtree(target_dir) + shutil.move(dummy_target_dir, target_dir) @dataclass(unsafe_hash=True) diff --git a/tensorrt_llm/hlapi/llm.py b/tensorrt_llm/hlapi/llm.py index 7750b6d63..023e55db2 100644 --- a/tensorrt_llm/hlapi/llm.py +++ b/tensorrt_llm/hlapi/llm.py @@ -80,6 +80,8 @@ class LLM: dtype(str): The data type for the model weights and activations. + trust_remote_code(bool): Download the model and tokenizer from trust remote code (e.g, Hugging Face) + revision(Optional[str]): The revision of the model. tokenzier_revision(Optional[str]): The revision of the tokenizer. @@ -92,11 +94,13 @@ def __init__(self, skip_tokenizer_init: bool = False, tensor_parallel_size: int = 1, dtype: str = "auto", + trust_remote_code: bool = False, revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, **kwargs: Any): self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor) + self.mpi_session: Optional[MpiSession] = None try: self.args = LlmArgs.from_kwargs( @@ -105,6 +109,7 @@ def __init__(self, skip_tokenizer_init=skip_tokenizer_init, tensor_parallel_size=tensor_parallel_size, dtype=dtype, + trust_remote_code=trust_remote_code, revision=revision, tokenizer_revision=tokenizer_revision, **kwargs) @@ -113,7 +118,6 @@ def __init__(self, f"Failed to parse the arguments for the LLM constructor: {e}") raise e - self.mpi_session: Optional[MpiSession] = None if self.args.parallel_config.is_multi_gpu: if get_device_count() < self.args.parallel_config.world_size: raise RuntimeError( @@ -131,16 +135,21 @@ def __init__(self, self.mpi_session = MpiCommSession( n_workers=self.args.parallel_config.world_size) - # Due to the Executor can only accept a engine path, we need to save the engine to a directory - self._engine_dir: Optional[Path] = None - self._executor: Optional[GenerationExecutor] = None - self._workspace = tempfile.TemporaryDirectory("llm-workspace") + try: + # Due to the Executor can only accept a engine path, we need to save the engine to a directory + self._engine_dir: Optional[Path] = None + self._executor: Optional[GenerationExecutor] = None + self._workspace = tempfile.TemporaryDirectory("llm-workspace") - self.runtime_context: Optional[_ModelRuntimeContext] = None - self.llm_build_stats = LlmBuildStats() + self.runtime_context: Optional[_ModelRuntimeContext] = None + self.llm_build_stats = LlmBuildStats() - self._build_model() - self._tokenizer = self._try_load_tokenizer() + self._build_model() + self._tokenizer = self._try_load_tokenizer() + except Exception as e: + if self.mpi_session is not None: + self.mpi_session.shutdown() + raise e exception_handler.register(self, '_shutdown') @@ -314,7 +323,7 @@ def _build_model(self): elif self.args.build_config.plugin_config.lora_plugin: engine_config = EngineConfig.from_json_file(self._engine_dir / "config.json") - lora_config = self.args.build_config.lora_config + lora_config = engine_config.build_config.lora_config max_lora_rank = lora_config.max_lora_rank num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \ len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) @@ -352,7 +361,8 @@ def _try_load_tokenizer(self) -> Optional[TokenizerBase]: if self.runtime_context is not None: return self.runtime_context.tokenizer - return ModelLoader.load_hf_tokenizer(self.args.model_dir) + return ModelLoader.load_hf_tokenizer(self.args.model_dir, + self.args.trust_remote_code) @property def tokenizer(self) -> Optional[TokenizerBase]: diff --git a/tensorrt_llm/hlapi/llm_utils.py b/tensorrt_llm/hlapi/llm_utils.py index c8c5c447b..431450783 100644 --- a/tensorrt_llm/hlapi/llm_utils.py +++ b/tensorrt_llm/hlapi/llm_utils.py @@ -56,7 +56,8 @@ from .tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import from .utils import (GpuArch, download_hf_model, download_hf_pretrained_config, - print_colored, print_traceback_on_error, set_docstring) + get_directory_size_in_gb, print_colored, + print_traceback_on_error, set_docstring) @dataclass @@ -182,7 +183,7 @@ def from_module(cls, module: Module): LLMARGS_STAET_DOCSTRING = "The arguments for constructing a LLM instance.\n\nParameters:\n" -# The arguments locate in LLM class's explicit arg-list. +# The arguments locate in LLM class's explicit arg-list, these will not be included in LLM class's apidocs. LLMARGS_EXPLICIT_ARGS_DOCSTRING = r""" model (str or Path): The model name or a local model directory. Note that if the value could be both a model name or a local model directory, @@ -213,6 +214,15 @@ def from_module(cls, module: Module): load_format (Literal['auto', 'dummy'], default='auto'): The format of the model weights to load. * 'auto' will try to load the weights from the provided checkpoint. * 'dummy' will initialize the weights with random values, which is mainly for profiling. +""" + +# The arguments locate in LLM class's kwargs, and will be concatenated to LLM class's apidocs. +# The parallel_config is replaced by {auto_parallel, pipeline_parallel_size} arguments, the tensor_parallel_size is +# already in the LLM class's apidocs, so it is not included here. +LLMARGS_REMAINING_ARGS_DOCSTRING = r""" + auto_parallel (bool, default=False): Enable auto parallel mode. + + pipeline_parallel_size (int, default=1): The pipeline parallel size. enable_lora (bool, default=False): Enable LoRA adapters. @@ -221,9 +231,7 @@ def from_module(cls, module: Module): max_loras (int, default=4): Maximum number of LoRA adapters to be stored in GPU memory. max_cpu_loras (int, default=4): Maximum number of LoRA adapters to be stored in CPU memory. -""" -LLMARGS_REMAINING_ARGS_DOCSTRING = r""" build_config (BuildConfig, default=BuildConfig()): The build configuration for the model. Default is an empty BuildConfig instance. @@ -266,6 +274,8 @@ def from_module(cls, module: Module): Default is None. enable_tqdm (bool, default=False): Whether to display a progress bar during model building. + + trust_remote_code (bool, default=False): Whether to trust remote code when downloading model and tokenizer from Hugging Face. """ @@ -338,12 +348,13 @@ class LlmArgs: # Display the model building progress bar enable_tqdm: bool = False + trust_remote_code: bool = False + def __post_init__(self): # NOTE: this is only for the compatibility with the old API, and will be removed in the future # chunked context is disabled by default, and it is recommended to keep it enabled. # The underlying implementation might disable it if it is not supported. self.enable_chunked_context: bool = False - # TODO[chunweiy]: Enable this option in the future # Currently we want HLAPI to be consistent with the lower APIs in the model building, thus disable this to avoid # magics. @@ -381,7 +392,6 @@ def __post_init__(self): @classmethod def from_kwargs(cls, **kwargs) -> "LlmArgs": LlmArgs._check_executor_config_options_consistency() - parallel_config = _ParallelConfig( tp_size=kwargs.pop('tensor_parallel_size', 1), pp_size=kwargs.pop('pipeline_parallel_size', 1), @@ -1040,6 +1050,7 @@ def copy_hf_tokenizer_data_to_engine_dir(): # supports end-to-end task. # This is only for HF model for now, not available for users' customized tokenizers. import shutil + for name in os.listdir(model_dir): src = os.path.join(model_dir, name) dst = os.path.join(engine_dir, name) @@ -1102,7 +1113,8 @@ def _download_hf_model(self): def _load_model_from_hf(self): ''' Load a TRT-LLM model from a HF model. ''' assert self._model_dir is not None - model_cls = AutoModelForCausalLM.get_trtllm_model_class(self._model_dir) + model_cls = AutoModelForCausalLM.get_trtllm_model_class( + self._model_dir, self.llm_args.trust_remote_code) if self.llm_args.load_format == 'dummy': config = model_cls.config_class.from_hugging_face( str(self._model_dir), @@ -1123,6 +1135,7 @@ def _load_model_from_hf(self): mapping=self.mapping, quant_config=self.llm_args.quant_config, **self.llm_args.calib_config.to_dict(), + trust_remote_code=self.llm_args.trust_remote_code, ) if self.llm_args.parallel_config.is_multi_gpu: mpi_barrier() @@ -1136,6 +1149,7 @@ def _load_model_from_hf(self): quant_config=self.llm_args.quant_config, load_model_on_cpu= True, # TODO:TRTLLM-195 to enhance the weights loading memory usage and chose best location + trust_remote_code=self.llm_args.trust_remote_code, **self.convert_checkpoint_options, ) @@ -1177,13 +1191,17 @@ def _build_engine(self): self.build_config, BuildConfig), f"build_config is not set yet: {self.build_config}" - self.build_config.update(auto_parallel_config=self.auto_parallel_config) - self.build_config.update_kv_cache_type(self._model_info.architecture) + # avoid the original build_config is modified, avoid the side effect + copied_build_config = copy.deepcopy(self.build_config) + + copied_build_config.update( + auto_parallel_config=self.auto_parallel_config) + copied_build_config.update_kv_cache_type(self._model_info.architecture) if self.auto_parallel_config.enabled: self.model.config.mapping.rank = self.rank assert self.model is not None, "model is loaded yet." - engine = build(self.model, self.build_config) + engine = build(self.model, copied_build_config) self._engine_buffer = engine.engine self._engine_config = engine.config @@ -1227,14 +1245,16 @@ def load_extra_build_configs_from_engine( return Namespace(**build_config) @staticmethod - def load_hf_tokenizer(model_dir) -> Optional[TransformersTokenizer]: + def load_hf_tokenizer(model_dir, + trust_remote_code) -> Optional[TransformersTokenizer]: try: - return TransformersTokenizer.from_pretrained(model_dir, - legacy=False, - padding_side='left', - truncation_side='left', - trust_remote_code=True, - use_fast=True) + return TransformersTokenizer.from_pretrained( + model_dir, + legacy=False, + padding_side='left', + truncation_side='left', + trust_remote_code=trust_remote_code, + use_fast=True) except Exception as e: logger.error(f"Failed to load tokenizer from {model_dir}: {e}") return None @@ -1286,11 +1306,11 @@ def __call__(self) -> Path: self._hf_model_dir = self.llm_args.model_dir if self.llm_args.model_format is _ModelFormatKind.HF else None self.engine_cache_stage = self._get_engine_cache_stage() - if self.engine_cache_stage.cache_hitted(): + if self.engine_cache_stage.is_cached(): + self.llm_build_stats.cache_hitted = True print_colored( f"Reusing cached engine in {self.engine_cache_stage.get_engine_path()}\n\n", 'grey') - self.llm_build_stats.cache_hitted = True self.llm_args.model = self.engine_cache_stage.get_engine_path() self.llm_build_stats.engine_dir = self.llm_args.model_dir return self.llm_build_stats.engine_dir @@ -1317,68 +1337,51 @@ def build_cache_enabled(self) -> bool: ) and not self.llm_args.parallel_config.auto_parallel def _get_engine_cache_stage(self) -> CachedStage: - ''' - Get the cache stage for engine building. - ''' + ''' Get the cache stage for engine building. ''' build_cache = BuildCache(self.llm_args.enable_build_cache) assert self._hf_model_dir is not None, "HF model dir is required for cache key." - dummy_build_config = CachedModelLoader.get_final_build_config( - self.llm_args, self._hf_model_dir) + + def serialize(d) -> str: + dic = asdict(d) if not isinstance( + d, PretrainedConfig) else d.to_dict() + return json.dumps(dic, sort_keys=True) + + parallel_config = self.llm_args.parallel_config + + force_rebuild = False + if parallel_config.auto_parallel: + force_rebuild = True + if self.llm_args.model_format is not _ModelFormatKind.HF: + force_rebuild = True return build_cache.get_engine_building_cache_stage( - build_config=dummy_build_config, + build_config=self.llm_args.build_config, model_path=self._hf_model_dir, - # for PretrainedConfig - parallel_config=self.llm_args.parallel_config, + force_rebuild=force_rebuild, # Other configs affecting the engine building - quant_config=self.llm_args.quant_config) - - @staticmethod - def get_final_build_config(llm_args: LlmArgs, - model_dir: Path) -> BuildConfig: - ''' - Get the build_config for cache key. The tricky part is that, the build_config will be altered in `build()`, - but we need a final version of build_config before `build()` is called for cache key. - - Args: - llm_args: The LlmArgs for building the model. - model_dir: The path to the local HF model. - ''' + parallel_config=serialize(parallel_config), + pretrained_config=serialize(self.get_pretrained_config()), + quant_config=serialize(self.llm_args.quant_config), + ) - # This is only needed by BuildCache for cache key - # The build() doesn't need the real model instance to get a updated BuildConig. What is really needed is the - # dtype. That's why the model will be downloaded from HF if necessary to get the accurate dtype. - - pretrained_config = AutoConfig.from_hugging_face( - model_dir, - mapping=Mapping(world_size=llm_args.parallel_config.world_size, - tp_size=llm_args.parallel_config.tp_size, - pp_size=llm_args.parallel_config.pp_size), - quant_config=llm_args.quant_config, - dtype=llm_args.dtype) - - @dataclass - class DummyModel: - # This is only used for getting the updated BuildConfig from build() without actually loading the whole - # pretrained model to save overhead and memory. - config: PretrainedConfig - - # dry_run to get the updated build_config for cache key. The build_config is modified within build(), so using - # a build_config before build() is not correct for cache key, so we need to get the build_config after build() - # in dry_run mode. - dummy_model = DummyModel(pretrained_config) - dummy_build_config = copy.copy(llm_args.build_config) - dummy_build_config.dry_run = True - updated_build_config = build(dummy_model, - dummy_build_config, - return_build_config=True) - return updated_build_config + def get_pretrained_config(self) -> PretrainedConfig: + ''' Get the PretrainedConfig for cache key. + NOTE, this is not the HF model's config, but the TRT-LLM's config. We use this as a generic information for + HF and other models. ''' + assert self._hf_model_dir is not None + return AutoConfig.from_hugging_face( + self._hf_model_dir, + mapping=Mapping(world_size=self.llm_args.parallel_config.world_size, + tp_size=self.llm_args.parallel_config.tp_size, + pp_size=self.llm_args.parallel_config.pp_size), + quant_config=self.llm_args.quant_config, + dtype=self.llm_args.dtype) def _build_model(self) -> Path: model_format = self.llm_args.model_format - def build_task(): + def build_task(engine_dir: Path): if model_format is not _ModelFormatKind.TLLM_ENGINE: model_loader_kwargs = { 'llm_args': self.llm_args, @@ -1391,22 +1394,44 @@ def build_task(): # The engine_dir:Path will be stored to MPINodeState.state build_infos = self.mpi_session.submit_sync( CachedModelLoader._node_build_task, - engine_dir=self.get_engine_dir(), + engine_dir=engine_dir, **model_loader_kwargs) self.llm_build_stats.build_steps_info = build_infos[0] else: # single-gpu with ModelLoader(**model_loader_kwargs) as model_loader: - model_loader(self.get_engine_dir()) + model_loader(engine_dir=engine_dir) release_gc() + has_storage = True if self.build_cache_enabled: - with self.engine_cache_stage.write_guard(): - build_task() - return self.get_engine_dir() - else: - build_task() + try: + # TODO[chunweiy]: Cover the case when the model is from HF model hub. + if self.llm_args.is_local_model: + # This is not perfect, but will make build-cache much more robust. + has_storage = self.engine_cache_stage.parent.free_storage_in_gb( + ) >= get_directory_size_in_gb(self.llm_args.model_dir) + except ValueError: + has_storage = False + except Exception as e: + logger.error(e) + has_storage = False + + if has_storage: + with self.engine_cache_stage.write_guard() as engine_dir: + build_task(engine_dir) + self.llm_build_stats.cache_hitted = True + + else: + print_colored( + "The cache directory is too small, build-cache is disabled.\n", + 'grey') + self.llm_build_stats.cache_hitted = False + self.llm_build_stats.cache_info = "The cache root directory is too small." + + if not (has_storage and self.build_cache_enabled): + build_task(self.get_engine_dir()) return self.get_engine_dir() @@ -1437,6 +1462,7 @@ class LlmBuildStats: ''' LlmBuildStats is the statistics for the LLM model building. ''' # Whether the cache is hitted for the engine cache_hitted: bool = False + cache_info: Optional[str] = None model_from_hf_hub: bool = False diff --git a/tensorrt_llm/hlapi/mpi_session.py b/tensorrt_llm/hlapi/mpi_session.py index 44c58ebaa..ad6c0b8b6 100644 --- a/tensorrt_llm/hlapi/mpi_session.py +++ b/tensorrt_llm/hlapi/mpi_session.py @@ -95,7 +95,7 @@ def submit_sync(self, task: (...), *args, **kwargs) -> List[Any]: def shutdown(self): if self.mpi_pool is not None: - self.mpi_pool.shutdown() + self.mpi_pool.shutdown(wait=False) self.mpi_pool = None def _start_mpi_pool(self): @@ -155,7 +155,7 @@ def submit_sync(self, task: (...), *args, **kwargs) -> List[Any]: def shutdown(self): if self.mpi_pool is not None: - self.mpi_pool.shutdown() + self.mpi_pool.shutdown(wait=False) self.mpi_pool = None def _start_mpi_pool(self): diff --git a/tensorrt_llm/hlapi/openai_protocol.py b/tensorrt_llm/hlapi/openai_protocol.py index f235d5a43..21d6e1f41 100644 --- a/tensorrt_llm/hlapi/openai_protocol.py +++ b/tensorrt_llm/hlapi/openai_protocol.py @@ -479,6 +479,7 @@ def to_sampling_params(self) -> SamplingParams: stop_token_ids=self.stop_token_ids, stop=self.stop, include_stop_str_in_output=self.include_stop_str_in_output, + return_log_probs=self.logprobs, ) if self.min_p > 0: sampling_params.top_p_min = self.min_p @@ -517,8 +518,9 @@ def check_tool_choice(cls, data): @model_validator(mode="before") @classmethod def check_logprobs(cls, data): - if "top_logprobs" in data or "logprobs" in data: - raise ValueError("returning log probs is not supported") + top_logprobs = data.get("top_logprobs") + if top_logprobs is not None and top_logprobs > 0: + raise ValueError("top_logprobs is not supported") return data @model_validator(mode="before") diff --git a/tensorrt_llm/hlapi/utils.py b/tensorrt_llm/hlapi/utils.py index 0d3034a9a..21776591b 100644 --- a/tensorrt_llm/hlapi/utils.py +++ b/tensorrt_llm/hlapi/utils.py @@ -201,6 +201,24 @@ def _get_stop_words(self) -> List[List[int]]: "please call the setup method.") return words + self._stop_word_ids + def _get_stop_reasons_and_words( + self) -> List[Tuple[Union[str, int], List[int]]]: + stop_reasons = [] + if self.stop_token_ids is not None: + stop_reasons.extend(self.stop_token_ids) + if self.stop is not None: + if isinstance(self.stop, str): + stop_reasons.append(self.stop) + else: + stop_reasons.extend(self.stop) + stop_words = self._get_stop_words() + if len(stop_reasons) != len(stop_words): + raise RuntimeError( + f"The number of {self.__class__.__name__}.stop_token_ids ({self.stop_token_ids}) " + f"and {self.__class__.__name__}.stop ({self.stop}) are inconsistent with the " + f"processed stop_words ({stop_words}).") + return list(zip(stop_reasons, stop_words)) + def _get_sampling_config(self) -> tllme.SamplingConfig: expected_fields = [ "beam_width", "top_k", "top_p", "top_p_min", "top_p_reset_ids", @@ -407,6 +425,18 @@ def decorator(fn): return decorator +def get_directory_size_in_gb(directory: Path) -> float: + """ Get the size of the directory. """ + if not (directory.is_dir() and directory.exists()): + raise ValueError(f"{directory} is not a directory.") + total_size = 0 + for dirpath, dirnames, filenames in os.walk(directory): + for f in filenames: + fp = os.path.join(dirpath, f) + total_size += os.path.getsize(fp) + return total_size / 1024**3 # GB + + class ManagedThread(threading.Thread): """ A thread that will put exceptions into an external queue if the task fails. @@ -420,8 +450,12 @@ class ManagedThread(threading.Thread): **kwargs: The arguments to pass to the task """ - def __init__(self, task: Callable[..., bool], error_queue: Queue, **kwargs): - super().__init__() + def __init__(self, + task: Callable[..., bool], + error_queue: Queue, + name: Optional[str] = None, + **kwargs): + super().__init__(name=name) self.task = task self.error_queue = error_queue self.kwargs = kwargs @@ -437,7 +471,8 @@ def run(self): except Exception as e: logger.error(f"Error in thread {self.name}: {e}") self.error_queue.put(e) - break + + logger.info(f"Thread {self.name} stopped.") def stop(self): self.stop_event.set() diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index 995425bc1..b7ad79cff 100644 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -251,11 +251,13 @@ def __init__(self, kv_cache_block_offsets: Tensor = None, host_kv_cache_block_offsets: Tensor = None, host_kv_cache_pool_pointers: Tensor = None, + host_kv_cache_pool_mapping: Tensor = None, cache_indirection: Tensor = None, past_key_value_length: Tensor = None, cross_kv_cache_block_offsets: Tensor = None, host_cross_kv_cache_block_offsets: Tensor = None, - host_cross_kv_cache_pool_pointers: Tensor = None): + host_cross_kv_cache_pool_pointers: Tensor = None, + host_cross_kv_cache_pool_mapping: Tensor = None): self.past_key_value = past_key_value self.host_past_key_value_lengths = host_past_key_value_lengths self.host_max_attention_window_sizes = host_max_attention_window_sizes @@ -263,9 +265,11 @@ def __init__(self, self.kv_cache_block_offsets = kv_cache_block_offsets self.host_kv_cache_block_offsets = host_kv_cache_block_offsets self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers + self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping self.cross_kv_cache_block_offsets = cross_kv_cache_block_offsets self.host_cross_kv_cache_block_offsets = host_cross_kv_cache_block_offsets self.host_cross_kv_cache_pool_pointers = host_cross_kv_cache_pool_pointers + self.host_cross_kv_cache_pool_mapping = host_cross_kv_cache_pool_mapping self.cache_indirection = cache_indirection # self.past_key_value_length = past_key_value_length @@ -349,7 +353,8 @@ def __init__(self, max_attn_value=0.0, block_sparse_params=None, use_implicit_relative_attention=False, - reorder=False): + reorder=False, + layer_idx_in_cache_pool=None): super().__init__() self.local_layer_idx = local_layer_idx @@ -357,6 +362,7 @@ def __init__(self, self.attention_mask_type = attention_mask_type self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size self.num_kv_heads = num_kv_heads + self.layer_idx_in_cache_pool = layer_idx_in_cache_pool if layer_idx_in_cache_pool is not None else local_layer_idx assert num_attention_heads % tp_size == 0, \ "num_attention_heads must be divisible by tp_size" self.num_attention_heads = num_attention_heads // tp_size @@ -857,8 +863,7 @@ def compute_cross_qkv(encoder_output): embed_positions_short_factors_for_attention_plugin, concat([0, 0, 0]), concat([ - max(attention_params.sequence_length, - self.original_max_position_embeddings), + self.max_position_embeddings, self.rotary_embedding_dim // 2, 2 ])) long = slice( @@ -866,8 +871,7 @@ def compute_cross_qkv(encoder_output): embed_positions_long_factors_for_attention_plugin, concat([0, 0, 0]), concat([ - max(attention_params.sequence_length, - self.original_max_position_embeddings), + self.max_position_embeddings, self.rotary_embedding_dim // 2, 2 ])) short = short.view((1, -1)) @@ -916,6 +920,7 @@ def compute_cross_qkv(encoder_output): layer_idx=self.local_layer_idx, num_heads=self.num_attention_heads, num_kv_heads=self.num_attention_kv_heads, + layer_idx_in_cache_pool=self.layer_idx_in_cache_pool, hidden_size_per_head=self.attention_head_size, q_scaling=self.q_scaling, rotary_embedding_dim=self.rotary_embedding_dim, @@ -956,6 +961,9 @@ def compute_cross_qkv(encoder_output): host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers if not self.cross_attention else kv_cache_params.host_cross_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping if not self.cross_attention else + kv_cache_params.host_cross_kv_cache_pool_mapping, do_cross_attention=self.cross_attention, cross_qkv=cross_qkv, cross_qkv_length=attention_params.encoder_max_input_length, @@ -1702,6 +1710,8 @@ def forward(self, host_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping, do_cross_attention=self.cross_attention, cross_qkv=None, cross_qkv_length=attention_params.encoder_max_input_length, diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 00507a118..18ed817e1 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -244,6 +244,15 @@ def load_hf_lora( if len(lora_config.lora_target_modules) == 0: lora_config.lora_target_modules = lora_loader.get_target_modules( trtllm_modules_to_hf_modules) + if len(lora_config.lora_target_modules) == 0: + raise ValueError( + "lora_target_modules is empty. " + "Please specify lora_target_modules or provide lora_dir to infer lora_target_modules." + ) + + missing_qkv_modules = LoraManager.get_missing_qkv_modules( + lora_config.lora_target_modules) + lora_config.lora_target_modules.extend(missing_qkv_modules) if lora_loader.is_valid: config = model.config diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index 15ee4ed52..7f3a66df9 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -22,7 +22,6 @@ from .cogvlm.model import CogVLMForCausalLM from .dbrx.config import DbrxConfig from .dbrx.model import DbrxForCausalLM -from .deci.model import DeciLMForCausalLM from .deepseek_v1.model import DeepseekForCausalLM from .dit.model import DiT from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder @@ -44,6 +43,7 @@ from .modeling_utils import (PretrainedConfig, PretrainedModel, SpeculativeDecodingMode) from .mpt.model import MPTForCausalLM, MPTModel +from .nemotron_nas.model import DeciLMForCausalLM from .opt.model import OPTForCausalLM, OPTModel from .phi3.model import Phi3ForCausalLM, Phi3Model from .phi.model import PhiForCausalLM, PhiModel @@ -127,6 +127,7 @@ 'Phi3ForCausalLM': Phi3ForCausalLM, 'Phi3VForCausalLM': Phi3ForCausalLM, 'Phi3SmallForCausalLM': Phi3ForCausalLM, + 'PhiMoEForCausalLM': Phi3ForCausalLM, 'MambaForCausalLM': MambaForCausalLM, 'GPTNeoXForCausalLM': GPTNeoXForCausalLM, 'GPTJForCausalLM': GPTJForCausalLM, diff --git a/tensorrt_llm/models/automodel.py b/tensorrt_llm/models/automodel.py index 974064305..a65781a88 100644 --- a/tensorrt_llm/models/automodel.py +++ b/tensorrt_llm/models/automodel.py @@ -14,6 +14,7 @@ def from_hugging_face(hf_model_or_dir, quant_config: Optional[QuantConfig] = None, **kwargs): import transformers + hf_config = transformers.AutoConfig.from_pretrained( hf_model_or_dir, trust_remote_code=True) hf_arch = hf_config.architectures[0] @@ -41,10 +42,11 @@ def from_hugging_face(hf_model_or_dir, class AutoModelForCausalLM: @staticmethod - def get_trtllm_model_class(hf_model_or_dir): + def get_trtllm_model_class(hf_model_or_dir, trust_remote_code=False): import transformers + hf_config = transformers.AutoConfig.from_pretrained( - hf_model_or_dir, trust_remote_code=True) + hf_model_or_dir, trust_remote_code=trust_remote_code) hf_arch = hf_config.architectures[0] trtllm_model_cls = MODEL_MAP.get(hf_arch, None) diff --git a/tensorrt_llm/models/baichuan/config.py b/tensorrt_llm/models/baichuan/config.py index 902487c81..1cffbefce 100644 --- a/tensorrt_llm/models/baichuan/config.py +++ b/tensorrt_llm/models/baichuan/config.py @@ -55,13 +55,13 @@ def from_hugging_face( quant_config: Optional[QuantConfig] = None, **kwargs): import transformers - + trust_remote_code = kwargs.pop('trust_remote_code', True) if isinstance(hf_config_or_dir, transformers.PretrainedConfig): hf_config = hf_config_or_dir else: hf_config_dir = str(hf_config_or_dir) hf_config = transformers.AutoConfig.from_pretrained( - hf_config_dir, trust_remote_code=True) + hf_config_dir, trust_remote_code=trust_remote_code) model_version = kwargs.pop('model_version', None) if model_version is None: diff --git a/tensorrt_llm/models/baichuan/convert.py b/tensorrt_llm/models/baichuan/convert.py index 9c5b0dc39..c2395e372 100644 --- a/tensorrt_llm/models/baichuan/convert.py +++ b/tensorrt_llm/models/baichuan/convert.py @@ -470,7 +470,8 @@ def quantize(hf_model_dir: str, output_dir: str, config: BaichuanConfig, device: str = 'cuda', - calib_dataset: str = 'ccdv/cnn_dailymail'): + calib_dataset: str = 'ccdv/cnn_dailymail', + trust_remote_code: bool = True): os.makedirs(output_dir, exist_ok=True) config.to_json_file(os.path.join(output_dir, 'config.json')) @@ -482,10 +483,9 @@ def quantize(hf_model_dir: str, device_map='auto' if device != 'cpu' else 'cpu', torch_dtype='auto' if not config.quantization.use_plugin_sq else torch.float16, - trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(hf_model_dir, - use_fast=False, - trust_remote_code=True) + trust_remote_code=trust_remote_code) + tokenizer = AutoTokenizer.from_pretrained( + hf_model_dir, use_fast=False, trust_remote_code=trust_remote_code) dataset = load_calib_dataset(calib_dataset) act_range = capture_activation_range(hf_model, tokenizer, dataset) diff --git a/tensorrt_llm/models/baichuan/model.py b/tensorrt_llm/models/baichuan/model.py index cd95c57a5..67311253f 100644 --- a/tensorrt_llm/models/baichuan/model.py +++ b/tensorrt_llm/models/baichuan/model.py @@ -182,8 +182,12 @@ def from_hugging_face( hf_model = hf_model_or_dir hf_config_or_dir = hf_model.config else: + trust_remote_code = kwargs.pop('trust_remote_code', True) + hf_model = transformers.AutoModelForCausalLM.from_pretrained( - hf_model_or_dir, trust_remote_code=True, torch_dtype='auto') + hf_model_or_dir, + trust_remote_code=trust_remote_code, + torch_dtype='auto') hf_config_or_dir = hf_model_or_dir config = BaichuanConfig.from_hugging_face(hf_config_or_dir, diff --git a/tensorrt_llm/models/chatglm/config.py b/tensorrt_llm/models/chatglm/config.py index 147f23d6c..c2f0b17e2 100644 --- a/tensorrt_llm/models/chatglm/config.py +++ b/tensorrt_llm/models/chatglm/config.py @@ -75,6 +75,7 @@ def from_hugging_face( quant_config: Optional[QuantConfig] = None, **kwargs): import transformers + trust_remote_code = kwargs.pop('trust_remote_code', True) # load hugging face config if isinstance(hf_config_or_dir, transformers.PretrainedConfig): @@ -82,7 +83,7 @@ def from_hugging_face( else: hf_config_dir = str(hf_config_or_dir) hf_config = transformers.AutoConfig.from_pretrained( - hf_config_dir, trust_remote_code=True) + hf_config_dir, trust_remote_code=trust_remote_code) logits_dtype = kwargs.pop('logits_dtype', 'float32') use_parallel_embedding = kwargs.pop('use_parallel_embedding', False) diff --git a/tensorrt_llm/models/chatglm/convert.py b/tensorrt_llm/models/chatglm/convert.py index 76b919110..9d1c59a61 100644 --- a/tensorrt_llm/models/chatglm/convert.py +++ b/tensorrt_llm/models/chatglm/convert.py @@ -286,9 +286,8 @@ def get_tllm_linear_sq_weight(vals, else: cur_per_channel_value = vals["scale_w_quant_orig"] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array(cur_per_channel_value, - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + + 'per_channel_scale'] = cur_per_channel_value.reshape(col_shape) else: if per_channel: original_weights = np.array(vals["weight.int8.col"]) @@ -382,7 +381,7 @@ def load_weights_from_hf_model(hf_model: AutoModel, if use_smooth_quant: qkv_act_range = act_range.get( f'{prefix}.{attention_attr_name}.query_key_value') - qkv_vals_int8 = generate_int8(qkv_weight.t().numpy(), + qkv_vals_int8 = generate_int8(qkv_weight.t(), qkv_act_range, is_qkv=True, multi_query_mode=True) @@ -430,7 +429,7 @@ def load_weights_from_hf_model(hf_model: AutoModel, if int8_kv_cache: qkv_act_range = act_range.get( f'{prefix}.{attention_attr_name}.query_key_value') - qkv_vals_int8 = generate_int8(qkv_weight.t().numpy(), + qkv_vals_int8 = generate_int8(qkv_weight.t(), qkv_act_range, is_qkv=True, multi_query_mode=True) @@ -448,7 +447,7 @@ def load_weights_from_hf_model(hf_model: AutoModel, f'{prefix}.{attention_attr_name}.dense') dense_smoother = smoother.get( f'{prefix}.{attention_attr_name}.dense') - dense_vals_int8 = generate_int8(attn_dense_weight.t().numpy(), + dense_vals_int8 = generate_int8(attn_dense_weight.t(), dense_act_range, is_qkv=False, multi_query_mode=True) @@ -483,7 +482,7 @@ def load_weights_from_hf_model(hf_model: AutoModel, if use_smooth_quant: fc_act_range = act_range.get(f'{prefix}.mlp.dense_h_to_4h') - fc_vals_int8 = generate_int8(mlp_fc_weight.t().numpy(), + fc_vals_int8 = generate_int8(mlp_fc_weight.t(), fc_act_range, is_qkv=False, multi_query_mode=True) @@ -558,7 +557,7 @@ def load_weights_from_hf_model(hf_model: AutoModel, if use_smooth_quant: proj_act_range = act_range.get(f'{prefix}.mlp.dense_4h_to_h') proj_smoother = smoother.get(f'{prefix}.mlp.dense_4h_to_h') - proj_vals_int8 = generate_int8(mlp_proj_weight.t().numpy(), + proj_vals_int8 = generate_int8(mlp_proj_weight.t(), proj_act_range, is_qkv=False, multi_query_mode=True) @@ -671,7 +670,8 @@ def quantize(hf_model_dir: str, output_dir: str, config: ChatGLMConfig, calib_dataset: str = 'cnn_dailymail', - device: str = 'auto'): + device: str = 'auto', + trust_remote_code: bool = True): ''' Quantize the save the model as TRT-LLM checkpoint to output_dir ''' @@ -694,7 +694,7 @@ def quantize(hf_model_dir: str, device_map = 'auto' if device != "cpu" else 'cpu' hf_model = AutoModel.from_pretrained( hf_model_dir, - trust_remote_code=True, + trust_remote_code=trust_remote_code, torch_dtype='auto' if config.chatglm_version != 'glm' else getattr( torch, config.dtype), device_map=device_map) @@ -703,7 +703,7 @@ def quantize(hf_model_dir: str, "TOKENIZERS_PARALLELISM", "false") tokenizer = AutoTokenizer.from_pretrained( hf_model_dir, - trust_remote_code=True, + trust_remote_code=trust_remote_code, ) dataset = load_calib_dataset(calib_dataset) diff --git a/tensorrt_llm/models/chatglm/model.py b/tensorrt_llm/models/chatglm/model.py index a6b82c928..1db13dbf1 100644 --- a/tensorrt_llm/models/chatglm/model.py +++ b/tensorrt_llm/models/chatglm/model.py @@ -283,6 +283,7 @@ def from_hugging_face( ''' Create a LLaMAForCausalLM object from give parameters ''' load_model_on_cpu = kwargs.pop('load_model_on_cpu', False) + trust_remote_code = kwargs.pop('trust_remote_code', True) config = ChatGLMConfig.from_hugging_face(hf_model_or_dir, dtype=dtype, @@ -295,7 +296,7 @@ def from_hugging_face( device_map = 'auto' if not load_model_on_cpu else 'cpu' hf_model = AutoModel.from_pretrained( hf_model_or_dir, - trust_remote_code=True, + trust_remote_code=trust_remote_code, torch_dtype='auto' if config.chatglm_version != 'glm' else getattr( torch, config.dtype), device_map=device_map) diff --git a/tensorrt_llm/models/convert_utils.py b/tensorrt_llm/models/convert_utils.py index be587fd5f..7d25399fe 100644 --- a/tensorrt_llm/models/convert_utils.py +++ b/tensorrt_llm/models/convert_utils.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -import numpy as np import torch from datasets import load_dataset @@ -404,9 +403,8 @@ def generate_int8( # compute weight scaling factors for fp->int8 and int8->fp if is_qkv and not multi_query_mode: scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max( - dim=-1, keepdims=True)[0].cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, - -1).cpu().numpy() + dim=-1, keepdims=True)[0] + scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, -1) elif is_qkv and multi_query_mode: hidden_dim = weights.shape[0] local_dim = act_range["w"].shape[0] @@ -421,62 +419,67 @@ def generate_int8( scale_w_v.max(dim=0, keepdim=True)[0] ]) - scale_w_orig_quant_t = 127. / scale_w_qkv_t.cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() + scale_w_orig_quant_t = 127. / scale_w_qkv_t + scale_w_orig_quant_c = 127. / act_range["w"] else: - scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() + scale_w_orig_quant_t = 127. / act_range["w"].max() + scale_w_orig_quant_c = 127. / act_range["w"] scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c + scale_w_orig_quant_c = scale_w_orig_quant_c.to(torch.float32) + scale_w_orig_quant_t = scale_w_orig_quant_t.to(torch.float32) + # compute the rest of needed scaling factors - scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) - scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) - scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) + scale_x_orig_quant_t = 127. / act_range["x"].max() + scale_y_orig_quant_t = 127. / act_range["y"].max() + scale_y_quant_orig_t = act_range["y"].max() / 127. scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * scale_w_orig_quant_t) scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * scale_w_orig_quant_c) if is_qkv and not multi_query_mode: - scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, - scale_w_orig_quant_c.shape) - scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, - scale_w_orig_quant_c.shape) + scale_y_accum_quant_t = torch.broadcast_to(scale_y_accum_quant_t, + scale_w_orig_quant_c.shape) + scale_w_quant_orig_t = torch.broadcast_to(scale_w_quant_orig_t, + scale_w_orig_quant_c.shape) if is_qkv and multi_query_mode: - scale_q_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[0], - scale_w_q.shape) - scale_k_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[1], - scale_w_k.shape) - scale_v_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[2], - scale_w_v.shape) - scale_y_accum_quant_t = np.concatenate( + scale_q_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[0], + scale_w_q.shape) + scale_k_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[1], + scale_w_k.shape) + scale_v_y_accum_t = torch.broadcast_to(scale_y_accum_quant_t[2], + scale_w_v.shape) + scale_y_accum_quant_t = torch.concat( [scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t]) - scale_w_quant_orig_t = np.concatenate([ - np.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape), - np.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape), - np.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape) + scale_w_quant_orig_t = torch.concat([ + torch.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape), + torch.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape), + torch.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape) ]) - to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) + to_i8 = lambda x: x.round().clip(-127, 127).to(torch.int8) if is_qkv and multi_query_mode: - scale_w_quant_orig_t_expand = np.ones([weights.shape[-1]]) - scale_w_quant_orig_t_expand[:hidden_dim] = scale_w_quant_orig_t[0] - scale_w_quant_orig_t_expand[hidden_dim:hidden_dim + - kv_dim] = scale_w_quant_orig_t[1] - scale_w_quant_orig_t_expand[-kv_dim:] = scale_w_quant_orig_t[2] - weight_int8 = to_i8(weights * scale_w_quant_orig_t_expand) + if weights.device != scale_w_quant_orig_t.device: + scale_w_quant_orig_t = scale_w_quant_orig_t.to(weights.device) + weight_int8 = to_i8(weights / scale_w_quant_orig_t) else: + if weights.device != scale_w_orig_quant_t.device: + scale_w_orig_quant_t = scale_w_orig_quant_t.to(weights.device) weight_int8 = to_i8(weights * scale_w_orig_quant_t) + if weights.device != scale_w_orig_quant_c.device: + scale_w_orig_quant_c = scale_w_orig_quant_c.to(weights.device) + return { "weight.int8": weight_int8, "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), - "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), - "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), - "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), - "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), - "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), - "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), + "scale_x_orig_quant": scale_x_orig_quant_t.to(torch.float32), + "scale_w_quant_orig": scale_w_quant_orig_t.to(torch.float32), + "scale_w_quant_orig.col": scale_w_quant_orig_c.to(torch.float32), + "scale_y_accum_quant": scale_y_accum_quant_t.to(torch.float32), + "scale_y_accum_quant.col": scale_y_accum_quant_c.to(torch.float32), + "scale_y_quant_orig": scale_y_quant_orig_t.to(torch.float32), } diff --git a/tensorrt_llm/models/enc_dec/model.py b/tensorrt_llm/models/enc_dec/model.py index 3f540e690..52a013d16 100644 --- a/tensorrt_llm/models/enc_dec/model.py +++ b/tensorrt_llm/models/enc_dec/model.py @@ -1160,12 +1160,16 @@ def forward(self, host_cross_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping, cross_kv_cache_block_offsets=kv_cache_params. cross_kv_cache_block_offsets, host_cross_kv_cache_block_offsets=kv_cache_params. host_cross_kv_cache_block_offsets, host_cross_kv_cache_pool_pointers=kv_cache_params. - host_cross_kv_cache_pool_pointers), + host_cross_kv_cache_pool_pointers, + host_cross_kv_cache_pool_mapping=kv_cache_params. + host_cross_kv_cache_pool_mapping), attention_params=attention_params, lora_layer_params=lora_layer_params, cross_kv_cache_gen=cross_kv_cache_gen, @@ -1601,10 +1605,12 @@ def prepare_inputs(self, kv_cache_block_offsets = None host_kv_cache_block_offsets = None host_kv_cache_pool_pointers = None + host_kv_cache_pool_mapping = None cross_kv_cache_block_offsets = None host_cross_kv_cache_block_offsets = None host_cross_kv_cache_pool_pointers = None + host_cross_kv_cache_pool_mapping = None if use_cache: if not paged_kv_cache: @@ -1669,21 +1675,25 @@ def prepare_inputs(self, x for x in max_cross_blocks_per_seq_range[0] ]] - kv_cache_block_offsets = Tensor(name=f'kv_cache_block_offsets', - dtype=trt.int32, - shape=[-1, 2, -1], - dim_range=OrderedDict([ - ('batch_size_beam_width', - [bb_range]), - ('kv', [2]), - ('max_blocks_per_seq', - max_blocks_per_seq_range), - ])) + # TODO(oargov): add support for vgqa, meanwhile assume a single kv cache pool + num_kv_cache_pools = 1 + + kv_cache_block_offsets = Tensor( + name=f'kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_blocks_per_seq', max_blocks_per_seq_range), + ])) host_kv_cache_block_offsets = Tensor( name=f'host_kv_cache_block_offsets', dtype=trt.int32, - shape=[-1, 2, -1], + shape=[num_kv_cache_pools, -1, 2, -1], dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), ('batch_size_beam_width', [bb_range]), ('kv', [2]), ('max_blocks_per_seq', max_blocks_per_seq_range), @@ -1691,17 +1701,26 @@ def prepare_inputs(self, host_kv_cache_pool_pointers = Tensor( name=f'host_kv_cache_pool_pointers', dtype=trt.int64, - shape=[2], + shape=[num_kv_cache_pools, 2], dim_range=OrderedDict([ - ('num_pools', [2]), + ('num_pools_layers', [num_kv_cache_pools]), + ('num_pools_kv', [2]), + ])) + host_kv_cache_pool_mapping = Tensor( + name=f"host_kv_cache_pool_mapping", + dtype=trt.int32, + shape=[num_pp_layers], + dim_range=OrderedDict([ + ('pools_mapping', [num_pp_layers]), ])) # paged blocks for cross kv cross_kv_cache_block_offsets = Tensor( name=f'cross_kv_cache_block_offsets', dtype=trt.int32, - shape=[-1, 2, -1], + shape=[num_kv_cache_pools, -1, 2, -1], dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), ('batch_size_beam_width', [bb_range]), ('kv', [2]), ('max_cross_blocks_per_seq', @@ -1710,8 +1729,9 @@ def prepare_inputs(self, host_cross_kv_cache_block_offsets = Tensor( name=f'host_cross_kv_cache_block_offsets', dtype=trt.int32, - shape=[-1, 2, -1], + shape=[num_kv_cache_pools, -1, 2, -1], dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), ('batch_size_beam_width', [bb_range]), ('kv', [2]), ('max_cross_blocks_per_seq', @@ -1720,10 +1740,18 @@ def prepare_inputs(self, host_cross_kv_cache_pool_pointers = Tensor( name=f'host_cross_kv_cache_pool_pointers', dtype=trt.int64, - shape=[2], + shape=[num_kv_cache_pools, 2], dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), ('num_pools', [2]), ])) + host_cross_kv_cache_pool_mapping = Tensor( + name=f"host_cross_kv_cache_pool_mapping", + dtype=trt.int32, + shape=[num_pp_layers], + dim_range=OrderedDict([ + ('pools_mapping', [num_pp_layers]), + ])) for i in layers_range: past_key_value.append(None) @@ -1737,11 +1765,14 @@ def prepare_inputs(self, kv_cache_block_offsets=kv_cache_block_offsets, host_kv_cache_block_offsets=host_kv_cache_block_offsets, host_kv_cache_pool_pointers=host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=host_kv_cache_pool_mapping, cross_kv_cache_block_offsets=cross_kv_cache_block_offsets, host_cross_kv_cache_block_offsets= host_cross_kv_cache_block_offsets, host_cross_kv_cache_pool_pointers= host_cross_kv_cache_pool_pointers, + host_cross_kv_cache_pool_mapping= + host_cross_kv_cache_pool_mapping, ) attention_params = AttentionParams( diff --git a/tensorrt_llm/models/falcon/config.py b/tensorrt_llm/models/falcon/config.py index aacbc3104..79af97dff 100644 --- a/tensorrt_llm/models/falcon/config.py +++ b/tensorrt_llm/models/falcon/config.py @@ -53,13 +53,14 @@ def from_hugging_face( quant_config: Optional[QuantConfig] = None, **kwargs): import transformers + trust_remote_code = kwargs.pop('trust_remote_code', True) if isinstance(hf_config_or_dir, transformers.PretrainedConfig): hf_config = hf_config_or_dir else: hf_config_dir = str(hf_config_or_dir) hf_config = transformers.AutoConfig.from_pretrained( - hf_config_dir, trust_remote_code=True) + hf_config_dir, trust_remote_code=trust_remote_code) # Falcon-7B config may not have num_kv_heads or n_head_kv. # Although Falcon-180B uses GQA (num_kv_heads=8), its config diff --git a/tensorrt_llm/models/falcon/model.py b/tensorrt_llm/models/falcon/model.py index a8f6458da..627335eeb 100644 --- a/tensorrt_llm/models/falcon/model.py +++ b/tensorrt_llm/models/falcon/model.py @@ -65,8 +65,7 @@ def __init__(self, config: FalconConfig, layer_idx: int): tp_rank=tp_rank, bias=config.bias, position_embedding_type=config.position_embedding_type, - quant_mode=config.quantization.quant_mode, - ) + quant_mode=config.quantization.quant_mode) mlp_hidden_size = hidden_size * 4 if config.intermediate_size is None else config.intermediate_size diff --git a/tensorrt_llm/models/gemma/model.py b/tensorrt_llm/models/gemma/model.py index f34ad71d1..4024b2e88 100644 --- a/tensorrt_llm/models/gemma/model.py +++ b/tensorrt_llm/models/gemma/model.py @@ -78,8 +78,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): tp_size=config.mapping.tp_size, quant_mode=config.quant_mode, q_scaling=q_scaling, - max_attn_value=max_attn_value, - ) + max_attn_value=max_attn_value) mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size diff --git a/tensorrt_llm/models/gemma/smoothquant.py b/tensorrt_llm/models/gemma/smoothquant.py index 294ffe992..640ff7ed9 100644 --- a/tensorrt_llm/models/gemma/smoothquant.py +++ b/tensorrt_llm/models/gemma/smoothquant.py @@ -194,12 +194,12 @@ def get_tllm_linear_sq_weight(vals, results = {} def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): - q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) - q_split = np.split(q, tp_size, axis=-1) - k_split = np.split(k, tp_size, axis=-1) - v_split = np.split(v, tp_size, axis=-1) + q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1) + q_split = torch.split(q, q.shape[-1] // tp_size, dim=-1) + k_split = torch.split(k, q.shape[-1] // tp_size, dim=-1) + v_split = torch.split(v, q.shape[-1] // tp_size, dim=-1) return [ - np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) + torch.concat((q_split[ii], k_split[ii], v_split[ii]), dim=-1) for ii in range(tp_size) ][cur_rank] @@ -207,9 +207,9 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): if per_token: if per_channel: - original_weights = np.array(vals["weight.int8.col"]) + original_weights = vals["weight.int8.col"] else: - original_weights = np.array(vals["weight.int8"]) + original_weights = vals["weight.int8"] local_dim = original_weights.shape[0] head_size = (original_weights.shape[1] - local_dim) // 2 @@ -217,14 +217,14 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: - cur_weights = np.split(original_weights, - tensor_parallel, - axis=cat_dim)[rank] + cur_weights = torch.split(original_weights, + original_weights.shape[-1] // + tensor_parallel, + dim=-1)[rank] if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + results[prefix + 'weight'] = cur_weights.t().contiguous() if smoother_value is None: results[last_prefix] = torch.from_numpy( np.array([1.0], dtype=np.float32)) @@ -233,6 +233,7 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_per_channel_value = vals["scale_w_quant_orig.col"] if smoother_value is None: if multi_query_mode: + cur_per_channel_value = multi_query_split( vals["scale_w_quant_orig.col"], local_dim, head_size, tensor_parallel, rank) @@ -240,7 +241,7 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_per_channel_value = np.split( vals["scale_w_quant_orig.col"], tensor_parallel, - axis=cat_dim)[rank] + axis=-1)[rank] else: cur_per_channel_value = vals["scale_w_quant_orig"] if is_qkv: @@ -249,18 +250,18 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): vals["scale_w_quant_orig"], local_dim, head_size, tensor_parallel, rank) else: - cur_per_channel_value = np.split(vals["scale_w_quant_orig"], - tensor_parallel, - axis=cat_dim)[rank] + cur_per_channel_value = torch.split( + vals["scale_w_quant_orig"], + vals["scale_w_quant_orig"].shape[-1] // tensor_parallel, + dim=-1)[rank] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array(cur_per_channel_value, - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + + 'per_channel_scale'] = cur_per_channel_value.reshape(col_shape) else: if per_channel: - original_weights = np.array(vals["weight.int8.col"]) + original_weights = vals["weight.int8.col"] else: - original_weights = np.array(vals["weight.int8"]) + original_weights = vals["weight.int8"] local_dim = original_weights.shape[0] head_size = (original_weights.shape[1] - local_dim) // 2 @@ -268,14 +269,14 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: - cur_weights = np.split(original_weights, - tensor_parallel, - axis=cat_dim)[rank] + cur_weights = torch.split(original_weights, + original_weights.shape[-1] // + tensor_parallel, + dim=-1)[rank] if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + results[prefix + 'weight'] = cur_weights.t().contiguous() if per_channel: cur_per_channel_value = vals["scale_y_accum_quant.col"] @@ -303,22 +304,19 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): tensor_parallel, axis=cat_dim)[rank] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array([cur_per_channel_value], - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape( + col_shape).contiguous() - results[last_prefix] = torch.from_numpy( - np.array([vals['scale_x_orig_quant']], - dtype=np.float32)).contiguous() + results[last_prefix] = vals['scale_x_orig_quant'].contiguous() - results[prefix + 'act_scale'] = torch.from_numpy( - np.array([[vals["scale_y_quant_orig"]]], - dtype=np.float32)).contiguous() + results[prefix + 'act_scale'] = vals["scale_y_quant_orig"].contiguous() if smoother_value is not None: - cur_smoother_value = np.split(smoother_value, - tensor_parallel, - axis=cat_dim)[rank] + cur_smoother_value = torch.split(smoother_value, + smoother_value.shape[-1] // + tensor_parallel, + dim=cat_dim)[rank] + results[prefix + 'smoother'] = cur_smoother_value.reshape( smoother_shape).contiguous().to(torch.float32) @@ -679,6 +677,8 @@ def convert_hf_model(*, hf_model: "AutoModelForCausalLM", mapping: Mapping, attn_dense_weight = attn_dense_weight.t() int8_weights = generate_int8( attn_dense_weight, act_range.get(prefix + 'self_attn.o_proj')) + # import pdb + # pdb.set_trace() weights.update( get_tllm_linear_sq_weight( int8_weights, @@ -785,7 +785,7 @@ def convert_hf_model(*, hf_model: "AutoModelForCausalLM", mapping: Mapping, smoother_value=smoother[prefix + 'mlp.down_proj'], smoother_shape=[1, intermediate_size // tensor_parallel], rank=mapping.tp_rank, - cat_dim=0)) + cat_dim=-1)) else: mlp_proj_weight = split_matrix_tp(mlp_proj_weight, tensor_parallel, diff --git a/tensorrt_llm/models/generation_mixin.py b/tensorrt_llm/models/generation_mixin.py index cb12289f8..92361e705 100644 --- a/tensorrt_llm/models/generation_mixin.py +++ b/tensorrt_llm/models/generation_mixin.py @@ -14,7 +14,7 @@ # limitations under the License. import math from collections import OrderedDict -from typing import List +from typing import List, Optional import tensorrt as trt @@ -166,28 +166,34 @@ def get_profiles_ranges( } return num_profiles, ranges - def prepare_attention_inputs(self, - *, - max_batch_size, - max_beam_width, - max_input_len, - max_seq_len, - num_kv_heads, - head_size, - num_layers, - kv_dtype, - kv_cache_type: KVCacheType, - num_profiles=1, - enable_ctx_gen_opt_profiles=False, - remove_input_padding=False, - use_gpt_attention_plugin=False, - tokens_per_block=64, - mapping=Mapping(), - streamingllm=False, - attn_layer_idx=None, - opt_batch_size=None, - num_kv_heads_per_layer=None): - + def prepare_attention_inputs( + self, + *, + max_batch_size, + max_beam_width, + max_input_len, + max_seq_len, + num_kv_heads, + head_size, + num_layers, + kv_dtype, + kv_cache_type: KVCacheType, + num_profiles=1, + enable_ctx_gen_opt_profiles=False, + remove_input_padding=False, + use_gpt_attention_plugin=False, + tokens_per_block=64, + mapping=Mapping(), + streamingllm=False, + attn_layer_idx=None, + opt_batch_size=None, + num_kv_heads_per_layer: Optional[List[int]] = None): + + if attn_layer_idx is not None and num_kv_heads_per_layer is not None: + assert len(attn_layer_idx) == len(num_kv_heads_per_layer), ( + f"Expected len(attn_layer_idx) ({len(attn_layer_idx)})" + f" == len(num_kv_heads_per_layer) ({len(num_kv_heads_per_layer)})" + ) default_range = GenerationMixin.default_range if opt_batch_size: @@ -245,23 +251,40 @@ def prepare_attention_inputs(self, max_len_range = [_max_len_range] * num_profiles num_kv_heads = (num_kv_heads + mapping.tp_size - 1) // mapping.tp_size + if num_kv_heads_per_layer is not None: + num_kv_heads_per_layer = [ + (nheads + mapping.tp_size - 1) // mapping.tp_size + for nheads in num_kv_heads_per_layer + ] + layers_range = mapping.pp_layers(num_layers) - num_pp_layers = len(layers_range) if attn_layer_idx is None: attn_layer_idx = [i for i in range(num_layers)] + # layer indices of attention layers local to the current pp rank + local_attn_layers = [i for i in layers_range if i in attn_layer_idx] + # number of attention layers local to previous pp ranks + num_attn_layers_lower_ranks = attn_layer_idx.index(local_attn_layers[0]) past_key_value = [] kv_cache_block_offsets = None host_kv_cache_block_offsets = None host_kv_cache_pool_pointers = None + host_kv_cache_pool_mapping = None if kv_cache_type == KVCacheType.DISABLED: for i in layers_range: past_key_value.append(None) else: if kv_cache_type != KVCacheType.PAGED: - for i in layers_range: + for layer_idx in layers_range: + if layer_idx not in local_attn_layers: + # not an attention layer ==> give it None pkv input + past_key_value.append(None) + continue + + attn_idx = local_attn_layers.index(layer_idx) if num_kv_heads_per_layer is not None: - heads_dim_name = f"num_heads_{attn_layer_idx[i]}" - kv_heads = num_kv_heads_per_layer[i] + heads_dim_name = f"num_heads_{layer_idx}" + kv_heads = num_kv_heads_per_layer[ + num_attn_layers_lower_ranks + attn_idx] else: heads_dim_name = "num_heads" kv_heads = num_kv_heads @@ -274,7 +297,7 @@ def prepare_attention_inputs(self, ('head_size', [head_size] * num_profiles), ]) - kv = Tensor(name=f'past_key_value_{attn_layer_idx[i]}', + kv = Tensor(name=f'past_key_value_{layer_idx}', dtype=kv_dtype, shape=[-1, 2, kv_heads, -1, head_size], dim_range=kv_dim_range) @@ -300,21 +323,28 @@ def prepare_attention_inputs(self, math.ceil(kv_cache_range[0][2] / tokens_per_block) ]] * num_profiles - kv_cache_block_offsets = Tensor(name=f'kv_cache_block_offsets', - dtype=trt.int32, - shape=[-1, 2, -1], - dim_range=OrderedDict([ - ('batch_size_beam_width', - bb_range), - ('kv', [2] * num_profiles), - ('max_blocks_per_seq', - max_blocks_per_seq_range), - ])) + num_kv_cache_pools = 1 if num_kv_heads_per_layer is None else len( + set(num_kv_heads_per_layer[num_attn_layers_lower_ranks: + num_attn_layers_lower_ranks + + len(local_attn_layers)])) + kv_cache_block_offsets = Tensor( + name=f'kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', + [num_kv_cache_pools] * num_profiles), + ('batch_size_beam_width', bb_range), + ('kv', [2] * num_profiles), + ('max_blocks_per_seq', max_blocks_per_seq_range), + ])) host_kv_cache_block_offsets = Tensor( name=f'host_kv_cache_block_offsets', dtype=trt.int32, - shape=[-1, 2, -1], + shape=[num_kv_cache_pools, -1, 2, -1], dim_range=OrderedDict([ + ('num_kv_cache_pools', + [num_kv_cache_pools] * num_profiles), ('batch_size_beam_width', bb_range), ('kv', [2] * num_profiles), ('max_blocks_per_seq', max_blocks_per_seq_range), @@ -322,9 +352,20 @@ def prepare_attention_inputs(self, host_kv_cache_pool_pointers = Tensor( name=f'host_kv_cache_pool_pointers', dtype=trt.int64, - shape=[2], + shape=[num_kv_cache_pools, 2], + dim_range=OrderedDict([ + ('num_pools_layers', + [num_kv_cache_pools] * num_profiles), + ('num_pools_kv', [2] * num_profiles), + ])) + + host_kv_cache_pool_mapping = Tensor( + name=f'host_kv_cache_pool_mapping', + dtype=trt.int32, + shape=[len(local_attn_layers)], dim_range=OrderedDict([ - ('num_pools', [2] * num_profiles), + ('pools_mapping', + [len(local_attn_layers)] * num_profiles), ])) for i in layers_range: @@ -403,9 +444,10 @@ def prepare_attention_inputs(self, host_max_attention_window_sizes = Tensor( name=f'host_max_attention_window_sizes', dtype=trt.int32, - shape=[num_pp_layers], - dim_range=OrderedDict([('num_layers', - [num_pp_layers] * num_profiles)])) + shape=[len(local_attn_layers)], + dim_range=OrderedDict([ + ('num_layers', [len(local_attn_layers)] * num_profiles) + ])) host_sink_token_length = Tensor(name='host_sink_token_length', dtype=trt.int32, @@ -437,6 +479,7 @@ def prepare_attention_inputs(self, 'kv_cache_block_offsets': kv_cache_block_offsets, 'host_kv_cache_block_offsets': host_kv_cache_block_offsets, 'host_kv_cache_pool_pointers': host_kv_cache_pool_pointers, + 'host_kv_cache_pool_mapping': host_kv_cache_pool_mapping, 'context_lengths': context_lengths, 'host_context_lengths': host_context_lengths, 'host_request_types': host_request_types, diff --git a/tensorrt_llm/models/gpt/config.py b/tensorrt_llm/models/gpt/config.py index ca72fea63..ba34ae255 100644 --- a/tensorrt_llm/models/gpt/config.py +++ b/tensorrt_llm/models/gpt/config.py @@ -91,6 +91,7 @@ def from_hugging_face( quant_config: Optional[QuantConfig] = None, **kwargs): import transformers + trust_remote_code = kwargs.pop('trust_remote_code', True) from .convert import get_needed_padding @@ -98,7 +99,7 @@ def from_hugging_face( hf_config = hf_config_or_dir else: hf_config = transformers.AutoConfig.from_pretrained( - hf_config_or_dir, trust_remote_code=True) + hf_config_or_dir, trust_remote_code=trust_remote_code) gpt_variant = kwargs.pop('gpt_variant', None) if gpt_variant is None: diff --git a/tensorrt_llm/models/gpt/convert.py b/tensorrt_llm/models/gpt/convert.py index d8c4ccc76..18c3f4343 100644 --- a/tensorrt_llm/models/gpt/convert.py +++ b/tensorrt_llm/models/gpt/convert.py @@ -270,12 +270,12 @@ def get_tllm_linear_sq_weight(vals, results = {} def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): - q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) - q_split = np.split(q, tp_size, axis=-1) - k_split = np.split(k, tp_size, axis=-1) - v_split = np.split(v, tp_size, axis=-1) + q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1) + q_split = torch.split(q, q.shape[-1] // tp_size, dim=-1) + k_split = torch.split(k, q.shape[-1] // tp_size, dim=-1) + v_split = torch.split(v, q.shape[-1] // tp_size, dim=-1) return [ - np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) + torch.concat((q_split[ii], k_split[ii], v_split[ii]), dim=-1) for ii in range(tp_size) ][cur_rank] @@ -283,9 +283,9 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): if per_token: if per_channel: - original_weights = np.array(vals["weight.int8.col"]) + original_weights = vals["weight.int8.col"] else: - original_weights = np.array(vals["weight.int8"]) + original_weights = vals["weight.int8"] local_dim = original_weights.shape[0] head_size = (original_weights.shape[1] - local_dim) // 2 @@ -299,8 +299,7 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() + results[prefix + 'weight'] = cur_weights.t().contiguous() if smoother_value is None: results[last_prefix] = torch.from_numpy( np.array([1.0], dtype=np.float32)) @@ -329,9 +328,8 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): tensor_parallel, axis=cat_dim)[rank] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array(cur_per_channel_value, - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape( + col_shape).contiguous() else: if per_channel: original_weights = np.array(vals["weight.int8.col"]) @@ -379,17 +377,12 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): tensor_parallel, axis=cat_dim)[rank] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array([cur_per_channel_value], - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape( + col_shape).contiguous() - results[last_prefix] = torch.from_numpy( - np.array([vals['scale_x_orig_quant']], - dtype=np.float32)).contiguous() + results[last_prefix] = vals['scale_x_orig_quant'].contiguous() - results[prefix + 'act_scale'] = torch.from_numpy( - np.array([[vals["scale_y_quant_orig"]]], - dtype=np.float32)).contiguous() + results[prefix + 'act_scale'] = vals["scale_y_quant_orig"].contiguous() if smoother_value is not None: cur_smoother_value = np.split(smoother_value, @@ -479,10 +472,10 @@ def load_weights_from_hf_model(hf_model, if use_smooth_quant: qkv_out_dim = qkv_w.shape[0] - qkv_w_numpy = qkv_w.t().numpy() + qkv_w_t = qkv_w.t() if not multi_query_mode: - qkv_w_numpy = qkv_w_numpy.reshape(hidden_size, 3, hidden_size) - int8_weights = generate_int8(qkv_w_numpy, + qkv_w_t = qkv_w_t.reshape(hidden_size, 3, hidden_size) + int8_weights = generate_int8(qkv_w_t, act_range.get(f'{prefix}.attn.c_attn'), is_qkv=True, multi_query_mode=multi_query_mode) @@ -528,17 +521,16 @@ def load_weights_from_hf_model(hf_model, plugin_weight_only_quant_type)) if int8_kv_cache: - qkv_w_numpy = qkv_w.t().numpy() + qkv_w_t = qkv_w.t() if not multi_query_mode: - qkv_w_numpy = qkv_w_numpy.reshape(hidden_size, 3, hidden_size) - int8_weights = generate_int8(qkv_w_numpy, + qkv_w_t = qkv_w_t.reshape(hidden_size, 3, hidden_size) + int8_weights = generate_int8(qkv_w_t, act_range.get(f'{prefix}.attn.c_attn'), is_qkv=True, multi_query_mode=multi_query_mode) weights[ - f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.from_numpy( - np.array([int8_weights['scale_y_quant_orig']], - dtype=np.float32)).contiguous() + f'{tllm_prex}.attention.kv_cache_scaling_factor'] = int8_weights[ + 'scale_y_quant_orig'].contiguous() # (2) Attention Dense Linear if gpt_variant == 'starcoder2': @@ -557,8 +549,8 @@ def load_weights_from_hf_model(hf_model, attn_dense_w = attn_dense_w.t().contiguous() # transpose for Conv1D if use_smooth_quant: - attn_dense_w_numpy = attn_dense_w.t().numpy() - int8_weights = generate_int8(attn_dense_w_numpy, + attn_dense_w_t = attn_dense_w.t() + int8_weights = generate_int8(attn_dense_w_t, act_range.get(f'{prefix}.attn.c_proj')) # change it to the real smoother if dense layer is applied smooth quant fake_smoother_value = torch.ones([1, hidden_size], @@ -606,8 +598,8 @@ def load_weights_from_hf_model(hf_model, mlp_fc_w = pad_array_up_to(mlp_fc_w, 0, mapping.tp_size) mlp_fc_b = pad_array_up_to(mlp_fc_b, 0, mapping.tp_size) if use_smooth_quant: - mlp_fc_w_numpy = mlp_fc_w.t().numpy() - int8_weights = generate_int8(mlp_fc_w_numpy, + mlp_fc_w_t = mlp_fc_w.t() + int8_weights = generate_int8(mlp_fc_w_t, act_range.get(f'{prefix}.mlp.c_fc')) mlp_fc_b = split(mlp_fc_b, mapping.tp_rank, @@ -681,8 +673,8 @@ def load_weights_from_hf_model(hf_model, if gpt_variant in ['jais']: mlp_proj_w = pad_array_up_to(mlp_proj_w, 1, mapping.tp_size) if use_smooth_quant: - mlp_proj_w_numpy = mlp_proj_w.t().numpy() - int8_weights = generate_int8(mlp_proj_w_numpy, + mlp_proj_w_t = mlp_proj_w.t() + int8_weights = generate_int8(mlp_proj_w_t, act_range.get(f'{prefix}.mlp.c_proj')) # change it to the real smoother if proj layer is applied smooth quant fake_smoother_value = torch.ones([1, 4 * hidden_size], @@ -857,7 +849,8 @@ def quantize(hf_model_dir: str, output_dir: str, config: GPTConfig, device: str = 'cuda', - calib_dataset: str = 'cnn_dailymail'): + calib_dataset: str = 'cnn_dailymail', + trust_remote_code: bool = True): os.makedirs(output_dir, exist_ok=True) config.to_json_file(os.path.join(output_dir, 'config.json')) @@ -875,14 +868,15 @@ def quantize(hf_model_dir: str, hf_model_dir, device_map='auto' if device != 'cpu' else 'cpu', torch_dtype='auto' if not use_smooth_quant else torch.float16, - trust_remote_code=True) + trust_remote_code=trust_remote_code) os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( "TOKENIZERS_PARALLELISM", "false") - tokenizer = AutoTokenizer.from_pretrained(hf_model_dir, - trust_remote_code=True, - use_fast=False, - padding_side='left') + tokenizer = AutoTokenizer.from_pretrained( + hf_model_dir, + trust_remote_code=trust_remote_code, + use_fast=False, + padding_side='left') dataset = load_calib_dataset(calib_dataset) act_range = capture_activation_range(hf_model, tokenizer, dataset) diff --git a/tensorrt_llm/models/gpt/model.py b/tensorrt_llm/models/gpt/model.py index 7e40e5872..bde4dc991 100644 --- a/tensorrt_llm/models/gpt/model.py +++ b/tensorrt_llm/models/gpt/model.py @@ -25,6 +25,8 @@ from ...mapping import Mapping from ...module import Module from ...quantization import QuantMode +from ...quantization.functional import quantize_fp8_per_token +from ...quantization.layers import Fp8RowwiseMLP from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, QuantConfig, check_share_embedding) from .config import GPTConfig @@ -171,6 +173,10 @@ def forward(self, residual = hidden_states hidden_states = self.post_layernorm(hidden_states) + # Quantize per-token for fp8 + if isinstance(self.mlp, Fp8RowwiseMLP): + hidden_states = quantize_fp8_per_token(hidden_states) + hidden_states = self.mlp(hidden_states, lora_layer_params=lora_layer_params) diff --git a/tensorrt_llm/models/gptj/config.py b/tensorrt_llm/models/gptj/config.py index 11efbbbca..b96e7223a 100644 --- a/tensorrt_llm/models/gptj/config.py +++ b/tensorrt_llm/models/gptj/config.py @@ -30,13 +30,14 @@ def from_hugging_face( quant_config: Optional[QuantConfig] = None, **kwargs): import transformers + trust_remote_code = kwargs.pop('trust_remote_code', True) if isinstance(hf_config_or_dir, transformers.PretrainedConfig): hf_config = hf_config_or_dir else: hf_config_dir = str(hf_config_or_dir) hf_config = transformers.AutoConfig.from_pretrained( - hf_config_dir, trust_remote_code=True) + hf_config_dir, trust_remote_code=trust_remote_code) if dtype == 'auto': dtype = getattr(hf_config, 'torch_dtype', None) diff --git a/tensorrt_llm/models/gptj/model.py b/tensorrt_llm/models/gptj/model.py index 36cecfc76..1c5b45fee 100644 --- a/tensorrt_llm/models/gptj/model.py +++ b/tensorrt_llm/models/gptj/model.py @@ -192,8 +192,12 @@ def from_hugging_face( **kwargs) if not use_preloading: + trust_remote_code = kwargs.pop('trust_remote_code', True) + hf_model = transformers.AutoModelForCausalLM.from_pretrained( - hf_model_dir, torch_dtype='auto', trust_remote_code=True) + hf_model_dir, + torch_dtype='auto', + trust_remote_code=trust_remote_code) weights = load_weights_from_hf_model(hf_model, config) check_share_embedding(weights, config) diff --git a/tensorrt_llm/models/llama/config.py b/tensorrt_llm/models/llama/config.py index 4f03cb40f..25ea86e78 100644 --- a/tensorrt_llm/models/llama/config.py +++ b/tensorrt_llm/models/llama/config.py @@ -83,6 +83,8 @@ def from_hugging_face( **kwargs): import transformers + trust_remote_code = kwargs.pop('trust_remote_code', True) + if isinstance(hf_config_or_dir, transformers.PretrainedConfig): hf_config = hf_config_or_dir else: @@ -97,7 +99,7 @@ def from_hugging_face( LlavaLlamaConfig, LlavaLlamaModel) hf_config = transformers.AutoConfig.from_pretrained( - hf_config_dir, trust_remote_code=True) + hf_config_dir, trust_remote_code=trust_remote_code) if hf_config.model_type == "llava": # LLaVA = Vision model + Llama LLM # We load a llava config and use its' text config as llama config diff --git a/tensorrt_llm/models/llama/convert.py b/tensorrt_llm/models/llama/convert.py index a1f3b7c38..cd9bbc63f 100644 --- a/tensorrt_llm/models/llama/convert.py +++ b/tensorrt_llm/models/llama/convert.py @@ -1084,7 +1084,8 @@ def quantize(hf_model_dir: str, output_dir: str, config: LLaMAConfig, device: str = 'cuda', - calib_dataset: str = 'cnn_dailymail'): + calib_dataset: str = 'cnn_dailymail', + trust_remote_code: bool = True): ''' Quantize the save the model as TRT-LLM checkpoint to output_dir ''' @@ -1101,20 +1102,22 @@ def quantize(hf_model_dir: str, assert use_smooth_quant or int8_kv_cache, "Call from_hugging_face when there is no quantization" assert hf_model_dir is not None ## only load and call smooth quant routine once for all ranks - hf_config = AutoConfig.from_pretrained(hf_model_dir, trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(hf_model_dir, + trust_remote_code=trust_remote_code) assert "llava" not in hf_config.model_type, "Smooth quant llava/vila/llava_next is not supported yet" hf_model = AutoModelForCausalLM.from_pretrained( hf_model_dir, device_map='auto' if device != 'cpu' else 'cpu', torch_dtype='auto' if not use_smooth_quant else torch.float16, - trust_remote_code=True) + trust_remote_code=trust_remote_code) os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( "TOKENIZERS_PARALLELISM", "false") - tokenizer = AutoTokenizer.from_pretrained(hf_model_dir, - trust_remote_code=True, - use_fast=False, - padding_side='left') + tokenizer = AutoTokenizer.from_pretrained( + hf_model_dir, + trust_remote_code=trust_remote_code, + use_fast=False, + padding_side='left') dataset = load_calib_dataset(calib_dataset) diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index 0beed1df4..c8e84ba1c 100644 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -297,7 +297,9 @@ def from_hugging_face( load_by_shard = kwargs.pop('load_by_shard', False) load_model_on_cpu = kwargs.pop('load_model_on_cpu', False) quant_ckpt_path = kwargs.pop('quant_ckpt_path', None) - if os.environ.get("TRTLLM_DISABLE_UNIFIED_CONVERTER") is not None: + if os.environ.get("TRTLLM_DISABLE_UNIFIED_CONVERTER" + ) is not None and not isinstance( + hf_model_or_dir, transformers.PreTrainedModel): if "vila" in hf_model_or_dir or "llava" in hf_model_or_dir: hf_model_or_dir = load_hf_llama(hf_model_or_dir, load_model_on_cpu) @@ -326,14 +328,15 @@ def from_hugging_face( config.num_key_value_heads = config.num_key_value_heads // 2 if os.environ.get("TRTLLM_DISABLE_UNIFIED_CONVERTER") is None: custom_dict = {} - if "llava" in hf_model_or_dir: + model_name = hf_model.config.model_type if use_preloading else hf_model_or_dir + if "llava" in model_name: custom_dict = { "transformer": "language_model.model", "lm_head": "language_model.lm_head" } - elif "vila" in hf_model_or_dir: + elif "vila" in model_name: hf_model_dir += "/llm" - elif "exaone" in hf_model_or_dir: + elif "exaone" in model_name: custom_dict = { "transformer": "transformer", "layers": "h", @@ -441,11 +444,14 @@ def quantize( mapping=mapping, quant_config=quant_config, **kwargs) + trust_remote_code = kwargs.pop("trust_remote_code", True) + convert.quantize(hf_model_dir, output_dir, config=config, device=device, - calib_dataset=calib_dataset) + calib_dataset=calib_dataset, + trust_remote_code=trust_remote_code) else: raise ValueError( f"The quant_config ({quant_config}) does not require calibration, try {cls.__name__}.from_hugging_face instead." diff --git a/tensorrt_llm/models/model_weights_loader.py b/tensorrt_llm/models/model_weights_loader.py index 76814df59..e7613f580 100644 --- a/tensorrt_llm/models/model_weights_loader.py +++ b/tensorrt_llm/models/model_weights_loader.py @@ -8,16 +8,17 @@ import torch from safetensors import safe_open from tqdm import tqdm - -from tensorrt_llm.layers.moe import MOEWeightWrapper -from tensorrt_llm.quantization.layers import ( - WeightOnlyGroupwiseQuantColumnLinear, WeightOnlyGroupwiseQuantRowLinear) +from transformers import PreTrainedModel from .._utils import trt_dtype_to_torch +from ..layers.moe import MOEWeightWrapper from ..logger import logger +from ..quantization.layers import (WeightOnlyGroupwiseQuantColumnLinear, + WeightOnlyGroupwiseQuantRowLinear) class ModelWeightsFormat(Enum): + IN_MEMORY = "in_mem" SAFETENSORS = "safetensors" BINARY = "bin" PYTORCH = "pth" @@ -136,9 +137,13 @@ def detect_format(self): else: raise NotImplementedError( "Only safetensors/pickle/binary directories are supported.") + elif isinstance(self.model_dir, dict) or isinstance( + self.model_dir, PreTrainedModel): + self.format = ModelWeightsFormat.IN_MEMORY else: raise NotImplementedError( - "args.model_dir is Neither a directory nor a file!") + "args.model_dir is not a directory, a file or an in-memory module!" + ) def preload(self): # Initialize shards and load_func @@ -146,9 +151,14 @@ def preload(self): shard_files = glob.glob(self.model_dir + "/*." + self.format.value) elif os.path.isfile(self.model_dir): shard_files = [self.model_dir] + elif isinstance(self.model_dir, dict): + shard_files = [self.model_dir] + elif isinstance(self.model_dir, PreTrainedModel): + shard_files = [dict(self.model_dir.named_parameters())] else: raise NotImplementedError( - "args.model_dir is Neither a directory nor a file!") + "args.model_dir is not a directory, a file or an in-memory module!" + ) shard_files.sort() if self.format == ModelWeightsFormat.SAFETENSORS: self.shards = [ @@ -159,6 +169,8 @@ def preload(self): torch.load(f, weights_only=True, map_location="cpu", mmap=True) for f in shard_files ] + elif self.format == ModelWeightsFormat.IN_MEMORY: + self.shards = [shard_files[0]] else: raise NotImplementedError( "Only *.safetensors/*.pth/*.bin files are supported.") @@ -178,7 +190,7 @@ def load_tensor(self, key, tp_size=1, tp_dim=-1, tp_rank=0): if tensor_shape == []: tensor = self.shards[ptr_idx].get_tensor(key).unsqueeze(0) tensor_shape = tensor.shape - elif self.format == ModelWeightsFormat.BINARY or self.format == ModelWeightsFormat.PYTORCH: + else: tensor = self.shards[ptr_idx][key] tensor_shape = tensor.shape @@ -235,10 +247,15 @@ def load(self, tp_dim = sub_module.tp_dim if hasattr(sub_module, "tp_dim") else -1 require_weight_transpose = ( isinstance(sub_module, WeightOnlyGroupwiseQuantColumnLinear) - or isinstance(sub_module, WeightOnlyGroupwiseQuantRowLinear) - ) and tllm_key.endswith("weight") + or isinstance(sub_module, WeightOnlyGroupwiseQuantRowLinear)) if tp_dim >= 0 and require_weight_transpose: - tp_dim = 1 - tp_dim + if sub_module.prequant_scaling_factor is not None: + if tllm_key.endswith("prequant_scaling_factor"): + tp_dim = 1 - tp_dim + elif tllm_key.endswith("weights_scaling_factor"): + tp_dim = -1 + elif tllm_key.endswith("weight"): + tp_dim = 1 - tp_dim tp_size = sub_module.tp_size if hasattr(sub_module, "tp_size") else 1 if skip_tp: tp_dim = -1 diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index f42548ab6..8ef7f32a4 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -6,7 +6,8 @@ from enum import IntFlag, auto from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Union +from typing import (TYPE_CHECKING, Callable, Dict, Generator, List, Optional, + Union) import numpy as np import safetensors @@ -403,6 +404,8 @@ def forward(self, host_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping, cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params, **kwargs) @@ -462,10 +465,14 @@ def from_config(cls, config: PretrainedConfig): return cls(config) @classmethod - def from_checkpoint(cls, - ckpt_dir: str, - rank: Optional[int] = None, - config: Optional[PretrainedConfig] = None): + def from_checkpoint( + cls, + ckpt_dir: str, + rank: Optional[int] = None, + config: Optional[PretrainedConfig] = None, + *, + preprocess_weights_hook: Optional[Callable[[Dict[str, Tensor]], + Dict[str, Tensor]]] = None): if config is None: config = PretrainedConfig.from_json_file( os.path.join(ckpt_dir, 'config.json')) @@ -480,6 +487,10 @@ def from_checkpoint(cls, weights = safetensors.torch.load_file(weights_path) is_checkpoint_pruned = getattr(config, 'is_pruned', False) + + if preprocess_weights_hook is not None: + weights = preprocess_weights_hook(weights) + preprocess_weights(weights, config, from_pruned=is_checkpoint_pruned) model = cls(config) model.load(weights, from_pruned=is_checkpoint_pruned) @@ -629,6 +640,8 @@ def prepare_inputs( 'host_kv_cache_block_offsets'], host_kv_cache_pool_pointers=model_inputs[ 'host_kv_cache_pool_pointers'], + host_kv_cache_pool_mapping=model_inputs[ + 'host_kv_cache_pool_mapping'], cache_indirection=model_inputs['cache_indirection'], ), 'attention_params': @@ -1191,7 +1204,6 @@ def preprocess_weights(weights: Dict[str, torch.Tensor], model.load(weights) """ quant_algo = model_config.quantization.quant_algo - kv_cache_quant_algo = model_config.quantization.kv_cache_quant_algo exclude_modules = model_config.quantization.exclude_modules # INT4_AWQ @@ -1211,7 +1223,9 @@ def preprocess_weights(weights: Dict[str, torch.Tensor], weights[name] = preprocessor(param.T.contiguous(), torch.quint4x2, activation_type).view(dtype) - if name.endswith('weights_scaling_factor'): + if name.endswith('weights_scaling_factor' + ) and param.shape[0] > param.shape[1]: + # TODO: refine on supporting ModelOpt HF-AWQ weights[name] = param.T.contiguous().to( str_dtype_to_torch(model_config.dtype)) if name.endswith('prequant_scaling_factor'): @@ -1266,12 +1280,6 @@ def preprocess_weights(weights: Dict[str, torch.Tensor], exclude_modules=exclude_modules, plugin=True) - # FP8 kv_cache_scaling_factor is always 1.0 - if kv_cache_quant_algo == QuantAlgo.FP8: - for name, param in weights.items(): - if name.endswith('kv_cache_scaling_factor'): - weights[name] = torch.tensor([1.0], dtype=torch.float32) - # Parallel block rowlinear should not have duplicate bias. elif model_config.architecture == 'GPTJForCausalLM': if model_config.mapping.tp_rank > 0: diff --git a/tensorrt_llm/models/deci/__init__.py b/tensorrt_llm/models/nemotron_nas/__init__.py similarity index 100% rename from tensorrt_llm/models/deci/__init__.py rename to tensorrt_llm/models/nemotron_nas/__init__.py diff --git a/tensorrt_llm/models/deci/config.py b/tensorrt_llm/models/nemotron_nas/config.py similarity index 86% rename from tensorrt_llm/models/deci/config.py rename to tensorrt_llm/models/nemotron_nas/config.py index b9accc61e..ca3b4fb1b 100644 --- a/tensorrt_llm/models/deci/config.py +++ b/tensorrt_llm/models/nemotron_nas/config.py @@ -21,11 +21,11 @@ from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.deci.convert import hf_block_config_to_layer_config -from tensorrt_llm.models.deci.layer_config import (AttentionConfig, - AttentionImplementation, - DeciLayerConfig, FFNConfig) from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig +from tensorrt_llm.models.nemotron_nas.convert import \ + hf_block_configs_to_layer_configs +from tensorrt_llm.models.nemotron_nas.layer_config import ( + AttentionConfig, AttentionImplementation, DeciLayerConfig, FFNConfig) class DeciConfig(PretrainedConfig): @@ -60,6 +60,7 @@ def __init__(self, Dict[str, Dict[str, Any]]]]] = None, + block_configs: Optional[object] = None, **kwargs): super().__init__(architecture=architecture, dtype=dtype, @@ -86,7 +87,13 @@ def __init__(self, self.rotary_base = rotary_base self.rotary_scaling = rotary_scaling - if layer_configs is not None: + if block_configs is not None: + assert layer_configs is None + self.layer_configs = hf_block_configs_to_layer_configs( + block_configs, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size) + elif layer_configs is not None: assert len( layer_configs ) == num_hidden_layers, f"num_hidden_layers ({num_hidden_layers}) must match len(layer_configs) ({len(layer_configs)})" @@ -102,6 +109,14 @@ def __init__(self, for layer_idx in range(self.num_hidden_layers) ] + # HACK: this is here since the runtime doesn't parse the layer_configs yet + self.num_kv_heads_per_layer = [] + for layer_idx in range(self.num_hidden_layers): + layer_config = self.get_layer_config(layer_idx) + if layer_config.is_attention_layer: + self.num_kv_heads_per_layer.append( + layer_config.attention.num_key_value_heads) + def _ensure_layer_configs( self, layer_configs: List[Union[DeciLayerConfig, Dict[str, Any]]] ) -> List[DeciLayerConfig]: @@ -154,16 +169,16 @@ def from_hugging_face( hf_config = transformers.AutoConfig.from_pretrained( hf_config_or_dir, trust_remote_code=trust_remote_code) - assert hf_config.model_type == "deci", f"Unsupported model type: {hf_config.model_type}" + assert hf_config.model_type in ( + "deci", + "nemotron-nas"), f"Unsupported model type: {hf_config.model_type}" block_configs = getattr(hf_config, "block_configs", None) if block_configs is not None: - layer_configs = [ - hf_block_config_to_layer_config(block_config, - hf_config.num_attention_heads, - hf_config.hidden_size) - for block_config in block_configs - ] + layer_configs = hf_block_configs_to_layer_configs( + block_configs, + num_attention_heads=hf_config.num_attention_heads, + hidden_size=hf_config.hidden_size) else: # older deci arch num_key_value_heads_per_layer = getattr( diff --git a/tensorrt_llm/models/deci/convert.py b/tensorrt_llm/models/nemotron_nas/convert.py similarity index 77% rename from tensorrt_llm/models/deci/convert.py rename to tensorrt_llm/models/nemotron_nas/convert.py index c6bff772a..06ca34b61 100644 --- a/tensorrt_llm/models/deci/convert.py +++ b/tensorrt_llm/models/nemotron_nas/convert.py @@ -17,8 +17,9 @@ import time from abc import ABC, abstractmethod from contextlib import contextmanager +from dataclasses import asdict from pathlib import Path -from typing import Any, Dict, Iterator, Optional, TypedDict, Union +from typing import Any, Dict, Iterator, List, Optional, TypedDict, Union import safetensors import torch @@ -26,10 +27,9 @@ from tensorrt_llm._utils import pad_vocab_size from tensorrt_llm.logger import logger from tensorrt_llm.models.convert_utils import dup_kv_weight, split -from tensorrt_llm.models.deci.layer_config import (AttentionConfig, - AttentionImplementation, - DeciLayerConfig, FFNConfig, - FFNImplementation) +from tensorrt_llm.models.nemotron_nas.layer_config import ( + AttentionConfig, AttentionImplementation, DeciLayerConfig, FFNConfig, + FFNImplementation) from tensorrt_llm.quantization.mode import QuantAlgo @@ -45,35 +45,39 @@ def _find_multiple(n: int, k: int) -> int: # BlockConfig is a custom class defined inside deci huggingface checkpoints, we can't import it -def hf_block_config_to_layer_config(block_config: "BlockConfig", +def hf_block_config_to_layer_config(block_config: Union["BlockConfig", dict], num_attn_heads: int, hidden_size: int) -> DeciLayerConfig: - attn = block_config.attention - if attn.no_op: + """`block_config` (`Union[BlockConfig, dict]`): A `dict` when exported from `ModelOpt`; A `dataclass` at the HF phase + """ + block_config = block_config if isinstance(block_config, + dict) else asdict(block_config) + attn = block_config["attention"] + if attn["no_op"]: attn_impl = AttentionImplementation.NO_OP num_key_value_heads = None - elif attn.replace_with_linear: + elif attn["replace_with_linear"]: attn_impl = AttentionImplementation.LINEAR num_key_value_heads = None - elif attn.sparsify: + elif attn.get("sparsify", None): raise NotImplementedError("Sparsification is not supported") else: attn_impl = AttentionImplementation.ATTENTION - num_key_value_heads = num_attn_heads // attn.n_heads_in_group + num_key_value_heads = num_attn_heads // attn["n_heads_in_group"] - ffn = block_config.ffn - if ffn.no_op: + ffn = block_config["ffn"] + if ffn["no_op"]: ffn_impl = FFNImplementation.NO_OP intermediate_size = None - elif ffn.replace_with_linear: + elif ffn["replace_with_linear"]: ffn_impl = FFNImplementation.LINEAR intermediate_size = None - elif ffn.sparsify: + elif ffn.get("sparsify", None): raise NotImplementedError("Sparsification is not supported") else: ffn_impl = FFNImplementation.MLP intermediate_size = _ffn_mult_to_intermediate_size( - ffn.ffn_mult, hidden_size) + ffn["ffn_mult"], hidden_size) return DeciLayerConfig(attention=AttentionConfig( impl=attn_impl, num_key_value_heads=num_key_value_heads), @@ -81,6 +85,16 @@ def hf_block_config_to_layer_config(block_config: "BlockConfig", intermediate_size=intermediate_size)) +def hf_block_configs_to_layer_configs( + block_configs: Union["BlockConfig", dict], *, num_attention_heads: int, + hidden_size: int) -> List[DeciLayerConfig]: + return [ + hf_block_config_to_layer_config(block_config, num_attention_heads, + hidden_size) + for block_config in block_configs + ] + + @contextmanager def timed_loading() -> Iterator[None]: tik = time.time() @@ -105,12 +119,31 @@ class SafetensorsIndex(TypedDict): class WeightsLoader(ABC): @abstractmethod + def read_weight(self, name: str) -> torch.Tensor: + ... + def get_weight(self, name: str, tp_dim: TpDim = TpDim.NO_TP, tp_size: int = 1, tp_rank: int = 0) -> torch.Tensor: - ... + weight = self.read_weight(name) + if tp_dim != TpDim.NO_TP: + weight = split(weight, tp_size, tp_rank, dim=tp_dim) + return weight + + def get_kv_weight(self, + name: str, + num_heads: int, + tp_size: int = 1, + tp_rank: int = 0) -> torch.Tensor: + weight = self.read_weight(name) + if tp_size > num_heads: + weight = dup_kv_weight(weight, num_heads, tp_size) + if tp_size > 1: + weight = split(weight, tp_size, tp_rank, dim=0) + + return weight class HFModelWeightsLoader(WeightsLoader): @@ -120,18 +153,11 @@ def __init__(self, *, hf_model: "transformers.PreTrainedModel", self.model_params = dict(hf_model.named_parameters()) self.dtype = getattr(torch, dtype) - def get_weight(self, - name: str, - tp_dim: TpDim = TpDim.NO_TP, - tp_size: int = 1, - tp_rank: int = 0) -> torch.Tensor: + def read_weight(self, name: str) -> torch.Tensor: weight = self.model_params[name] if weight.dtype != self.dtype: weight = weight.to(self.dtype) weight = weight.detach() - - if tp_dim != TpDim.NO_TP: - weight = split(weight, tp_size, tp_rank, dim=tp_dim) return weight @@ -163,37 +189,10 @@ def __init__(self, *, model_dir: Path, dtype: str) -> None: for shard_file in shard_files } - def get_weight(self, - name: str, - tp_dim: TpDim = TpDim.NO_TP, - tp_size: int = 1, - tp_rank: int = 0) -> torch.Tensor: + def read_weight(self, name: str) -> torch.Tensor: shard_filename = self.sharding_map['weight_map'].get( name, self.shard_files[0]) - if tp_dim == TpDim.NO_TP: - res = self.safetensors_files[shard_filename].get_tensor(name) - else: - tensor_slice = self.safetensors_files[shard_filename].get_slice( - name) - tensor_shape = tensor_slice.get_shape() - if len(tensor_shape) == 1: - if tp_dim == TpDim.COLWISE: - slice_width = tensor_shape[0] // tp_size - res = tensor_slice[slice_width * tp_rank:slice_width * - (tp_rank + 1)] - else: # row-wise, but 1-dimensional ==> no tp - res = tensor_slice[:] - else: - assert tensor_shape[ - tp_dim] % tp_size == 0, f"Current weight shape is invalid for tp_size={tp_size}" - slice_width = tensor_shape[tp_dim] // tp_size - if tp_dim == TpDim.COLWISE: - res = tensor_slice[slice_width * tp_rank:slice_width * - (tp_rank + 1), :] - else: - res = tensor_slice[:, slice_width * tp_rank:slice_width * - (tp_rank + 1)] - + res = self.safetensors_files[shard_filename].get_tensor(name) return res.to(self.dtype).contiguous() @@ -245,24 +244,20 @@ def load_weight(name: str, tp_dim: TpDim = TpDim.NO_TP) -> torch.Tensor: f"model.layers.{l}.input_layernorm.weight" ) # input_layernorm - qkv = {} - for comp in ["q", "k", "v"]: - weight_part = load_weight( - f"model.layers.{l}.self_attn.{comp}_proj.weight", - TpDim.COLWISE) - qkv[comp] = weight_part - - if layer_config.attention.num_key_value_heads < mapping.tp_size: - # duplicate the KV heads up to tensor_parallel - qkv["k"] = dup_kv_weight( - qkv["k"], layer_config.attention.num_key_value_heads, - mapping.tp_size) - qkv["v"] = dup_kv_weight( - qkv["v"], layer_config.attention.num_key_value_heads, - mapping.tp_size) - + q = load_weight(f"model.layers.{l}.self_attn.q_proj.weight", + TpDim.COLWISE) + k = loader.get_kv_weight( + f"model.layers.{l}.self_attn.k_proj.weight", + num_heads=layer_config.attention.num_key_value_heads, + tp_size=mapping.tp_size, + tp_rank=mapping.tp_rank) + v = loader.get_kv_weight( + f"model.layers.{l}.self_attn.v_proj.weight", + num_heads=layer_config.attention.num_key_value_heads, + tp_size=mapping.tp_size, + tp_rank=mapping.tp_rank) weights[f'{tllm_prex}.attention.qkv.weight'] = torch.cat( - [qkv["q"], qkv["k"], qkv["v"]], 0) + [q, k, v], 0) weights[f'{tllm_prex}.attention.dense.weight'] = load_weight( f"model.layers.{l}.self_attn.o_proj.weight", TpDim.ROWWISE) # attention.dense @@ -363,3 +358,23 @@ def load_weights_from_hf_safetensors( loader = SafetensorsWeightsLoader(model_dir=model_dir, dtype=config.dtype) logger.info('Loading weights from Huggingface safetensors...') return load_model_weights(loader=loader, config=config) + + +def update_weights_following_modelopt_optimization( + weights: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + # Rename MLPs to FFNs to match TRTLLM implementation expectation + weights = {k.replace('.mlp.', '.ffn.'): v for k, v in weights.items()} + + # Move all linear attentions to their expected locations + weights = { + k.replace('.attn_replacing_linear.', '.attention.'): v + for k, v in weights.items() + } + + # Move all linear MLPs to their expected locations + weights = { + k.replace('.mlp_replacing_linear.', '.ffn.'): v + for k, v in weights.items() + } + + return weights diff --git a/tensorrt_llm/models/deci/layer_config.py b/tensorrt_llm/models/nemotron_nas/layer_config.py similarity index 100% rename from tensorrt_llm/models/deci/layer_config.py rename to tensorrt_llm/models/nemotron_nas/layer_config.py diff --git a/tensorrt_llm/models/deci/model.py b/tensorrt_llm/models/nemotron_nas/model.py similarity index 67% rename from tensorrt_llm/models/deci/model.py rename to tensorrt_llm/models/nemotron_nas/model.py index b0d0ded0e..a3c3e2388 100644 --- a/tensorrt_llm/models/deci/model.py +++ b/tensorrt_llm/models/nemotron_nas/model.py @@ -16,9 +16,10 @@ from typing import List, Optional, Tuple, Type, Union from tensorrt_llm.bindings import KVCacheType -from tensorrt_llm.functional import (AllReduceFusionParams, AttentionMaskType, - PositionEmbeddingType, Tensor, - gather_last_token_logits, recv, send) +from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceFusionParams, + AttentionMaskType, PositionEmbeddingType, + Tensor, gather_last_token_logits, recv, + send) from tensorrt_llm.layers.attention import (Attention, AttentionParams, KeyValueCacheParams, SpecDecodingParams) @@ -29,16 +30,17 @@ from tensorrt_llm.layers.normalization import RmsNorm from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.convert_utils import has_safetensors -from tensorrt_llm.models.deci.config import DeciConfig -from tensorrt_llm.models.deci.convert import (load_weights_from_hf_model, - load_weights_from_hf_safetensors) from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM +from tensorrt_llm.models.nemotron_nas.config import DeciConfig +from tensorrt_llm.models.nemotron_nas.convert import ( + load_weights_from_hf_model, load_weights_from_hf_safetensors, + update_weights_following_modelopt_optimization) from tensorrt_llm.module import Module, ModuleList from tensorrt_llm.plugin.plugin import init_all_reduce_helper from ..._common import default_net from ..._utils import pad_vocab_size -from ..modeling_utils import QuantConfig, preprocess_weights +from ..modeling_utils import PretrainedConfig, QuantConfig, preprocess_weights @dataclass @@ -123,21 +125,106 @@ def __init__(self, config: DeciConfig, layer_idx: int): self.layer_config = self.config.get_layer_config(self.layer_idx) - layer_type_len = len(config.layer_types) - layer_types = config.layer_types * ((layer_idx + 1) // layer_type_len) - layer_types = layer_types + config.layer_types[0:( - (layer_idx + 1) % layer_type_len)] - - attention_layer_idx = layer_types.count('attention') - 1 - self._init_attention(attention_layer_idx) + self._init_attention() self._init_ffn() - def _init_attention(self, attention_layer_idx) -> None: + @property + def input_layernorm_was_fused(self) -> bool: + """ + The previous layer ran our input_layernorm for us if: + 1. The reduce_fusion plugin is enabled and + 2. We are not the first local model layer and + 3. The previous layer is an MLP layer + """ + return default_net( + ).plugin_config.reduce_fusion and self.local_layer_idx > 0 and self.config.get_layer_config( + self.layer_idx - + 1).is_mlp_layer and self.needs_input_layernorm_fusion + + @property + def needs_input_layernorm_fusion(self) -> bool: + """ + This layer needs the previous layer to perform input_layernorm fusion if: + 1. The reduce_fusion plugin is enabled and + 2. This is not a NOOP attention layer (otherwise it has no input_layernorm) + """ + return default_net( + ).plugin_config.reduce_fusion and not self.layer_config.is_noop_attention_layer + + @property + def can_fuse_post_layernorm(self) -> bool: + """ + This layer can fuse attention and post_layernorm if: + 1. The reduce_fusion plugin is enabled and + 2. It is an attention layer and + 3. It is not a NOOP FFN layer (othrewise it has no post_layernorm) + """ + return default_net( + ).plugin_config.reduce_fusion and self.layer_config.is_attention_layer and not self.layer_config.is_noop_ffn_layer + + @property + def can_fuse_input_layernorm(self) -> bool: + """ + This layer can run the next layer's input_layernorm if: + 1. The reduce_fusion plugin is enable and + 2. It is an MLP layer + """ + return default_net( + ).plugin_config.reduce_fusion and self.layer_config.is_mlp_layer + + def _init_attention(self) -> None: """ Initialize some attention alternative """ # normal attention if self.layer_config.is_attention_layer: + # according to recurrentgemma, len(layer_types) can be less than num_hidden_layers + # in this case, the list should wrap-around + # for example, if layer_types = ["attention", "recurrent", "recurrent"], and we have 5 layers, we get: + # layer 0 ==> attention + # layer 1 ==> recurrent + # layer 2 ==> recurrent + # layer 3 ==> attention + # layer 4 ==> recurrent + # we check which layers are local to our rank + layers_range = self.config.mapping.pp_layers( + self.config.num_hidden_layers) + # then take the size of layer_types in the config + layer_type_len = len(self.config.layer_types) + # collect the layer types of all the local layers + local_layer_types = [ + self.config.layer_types[layer_id % layer_type_len] + for layer_id in layers_range + ] + # and see how many of them are attention layers to determine our local attention layer idx + local_attn_layer_idx = local_layer_types[:self. + local_layer_idx].count( + "attention") + + # Iterate over all local layer configs, getting num_kv_heads of the attention ones + num_kv_heads_per_local_layer = [ + layer_config.attention.num_key_value_heads for layer_config in + [self.config.layer_configs[idx] for idx in layers_range] + if layer_config.is_attention_layer + ] + + # adjust num heads according to tp size + num_kv_heads_per_local_layer = [ + (nheads + self.config.mapping.tp_size - 1) // + self.config.mapping.tp_size + for nheads in num_kv_heads_per_local_layer + ] + nheads_tp = (self.layer_config.attention.num_key_value_heads + + self.config.mapping.tp_size - + 1) // self.config.mapping.tp_size + + # local layers with the same number of kv heads share the same cache pool + # we count how many such layers there are before us to determine our index inside that pool + layer_idx_in_cache_pool = num_kv_heads_per_local_layer[: + local_attn_layer_idx].count( + nheads_tp + ) + self.input_layernorm = RmsNorm( normalized_shape=self.config.hidden_size, eps=self.config.norm_epsilon, @@ -145,7 +232,7 @@ def _init_attention(self, attention_layer_idx) -> None: ) self.attention = Attention( - local_layer_idx=attention_layer_idx, + local_layer_idx=local_attn_layer_idx, hidden_size=self.config.hidden_size, attention_head_size=self.config.head_size, num_attention_heads=self.config.num_attention_heads, @@ -161,7 +248,7 @@ def _init_attention(self, attention_layer_idx) -> None: tp_size=self.config.mapping.tp_size, tp_rank=self.config.mapping.tp_rank, quant_mode=self.config.quant_mode, - ) + layer_idx_in_cache_pool=layer_idx_in_cache_pool) elif self.layer_config.is_noop_attention_layer: self.input_layernorm = NoOpLayerNorm() @@ -238,20 +325,34 @@ def _init_ffn(self) -> None: f"FFN of type {str(self.layer_config.ffn.impl)} is not implemented" ) - def forward( - self, - hidden_states: Tensor, - attention_mask: Optional[Tensor] = None, - use_cache: bool = False, - spec_decoding_params=None, - kv_cache_params: Optional[KeyValueCacheParams] = None, - attention_params: Optional[AttentionParams] = None, - lora_layer_params: Optional[LoraParams] = None, - ): - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + def forward(self, + hidden_states: Tensor | Tuple[Tensor, Tensor], + attention_mask: Optional[Tensor] = None, + use_cache: bool = False, + spec_decoding_params=None, + kv_cache_params: Optional[KeyValueCacheParams] = None, + attention_params: Optional[AttentionParams] = None, + lora_layer_params: Optional[LoraParams] = None, + next_layer_input_layernorm_args: Optional[Tuple[Tensor, + float]] = None): + if self.input_layernorm_was_fused: + # previous layer already performed our layer norm + assert isinstance(hidden_states, tuple) + hidden_states, residual = hidden_states + else: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.can_fuse_post_layernorm: + reduce_fusion_params = AllReduceFusionParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_layernorm.weight.value, + eps=self.post_layernorm.eps) + else: + reduce_fusion_params = None - attention_output = self.attention( + attention_output = self._run_attention( hidden_states=hidden_states, attention_mask=attention_mask, use_cache=use_cache, @@ -259,23 +360,92 @@ def forward( kv_cache_params=kv_cache_params, attention_params=attention_params, lora_layer_params=lora_layer_params, - ) + reduce_fusion_params=reduce_fusion_params) if use_cache: attention_output, present_kv = attention_output else: present_kv = None - hidden_states = residual + attention_output - residual = hidden_states - hidden_states = self.post_layernorm(hidden_states) - hidden_states = self.ffn(hidden_states, - lora_layer_params=lora_layer_params) - hidden_states = residual + hidden_states + if self.can_fuse_post_layernorm: + hidden_states, residual = attention_output + else: + hidden_states = residual + attention_output + residual = hidden_states + hidden_states = self.post_layernorm(hidden_states) + + if next_layer_input_layernorm_args is not None: + assert self.can_fuse_input_layernorm + norm_weight, eps = next_layer_input_layernorm_args + reduce_fusion_params = AllReduceFusionParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=norm_weight, + eps=eps) + hidden_states = self._run_ffn( + hidden_states, + lora_layer_params=lora_layer_params, + reduce_fusion_params=reduce_fusion_params) + + else: + hidden_states = self._run_ffn(hidden_states, + lora_layer_params=lora_layer_params) + hidden_states = residual + hidden_states return DeciLMLayerOutput(hidden_states=hidden_states, present_kv=present_kv) + def _run_attention( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + use_cache: bool = False, + spec_decoding_params=None, + kv_cache_params: Optional[KeyValueCacheParams] = None, + attention_params: Optional[AttentionParams] = None, + lora_layer_params: Optional[LoraParams] = None, + reduce_fusion_params: Optional[AllReduceFusionParams] = None + ) -> Union[Tensor, Tuple[Tensor, None]]: + """ + Ideally, this functionality would be encapsulated in a LinearAttention class, but during + FP8 and lower quantization, our linear classes get overrun by ModelOpt, thus we must + control the attention inputs at the DecoderLayer level. + """ + if self.layer_config.is_linear_attention_layer: + out = self.attention(hidden_states) + return out, None if use_cache else out + else: + if not self.layer_config.is_attention_layer: + assert reduce_fusion_params is None, f"Layer with attention of type {self.layer_config.attention.impl} can't do reduce_fusion" + + return self.attention(hidden_states=hidden_states, + attention_mask=attention_mask, + use_cache=use_cache, + spec_decoding_params=spec_decoding_params, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + lora_layer_params=lora_layer_params, + reduce_fusion_params=reduce_fusion_params) + + def _run_ffn(self, + hidden_states, + lora_layer_params=None, + reduce_fusion_params: Optional[AllReduceFusionParams] = None): + """ + Ideally, this functionality would be encapsulated in a LinearMLP class, but during + FP8 and lower quantization, our linear classes get overrun by ModelOpt, thus we must + control the MLP inputs at the DecoderLayer level. + """ + if reduce_fusion_params is not None: + assert self.layer_config.is_mlp_layer, f"Layer with FFN of type {self.layer_config.ffn.impl} can't do reduce_fusion" + + if self.layer_config.is_linear_ffn_layer: + return self.ffn(hidden_states) + else: + return self.ffn(hidden_states, + lora_layer_params=lora_layer_params, + reduce_fusion_params=reduce_fusion_params) + class DeciLMDecoderLayerList(ModuleList): @@ -311,6 +481,17 @@ def forward( past_key_values = [x for x in pkv_iter] for layer_idx, (layer, past) in enumerate(zip(self, past_key_values)): + next_layer_input_layernorm_args = None + if default_net().plugin_config.reduce_fusion: + if layer_idx < self.layer_list[-1]: + # this is not the last layer + next_layer = self[layer_idx + 1] + if layer.can_fuse_input_layernorm and next_layer.needs_input_layernorm_fusion: + # this layer can fuse the next layer's input_layernorm + next_layer_input_layernorm_args = ( + next_layer.input_layernorm.weight.value, + next_layer.input_layernorm.eps) + layer_out = layer( hidden_states=hidden_states, attention_mask=attention_mask, @@ -329,13 +510,16 @@ def forward( host_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping, cache_indirection=kv_cache_params.cache_indirection, ), spec_decoding_params=spec_decoding_params, use_cache=use_cache, lora_layer_params=lora_params.get_layer_config(layer_idx) if lora_params is not None - and lora_params.lora_ranks is not None else None) + and lora_params.lora_ranks is not None else None, + next_layer_input_layernorm_args=next_layer_input_layernorm_args) hidden_states = layer_out.hidden_states if use_cache and layer_out.present_kv is not None: @@ -511,6 +695,19 @@ def from_hugging_face(cls, model.load(weights) return model + @classmethod + def from_checkpoint(cls, + ckpt_dir: str, + rank: Optional[int] = None, + config: Optional["PretrainedConfig"] = None): + return super().from_checkpoint( + ckpt_dir, + rank, + config, + preprocess_weights_hook= + update_weights_following_modelopt_optimization, + ) + def forward( self, input_ids: Tensor, @@ -605,7 +802,6 @@ def prepare_attention_inputs( attn_layer_idx.append(layer_idx) num_kv_heads_per_layer.append( layer_config.attention.num_key_value_heads) - num_layers = len(attn_layer_idx) attention_inputs = super().prepare_attention_inputs( max_batch_size=max_batch_size, @@ -628,16 +824,4 @@ def prepare_attention_inputs( opt_batch_size=opt_batch_size, num_kv_heads_per_layer=num_kv_heads_per_layer) - kv_idx = 0 - past_key_value = [] - for i in range(self.config.num_hidden_layers): - layer_config = self.config.get_layer_config(i) - if layer_config.is_attention_layer: - past_key_value.append( - attention_inputs['past_key_value'][kv_idx]) - kv_idx += 1 - else: - past_key_value.append(None) - attention_inputs['past_key_value'] = past_key_value - return attention_inputs diff --git a/tensorrt_llm/models/phi/config.py b/tensorrt_llm/models/phi/config.py index b8bf4dc95..82fad32e5 100644 --- a/tensorrt_llm/models/phi/config.py +++ b/tensorrt_llm/models/phi/config.py @@ -54,6 +54,7 @@ def from_hugging_face( quant_config: Optional[QuantConfig] = None, **kwargs): import transformers + trust_remote_code = kwargs.pop('trust_remote_code', True) if isinstance(hf_config_or_dir, transformers.PretrainedConfig): hf_config = hf_config_or_dir @@ -61,7 +62,7 @@ def from_hugging_face( hf_config_dir = str(hf_config_or_dir) hf_config = transformers.AutoConfig.from_pretrained( - hf_config_dir, trust_remote_code=True) + hf_config_dir, trust_remote_code=trust_remote_code) num_key_value_heads = getattr(hf_config, "num_key_value_heads", hf_config.num_attention_heads) diff --git a/tensorrt_llm/models/phi/model.py b/tensorrt_llm/models/phi/model.py index 71f64a640..91aedd1d6 100644 --- a/tensorrt_llm/models/phi/model.py +++ b/tensorrt_llm/models/phi/model.py @@ -191,8 +191,12 @@ def from_hugging_face( quant_config=quant_config, **kwargs) if not use_preloading: + trust_remote_code = kwargs.pop('trust_remote_code', True) + hf_model = AutoModelForCausalLM.from_pretrained( - hf_model_dir, torch_dtype="auto", trust_remote_code=True) + hf_model_dir, + torch_dtype="auto", + trust_remote_code=trust_remote_code) assert isinstance(hf_model, transformers.PreTrainedModel) diff --git a/tensorrt_llm/models/phi3/config.py b/tensorrt_llm/models/phi3/config.py index ce1cfda98..558196930 100644 --- a/tensorrt_llm/models/phi3/config.py +++ b/tensorrt_llm/models/phi3/config.py @@ -18,6 +18,7 @@ import torch from ..._utils import torch_dtype_to_str +from ...layers import MoeConfig from ...logger import logger from ...mapping import Mapping from ..modeling_utils import PretrainedConfig, QuantConfig @@ -54,6 +55,7 @@ def from_hugging_face( quant_config: Optional[QuantConfig] = None, **kwargs): import transformers + trust_remote_code = kwargs.pop('trust_remote_code', True) if isinstance(hf_config_or_dir, transformers.PretrainedConfig): hf_config = hf_config_or_dir @@ -61,7 +63,7 @@ def from_hugging_face( hf_config_dir = str(hf_config_or_dir) hf_config = transformers.AutoConfig.from_pretrained( - hf_config_dir, trust_remote_code=True) + hf_config_dir, trust_remote_code=trust_remote_code) num_key_value_heads = getattr(hf_config, "num_key_value_heads", hf_config.num_attention_heads) @@ -102,9 +104,23 @@ def from_hugging_face( hf_config, "blocksparse_vert_stride", None) kwargs['dense_attention_every_n_layers'] = getattr( hf_config, "dense_attention_every_n_layers", None) + kwargs['norm_epsilon'] = hf_config.layer_norm_epsilon else: kwargs['rotary_base'] = hf_config.rope_theta kwargs['norm_epsilon'] = hf_config.rms_norm_eps + moe_variant = hf_config.architectures[0] == "PhiMoEForCausalLM" + if moe_variant: + kwargs.update({ + 'moe': { + 'num_experts': hf_config.num_local_experts, + 'top_k': hf_config.num_experts_per_tok, + 'normalization_mode': + MoeConfig.ExpertScaleNormalizationMode.SPARSE_MIXER, + 'sparse_mixer_epsilon': hf_config.router_jitter_noise, + }, + 'attention_bias': hf_config.attention_bias + }) + kwargs['position_embedding_type'] = 'rope_gpt_neox' if hf_config.max_position_embeddings >= 128000: kwargs[ @@ -114,7 +130,7 @@ def from_hugging_face( "short_factor"] kwargs['longrope_scaling_long_factors'] = hf_config.rope_scaling[ "long_factor"] - if small_variant: + if small_variant or moe_variant: kwargs['longrope_long_mscale'] = hf_config.rope_scaling[ "long_mscale"] kwargs['longrope_short_mscale'] = hf_config.rope_scaling[ diff --git a/tensorrt_llm/models/phi3/convert.py b/tensorrt_llm/models/phi3/convert.py index 9ee6821db..5a2bf59ec 100644 --- a/tensorrt_llm/models/phi3/convert.py +++ b/tensorrt_llm/models/phi3/convert.py @@ -34,6 +34,12 @@ def load_weights_from_hf_model(hf_model, config): key = key.replace("mlp.down_proj.", "mlp.proj.") #128k key = key.replace("mlp.gate_proj.", "mlp.fc.") #128k key = key.replace("o_proj.", "dense.") #128k + + #MoE + key = key.replace("block_sparse_moe.gate", "mlp.router") + key = key.replace("block_sparse_moe.experts.0.w3", "mlp.fc") + key = key.replace("block_sparse_moe.experts.0.w2", "mlp.proj") + #Layer norm key = key.replace("post_attention_layernorm.", "post_layernorm.") #128k @@ -54,16 +60,44 @@ def load_weights_from_hf_model(hf_model, config): # Swap the halves value = torch.cat((second_half, first_half), dim=0) + if config.architecture == "PhiMoEForCausalLM": + num_experts = config.moe["num_experts"] + mlp_hidden_size = config.intermediate_size + num_hidden = config.hidden_size + rank_experts = list(range(num_experts)) + if config.mapping.has_moe_ep(): + rank_experts = config.mapping.ep_experts(num_experts) + + def get_moe_weight(key, suffix): + param = [] + for expert in rank_experts: + name = key.replace(f"0.{suffix}", f"{expert}.{suffix}") + fc_value = hf_state_dict[name] + param.append(fc_value) + w = torch.stack(param) + return w.reshape(-1, mlp_hidden_size, num_hidden) + + if ".0.w3" in orig_key: + w3 = get_moe_weight(orig_key, 'w3') + w1 = get_moe_weight(orig_key.replace("w3", "w1"), 'w1') + value = torch.concat([w3, w1], dim=-2) + elif ".0.w2" in orig_key: + w2 = get_moe_weight(orig_key, 'w2') + value = w2.reshape(-1, num_hidden, mlp_hidden_size) + elif any([k in orig_key for k in ["w1", "w2", "w3"]]): + continue + if "q_proj" in key: #128k q_param = value k_param = hf_state_dict[orig_key.replace("q_proj", "k_proj")] v_param = hf_state_dict[orig_key.replace("q_proj", "v_proj")] value = torch.cat([q_param, k_param, v_param], dim=0) - key = key.replace("q_proj.weight", "qkv.weight") + key = key.replace("q_proj", "qkv") elif "k_proj" in key or "v_proj" in key: continue - weights[key] = value.to(torch_dtype).cpu() + dtype = torch.float if "router" in key else torch_dtype + weights[key] = value.to(dtype).cpu() if config.architecture == 'Phi3SmallForCausalLM': weights['lm_head.weight'] = weights[ @@ -74,6 +108,8 @@ def load_weights_from_hf_model(hf_model, config): if "qkv." in key: weights[key] = shuffle_qkv_weights(weights[key], config) + if config.architecture in ['Phi3SmallForCausalLM', "PhiMoEForCausalLM" + ] and config.mapping.has_tp(): weights = split_weights_tp(config, weights, torch_dtype) return weights diff --git a/tensorrt_llm/models/phi3/model.py b/tensorrt_llm/models/phi3/model.py index 8c4ff841b..ac29ab9a0 100644 --- a/tensorrt_llm/models/phi3/model.py +++ b/tensorrt_llm/models/phi3/model.py @@ -5,8 +5,9 @@ from ..._utils import pad_vocab_size from ...functional import PositionEmbeddingType, Tensor -from ...layers import (MLP, Attention, AttentionMaskType, BlockSparseAttnParams, - ColumnLinear, Embedding, LayerNorm, RmsNorm) +from ...layers import (MLP, MOE, Attention, AttentionMaskType, + BlockSparseAttnParams, ColumnLinear, Embedding, + LayerNorm, MoeConfig, RmsNorm) from ...lora_manager import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module @@ -31,6 +32,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): self.gegelu_limit = None self.small_variant = config.architecture == "Phi3SmallForCausalLM" + self.moe_variant = config.architecture == "PhiMoEForCausalLM" if self.small_variant: self.gegelu_limit = config.gegelu_limit @@ -51,10 +53,14 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): config.blocksparse_num_local_blocks, config.blocksparse_vertical_stride) + if self.small_variant or self.moe_variant: self.input_layernorm = LayerNorm( - normalized_shape=config.hidden_size, dtype=config.dtype) + normalized_shape=config.hidden_size, + dtype=config.dtype, + eps=config.norm_epsilon) self.post_layernorm = LayerNorm(normalized_shape=config.hidden_size, - dtype=config.dtype) + dtype=config.dtype, + eps=config.norm_epsilon) else: self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, @@ -80,7 +86,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): original_max_position_embeddings = config.original_max_position_embeddings position_embedding_type = PositionEmbeddingType.long_rope - if self.small_variant: + if self.small_variant or self.moe_variant: rope_scaling_short_mscale = config.longrope_short_mscale rope_scaling_long_mscale = config.longrope_long_mscale @@ -94,7 +100,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): max_position_embeddings=config.max_position_embeddings, dtype=config.dtype, attention_mask_type=attention_mask_type, - bias=self.small_variant, + bias=self.small_variant or self.moe_variant, q_scaling=q_scaling, tp_group=tp_group, tp_size=tp_size, @@ -106,14 +112,27 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): original_max_position_embeddings=original_max_position_embeddings, block_sparse_params=block_sparse_attn_params) - self.mlp = MLP(hidden_size=config.hidden_size, - ffn_hidden_size=config.intermediate_size, - hidden_act=config.hidden_act, - dtype=config.dtype, - tp_group=tp_group, - tp_size=tp_size, - quant_mode=config.quant_mode, - bias=self.small_variant) + ClsMLP = MLP + mlp_kwargs = {} + if hasattr(config, "moe"): + ClsMLP = MOE + moe_config = MoeConfig() + for key, value in config.moe.items(): + setattr(moe_config, key, value) + mlp_kwargs = { + "moe_config": moe_config, + "mapping": config.mapping, + } + + self.mlp = ClsMLP(hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + hidden_act=config.hidden_act, + dtype=config.dtype, + tp_group=tp_group, + tp_size=tp_size, + quant_mode=config.quant_mode, + bias=self.small_variant, + **mlp_kwargs) def forward( self, @@ -141,10 +160,14 @@ def forward( post_attention_input = hidden_states + attention_output post_attention_output = self.post_layernorm(post_attention_input) - feed_forward_hidden_states = self.mlp( - post_attention_output, - gegelu_limit=self.gegelu_limit, - lora_layer_params=lora_layer_params) + if self.small_variant: + feed_forward_hidden_states = self.mlp( + post_attention_output, + gegelu_limit=self.gegelu_limit, + lora_layer_params=lora_layer_params) + else: + feed_forward_hidden_states = self.mlp( + post_attention_output, lora_layer_params=lora_layer_params) hidden_states = post_attention_input + feed_forward_hidden_states if use_cache: return (hidden_states, presents) @@ -161,10 +184,13 @@ def __init__(self, config: PretrainedConfig): self.layers = DecoderLayerList(Phi3DecoderLayer, config) self.small_variant = config.architecture == "Phi3SmallForCausalLM" - if self.small_variant: + self.moe_variant = config.architecture == "PhiMoEForCausalLM" + if self.small_variant or self.moe_variant: self.ln_f = LayerNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, dtype=config.dtype) - self.mup_embedding_multiplier = config.mup_embedding_multiplier + if self.small_variant: + self.mup_embedding_multiplier = config.mup_embedding_multiplier else: self.ln_f = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, @@ -216,9 +242,10 @@ def __init__(self, config: PretrainedConfig): vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) + self.moe_variant = config.architecture == "PhiMoEForCausalLM" lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, - bias=False, + bias=self.moe_variant, dtype=config.dtype, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, @@ -257,8 +284,12 @@ def from_hugging_face( **kwargs) if not use_preloading: + trust_remote_code = kwargs.pop('trust_remote_code', True) + hf_model = AutoModelForCausalLM.from_pretrained( - hf_model_dir, torch_dtype="auto", trust_remote_code=True) + hf_model_dir, + torch_dtype="auto", + trust_remote_code=trust_remote_code) assert isinstance(hf_model, transformers.PreTrainedModel) diff --git a/tensorrt_llm/models/phi3/split_weights.py b/tensorrt_llm/models/phi3/split_weights.py index fcc4d735b..62a889123 100644 --- a/tensorrt_llm/models/phi3/split_weights.py +++ b/tensorrt_llm/models/phi3/split_weights.py @@ -15,8 +15,8 @@ import torch -from tensorrt_llm.models.convert_utils import (get_weight_and_bias, split, - split_matrix_tp, +from tensorrt_llm.models.convert_utils import (get_weight, get_weight_and_bias, + split, split_matrix_tp, split_qkv_bias_tp, split_qkv_tp) from ..._utils import pad_vocab_size @@ -110,10 +110,13 @@ def split_weights_tp(config, weights, dtype): num_heads = config.num_attention_heads num_kv_heads = config.num_key_value_heads hidden_size = config.hidden_size + moe_variant = config.architecture == "PhiMoEForCausalLM" mha_mode = num_heads == num_kv_heads tp_size = config.mapping.tp_size rank = config.mapping.tp_rank + moe_tp_size = config.mapping.moe_tp_size + moe_tp_rank = config.mapping.moe_tp_rank use_weight_only = config.quant_mode.is_weight_only() plugin_weight_only_quant_type = None if use_weight_only and config.quant_mode.is_int8_weight_only() == 'int8': @@ -121,8 +124,7 @@ def split_weights_tp(config, weights, dtype): elif use_weight_only and config.quant_mode.is_int4_weight_only() == 'int4': plugin_weight_only_quant_type = torch.quint4x2 - # Helper - def get_weight(weight, prefix, bias): + def get_quant_weight(weight, prefix, bias): return get_tllm_linear_weight(weight, prefix, bias, use_weight_only, plugin_weight_only_quant_type) @@ -156,25 +158,43 @@ def get_weight(weight, prefix, bias): split_bias = split_qkv_bias_tp(qkv_bias, num_heads, hidden_size, tp_size, rank) - weights.update(get_weight(split_weight, prefix, split_bias)) + weights.update(get_quant_weight(split_weight, prefix, split_bias)) prefix = layer_prefix + 'attention.dense' attn_dense_weight, attn_dense_bias = get_weight_and_bias( weights, prefix, dtype) split_v = split_matrix_tp(attn_dense_weight, tp_size, rank, dim=1) - weights.update(get_weight(split_v, prefix, attn_dense_bias)) + weights.update(get_quant_weight(split_v, prefix, attn_dense_bias)) prefix = layer_prefix + 'mlp.fc' - mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(weights, prefix, dtype) - split_v = split_matrix_tp(mlp_fc_weight, tp_size, rank, dim=0) - bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0) - weights.update(get_weight(split_v, prefix, bias)) + if not moe_variant: + mlp_fc_weight, mlp_fc_bias = get_weight_and_bias( + weights, prefix, dtype) + split_v = split_matrix_tp(mlp_fc_weight, tp_size, rank, dim=0) + bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0) + weights.update(get_quant_weight(split_v, prefix, bias)) + else: + mlp_fc_weight = get_weight(weights, prefix, dtype) + w3 = split_matrix_tp(mlp_fc_weight, 2, 0, dim=1) + split_w3 = split_matrix_tp(w3, moe_tp_size, moe_tp_rank, dim=1) + w1 = split_matrix_tp(mlp_fc_weight, 2, 1, dim=1) + split_w1 = split_matrix_tp(w1, moe_tp_size, moe_tp_rank, dim=1) + split_v = torch.concat([split_w3, split_w1], dim=-2) + weights.update(get_quant_weight(split_v, prefix, None)) prefix = layer_prefix + 'mlp.proj' - mlp_proj_weight, mlp_proj_bias = get_weight_and_bias( - weights, prefix, dtype) - split_v = split_matrix_tp(mlp_proj_weight, tp_size, rank, dim=1) - weights.update(get_weight(split_v, prefix, mlp_proj_bias)) + if not moe_variant: + mlp_proj_weight, mlp_proj_bias = get_weight_and_bias( + weights, prefix, dtype) + split_v = split_matrix_tp(mlp_proj_weight, tp_size, rank, dim=1) + weights.update(get_quant_weight(split_v, prefix, mlp_proj_bias)) + else: + mlp_proj_weight = get_weight(weights, prefix, dtype) + split_v = split_matrix_tp(mlp_proj_weight, + moe_tp_size, + moe_tp_rank, + dim=2) + weights.update(get_quant_weight(split_v, prefix, None)) weights['transformer.vocab_embedding.weight'] = split_embedding( weights['transformer.vocab_embedding.weight'], tp_size, rank) @@ -182,5 +202,10 @@ def get_weight(weight, prefix, bias): tp_size, rank, dim=0) + if moe_variant: + weights['lm_head.bias'] = split_matrix_tp(weights['lm_head.bias'], + tp_size, + rank, + dim=0) return weights diff --git a/tensorrt_llm/models/qwen/config.py b/tensorrt_llm/models/qwen/config.py index 051c72f51..3636dc61e 100644 --- a/tensorrt_llm/models/qwen/config.py +++ b/tensorrt_llm/models/qwen/config.py @@ -73,6 +73,7 @@ def from_hugging_face(cls, quant_config: Optional[QuantConfig] = None, **kwargs) -> "QWenConfig": import transformers + trust_remote_code = kwargs.pop('trust_remote_code', True) if isinstance(hf_config_or_dir, transformers.PretrainedConfig): hf_config = hf_config_or_dir @@ -80,7 +81,7 @@ def from_hugging_face(cls, hf_config_dir = str(hf_config_or_dir) hf_config = transformers.AutoConfig.from_pretrained( - hf_config_dir, trust_remote_code=True) + hf_config_dir, trust_remote_code=trust_remote_code) qwen_type = hf_config.model_type valid_types = ('qwen', 'qwen2', 'qwen2_moe') diff --git a/tensorrt_llm/models/qwen/convert.py b/tensorrt_llm/models/qwen/convert.py index 76769d0d8..34f26f496 100644 --- a/tensorrt_llm/models/qwen/convert.py +++ b/tensorrt_llm/models/qwen/convert.py @@ -57,7 +57,8 @@ def smooth_qwen_model(model, scales, alpha, qwen_qkv_para, qwen_smoother): scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=1)[0] # see transpose_weights function - qwen_qkv_para[layer_name] = module.attn.c_attn.weight.transpose(0, 1) + qwen_qkv_para[layer_name] = module.attn.c_attn.weight.transpose( + 0, 1).contiguous() # ================================================================= layer_name = name + ".attn.c_proj" @@ -127,7 +128,7 @@ def smooth_qwen2_model(model, scales, alpha, qwen_qkv_para, qwen_smoother): dim=0) # see transpose_weights function - qwen_qkv_para[layer_name_qkv] = weight.transpose(0, 1) + qwen_qkv_para[layer_name_qkv] = weight.transpose(0, 1).contiguous() # ================================================================= layer_name = name + ".self_attn.o_proj" @@ -293,12 +294,13 @@ def get_tllm_linear_sq_weight(vals, results = {} def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): - q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) - q_split = np.split(q, tp_size, axis=-1) - k_split = np.split(k, tp_size, axis=-1) - v_split = np.split(v, tp_size, axis=-1) + + q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1) + q_split = torch.split(q, q.shape[-1] // tp_size, dim=-1) + k_split = torch.split(k, q.shape[-1] // tp_size, dim=-1) + v_split = torch.split(v, q.shape[-1] // tp_size, dim=-1) return [ - np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) + torch.concat((q_split[ii], k_split[ii], v_split[ii]), dim=-1) for ii in range(tp_size) ][cur_rank] @@ -306,9 +308,9 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): if per_token: if per_channel: - original_weights = np.array(vals["weight.int8.col"]) + original_weights = vals["weight.int8.col"] else: - original_weights = np.array(vals["weight.int8"]) + original_weights = vals["weight.int8"] local_dim = original_weights.shape[0] head_size = (original_weights.shape[1] - local_dim) // 2 @@ -316,14 +318,14 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: - cur_weights = np.split(original_weights, - tensor_parallel, - axis=cat_dim)[rank] + cur_weights = torch.split(original_weights, + original_weights.shape[-1] // + tensor_parallel, + dim=cat_dim)[rank] if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + 'weight'] = torch.from_numpy( - cur_weights).t().clone().contiguous() + results[prefix + 'weight'] = cur_weights.t().contiguous() if smoother_value is None: results[last_prefix] = torch.from_numpy( np.array([1.0], dtype=np.float32)) @@ -332,6 +334,7 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_per_channel_value = vals["scale_w_quant_orig.col"] if smoother_value is None: if multi_query_mode: + cur_per_channel_value = multi_query_split( vals["scale_w_quant_orig.col"], local_dim, head_size, tensor_parallel, rank) @@ -348,18 +351,18 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): vals["scale_w_quant_orig"], local_dim, head_size, tensor_parallel, rank) else: - cur_per_channel_value = np.split(vals["scale_w_quant_orig"], - tensor_parallel, - axis=cat_dim)[rank] + cur_per_channel_value = torch.split( + vals["scale_w_quant_orig"], + tensor_parallel, + axis=cat_dim)[rank] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array(cur_per_channel_value, - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + + 'per_channel_scale'] = cur_per_channel_value.reshape(col_shape) else: if per_channel: - original_weights = np.array(vals["weight.int8.col"]) + original_weights = vals["weight.int8.col"] else: - original_weights = np.array(vals["weight.int8"]) + original_weights = vals["weight.int8"] local_dim = original_weights.shape[0] head_size = (original_weights.shape[1] - local_dim) // 2 @@ -367,14 +370,14 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: - cur_weights = np.split(original_weights, - tensor_parallel, - axis=cat_dim)[rank] + cur_weights = torch.split(original_weights, + original_weights.shape[-1] // + tensor_parallel, + dim=cat_dim)[rank] if is_qkv: hidden_dim = cur_weights.shape[0] cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + 'weight'] = torch.from_numpy( - cur_weights).t().clone().contiguous() + results[prefix + 'weight'] = cur_weights.t().contiguous() if per_channel: cur_per_channel_value = vals["scale_y_accum_quant.col"] @@ -402,22 +405,19 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): tensor_parallel, axis=cat_dim)[rank] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array([cur_per_channel_value], - dtype=np.float32).reshape(col_shape)).contiguous() + results[prefix + 'per_channel_scale'] = cur_per_channel_value.reshape( + col_shape).contiguous() - results[last_prefix] = torch.from_numpy( - np.array([vals['scale_x_orig_quant']], - dtype=np.float32)).contiguous() + results[last_prefix] = vals['scale_x_orig_quant'].contiguous() - results[prefix + 'act_scale'] = torch.from_numpy( - np.array([[vals["scale_y_quant_orig"]]], - dtype=np.float32)).contiguous() + results[prefix + 'act_scale'] = vals["scale_y_quant_orig"].contiguous() if smoother_value is not None: - cur_smoother_value = np.split(smoother_value, - tensor_parallel, - axis=cat_dim)[rank] + cur_smoother_value = torch.split(smoother_value, + smoother_value.shape[-1] // + tensor_parallel, + dim=cat_dim)[rank] + results[prefix + 'smoother'] = cur_smoother_value.reshape( smoother_shape).contiguous().to(torch.float32) @@ -573,7 +573,6 @@ def convert_hf_qwen(hf_model, qkv_b, use_weight_only, plugin_weight_only_quant_type, dtype, use_gemm_woq_plugin)) - if int8_kv_cache: if qwen_type == 'qwen': qkv_y = act_range.get(prefix + key_list[0])["y"] @@ -768,6 +767,7 @@ def convert_hf_qwen(hf_model, mlp_fc_weight = mlp_fc_weight.t() #verified int8_weights = generate_int8( mlp_fc_weight, act_range.get(prefix + key_list[3])) + weights.update( get_tllm_linear_sq_weight( int8_weights, @@ -800,6 +800,7 @@ def convert_hf_qwen(hf_model, mlp_proj_weight = mlp_proj_weight.t() int8_weights = generate_int8( mlp_proj_weight, act_range.get(prefix + key_list[4])) + weights.update( get_tllm_linear_sq_weight( int8_weights, @@ -815,7 +816,7 @@ def convert_hf_qwen(hf_model, 1, intermediate_size // tensor_parallel ], rank=mapping.tp_rank, - cat_dim=0)) + cat_dim=-1)) else: weights.update( get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', diff --git a/tensorrt_llm/models/qwen/model.py b/tensorrt_llm/models/qwen/model.py index 1f0605d1e..c3dd5b305 100644 --- a/tensorrt_llm/models/qwen/model.py +++ b/tensorrt_llm/models/qwen/model.py @@ -47,7 +47,7 @@ def __init__(self, config: QWenConfig, layer_idx: int): dtype = config.dtype self.tp_group = config.mapping.tp_group - tp_size = config.mapping.tp_size + self.tp_size = config.mapping.tp_size self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, @@ -69,7 +69,7 @@ def __init__(self, config: QWenConfig, layer_idx: int): rotary_embedding_base=config.rotary_base, rotary_embedding_scaling=config.rotary_scaling, tp_group=self.tp_group, - tp_size=tp_size, + tp_size=self.tp_size, quant_mode=config.quant_mode, dense_bias=False) @@ -90,7 +90,7 @@ def __init__(self, config: QWenConfig, layer_idx: int): dtype=dtype, bias=False, tp_group=self.tp_group, - tp_size=tp_size, + tp_size=self.tp_size, quant_mode=config.quant_mode, is_expert=True) self.shared_expert_gate = RowLinear(config.hidden_size, @@ -115,7 +115,7 @@ def __init__(self, config: QWenConfig, layer_idx: int): dtype=dtype, bias=config.mlp_bias, tp_group=self.tp_group, - tp_size=tp_size, + tp_size=self.tp_size, quant_mode=config.quant_mode, **mlp_kwargs) self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, @@ -168,7 +168,8 @@ def forward( if shared_output is not None: hidden_states = hidden_states + shared_output - hidden_states = allreduce(hidden_states, self.tp_group) + if self.tp_size > 1 and self.tp_group is not None: + hidden_states = allreduce(hidden_states, self.tp_group) hidden_states = residual + hidden_states if use_cache: diff --git a/tensorrt_llm/models/recurrentgemma/model.py b/tensorrt_llm/models/recurrentgemma/model.py index d555fc5c3..e0cbe77b1 100644 --- a/tensorrt_llm/models/recurrentgemma/model.py +++ b/tensorrt_llm/models/recurrentgemma/model.py @@ -57,6 +57,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): layer_types = layer_types + config.layer_types[0:( (layer_idx + 1) % layer_type_len)] attention_layer_idx = layer_types.count('attention') - 1 + self.attention = Attention( local_layer_idx=attention_layer_idx, hidden_size=config.hidden_size, @@ -209,6 +210,8 @@ def forward(self, host_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping, cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params, conv_state=past_conv, @@ -499,7 +502,6 @@ def prepare_inputs( mapping, num_profiles) # attention inputs - num_attention_layers = self.layer_types.count('attention') attn_layer_idx = [] for i in range(self.config.num_hidden_layers): if self.layer_types[i] == 'attention': @@ -511,7 +513,7 @@ def prepare_inputs( max_seq_len=max_seq_len, num_kv_heads=self.config.num_key_value_heads, head_size=self.config.head_size, - num_layers=num_attention_layers, + num_layers=self.config.num_hidden_layers, kv_dtype=str_dtype_to_trt(self.config.kv_dtype), num_profiles=num_profiles, enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, @@ -523,17 +525,6 @@ def prepare_inputs( streamingllm=streamingllm, attn_layer_idx=attn_layer_idx) - kv_idx = 0 - past_key_value = [] - for i in range(self.config.num_hidden_layers): - if self.layer_types[i] == 'attention' and not paged_kv_cache: - past_key_value.append( - attention_inputs['past_key_value'][kv_idx]) - kv_idx += 1 - else: - past_key_value.append(None) - attention_inputs['past_key_value'] = past_key_value - # recurrent inputs recurrent_inputs = self.prepare_recurrent_inputs( max_batch_size=max_batch_size, @@ -601,6 +592,8 @@ def prepare_inputs( 'host_kv_cache_block_offsets'], host_kv_cache_pool_pointers=attention_inputs[ 'host_kv_cache_pool_pointers'], + host_kv_cache_pool_mapping=attention_inputs[ + 'host_kv_cache_pool_mapping'], cache_indirection=attention_inputs['cache_indirection'], ), 'attention_params': diff --git a/tensorrt_llm/quantization/functional.py b/tensorrt_llm/quantization/functional.py index 7d1950ee7..6967ccef3 100644 --- a/tensorrt_llm/quantization/functional.py +++ b/tensorrt_llm/quantization/functional.py @@ -628,6 +628,25 @@ def preprocess_weights_for_mixed_gemm(weight, quant_mode): original_shape[1] // 2) +def validate_group_size(layer): + # TODO: Remove this function and its usage after W4A8-AWQ with group_size = 64 is implemented. + W4A8_AWQ = 8 + if layer.quant_algo & W4A8_AWQ and layer.group_size == 64: + raise NotImplementedError( + "W4A8_AWQ with group_size = 64 is not implemented yet!") + + +def unpack_int32_into_int8(w_packed): + # Unpack inputs packed in int32/float32 into uint4 and store them in int8 format + w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) + w_unpacked = torch.zeros(w_packed_int4x2.shape[0], + w_packed_int4x2.shape[1] * 2, + dtype=torch.int8) + w_unpacked[:, ::2] = w_packed_int4x2 % 16 + w_unpacked[:, 1::2] = w_packed_int4x2 // 16 + return w_unpacked.contiguous() + + def change_qkv_leading_dim(w, num_heads): if w.dim() == 1: w = w.reshape(num_heads, 3, -1) @@ -682,6 +701,124 @@ def postprocess_weight_only(tllm_key, weights, quant_mode, layer): return {tllm_key: weights} # Bias +def postprocess_weight_only_groupwise(tllm_key, weights, torch_dtype, layer, + **kwargs): + using_head_as_leading_dim = kwargs.get("using_head_as_leading_dim", False) + config = kwargs.get("config", None) + use_autoawq = kwargs.get("use_autoawq", None) + num_heads = config.num_attention_heads + USE_GPTQ = layer.prequant_scaling_factor is None and use_autoawq is None + USE_HF_AWQ = layer.prequant_scaling_factor is None and use_autoawq is not None + USE_MODELOPT_AWQ = layer.prequant_scaling_factor is not None + + tp_dim = 1 if isinstance(layer, ColumnLinear) else 0 + is_qkv = layer.is_qkv if hasattr(layer, "is_qkv") else False + + if using_head_as_leading_dim: + assert config.num_attention_heads == config.num_key_value_heads, "using_head_as_leading_dim require head_size to be multiple of 3." + if tllm_key.endswith("weights_scaling_factor"): + # TODO: Remove reshaping after modelopt optimizes scale shape + if is_qkv: + for idx, w in enumerate(weights): + scales = w.to(torch_dtype) + scales = scales.reshape(-1, + layer.weights_scaling_factor.shape[0]).T + scales = scales.chunk(layer.tp_size, 1)[layer.tp_rank] + weights[idx] = scales + weights = torch.cat(weights, dim=1) + else: + scales = weights.to(torch_dtype) + scales_shape = [ + layer.weights_scaling_factor.shape[1], + layer.weights_scaling_factor.shape[0] + ] + scales_shape[1 - tp_dim] *= layer.tp_size + scales = scales.reshape(scales_shape).T + weights = scales.chunk(layer.tp_size, tp_dim)[layer.tp_rank] + if is_qkv and isinstance(weights, list) and len(weights) >= 3: + if USE_MODELOPT_AWQ: + if tllm_key.endswith("prequant_scaling_factor"): + weights = weights[0] + else: + weights = torch.cat(weights, dim=0) + elif len(weights) > 3: + weights = [ + torch.cat(weights[i::len(weights) // 3], dim=1) + for i in range(len(weights) // 3) + ] + + if tllm_key.endswith("bias"): + if is_qkv and isinstance(weights, list): + weights = torch.cat(weights) + if layer.is_padded: + weights = pad_like(weights, layer.bias.shape) + if using_head_as_leading_dim: + weights = change_qkv_leading_dim(weights, num_heads) + results = {tllm_key: weights} + elif tllm_key.endswith("weight"): + if USE_GPTQ: + qweight = unpack_int32_into_int8(weights[0].T).T - 8 + elif USE_HF_AWQ: + qweight = unpack_int32_into_int8(weights[0]) - 8 + else: + qweight = unpack_int32_into_int8(weights.T) + qweight[qweight < 0] += 16 + qweight = qweight.view(torch.uint8) + if using_head_as_leading_dim: + qweight = change_qkv_leading_dim(qweight, num_heads) + if layer.is_padded: + qweight = torch.split(qweight, layer.out_features, + tp_dim)[layer.tp_rank] + qweight = pad_like(qweight, (layer.in_features, layer.out_features)) + qweight = (qweight[:, 1::2] * 16 + qweight[:, ::2]).view(torch.int8) + qweight = torch.ops.trtllm.preprocess_weights_for_mixed_gemm( + qweight.contiguous(), torch.quint4x2, + torch.float16).view(torch_dtype) + results = {tllm_key: qweight} + + # scales and zeros for GPTQ and HF-AWQ + if USE_GPTQ or USE_HF_AWQ: + scales = weights[1].to(torch_dtype) + qzeros = unpack_int32_into_int8(weights[2]) + if using_head_as_leading_dim: + scales = change_qkv_leading_dim(scales, num_heads) + qzeros = change_qkv_leading_dim(qzeros, num_heads) + if layer.is_padded: + scales = torch.split(scales, + layer.weights_scaling_factor.shape[tp_dim], + tp_dim)[layer.tp_rank] + scales = pad_like(scales, layer.weights_scaling_factor.shape, 1) + qzeros = torch.split(qzeros, + layer.weights_scaling_factor.shape[tp_dim], + tp_dim)[layer.tp_rank] + qzeros = pad_like(qzeros, layer.zero.shape, 7) + zeros_x_scales = (-qzeros + 8 - 1 * USE_GPTQ) * scales + zeros_x_scales = zeros_x_scales.to(torch_dtype) + results.update({ + tllm_key.replace("weight", "weights_scaling_factor"): + scales, + tllm_key.replace("weight", "zero"): + zeros_x_scales, + }) + elif tllm_key.endswith("weights_scaling_factor"): + # TODO: Remove reshaping after modelopt optimizes scale shape + if layer.is_padded: + raise NotImplementedError( + "Auto-padding is not Implemented for ModelOpt HF-AWQ.") + results = {tllm_key: weights} + elif tllm_key.endswith("prequant_scaling_factor"): + prequant_scale = weights.to(torch_dtype).reshape(1, -1) + if layer.is_padded and tp_dim == 1: + prequant_scale = torch.split(prequant_scale, + layer.prequant_scaling_factor.shape[1], + 1)[layer.tp_rank] + prequant_scale = pad_like(prequant_scale, + layer.prequant_scaling_factor.shape, 0) + results = {tllm_key: prequant_scale} + + return results + + def postprocess_fp8_rowwise(tllm_key, weights, **kwargs): if tllm_key.endswith("per_channel_scale"): return {} diff --git a/tensorrt_llm/quantization/layers.py b/tensorrt_llm/quantization/layers.py index b75d7d572..ac0b14916 100644 --- a/tensorrt_llm/quantization/layers.py +++ b/tensorrt_llm/quantization/layers.py @@ -35,11 +35,12 @@ # isort: off from .functional import ( - change_qkv_leading_dim, dequantize, fp8_rowwise_gemm, fp8_rowwise_rms_norm, - postprocess_fp8_rowwise, postprocess_weight_only, quantize, + dequantize, fp8_rowwise_gemm, fp8_rowwise_rms_norm, postprocess_fp8_rowwise, + postprocess_weight_only, postprocess_weight_only_groupwise, quantize, quantize_fp8_per_token, quantize_per_token, quantize_tensor, - smooth_quant_gemm, smooth_quant_layer_norm, smooth_quant_rms_norm, - weight_only_groupwise_quant_matmul, weight_only_quant_matmul, pad_like) + validate_group_size, smooth_quant_gemm, smooth_quant_layer_norm, + smooth_quant_rms_norm, weight_only_groupwise_quant_matmul, + weight_only_quant_matmul) # isort: on from .mode import QuantMode @@ -848,14 +849,15 @@ def __init__( else: self.register_parameter('alpha', None) + validate_group_size(self) if pre_quant_scale: self.tllm_to_externel_key_dict = { - "weight": - ["weight", "weight_scaling_factor", "prequant_scaling_factor"] + "weights_scaling_factor": "weight_scale", + "prequant_scaling_factor": "input_quantizer._pre_quant_scale", } # AWQ else: self.tllm_to_externel_key_dict = { - "weight": ["qweight", "qzeros", "scales"] + "weight": ["qweight", "scales", "qzeros"] } # GPTQ def forward(self, x, lora_runtime_params=None): @@ -882,74 +884,13 @@ def forward(self, x, lora_runtime_params=None): return x def postprocess(self, tllm_key, weights, **kwargs): - using_head_as_leading_dim = kwargs.get("using_head_as_leading_dim", - False) - config = kwargs.get("config", None) - num_heads = config.num_attention_heads - if using_head_as_leading_dim: - assert config.num_attention_heads == config.num_key_value_heads, "using_head_as_leading_dim require head_size to be multiple of 3." - if not (tllm_key.endswith("bias") or tllm_key.endswith("weight")): + if tllm_key.endswith("zero") or ( + self.prequant_scaling_factor is None + and tllm_key.endswith("weights_scaling_factor")): return {} - if self.is_qkv and type(weights) is list and len(weights) > 3: - weights = [ - torch.cat(weights[i::len(weights) // 3], dim=1) - for i in range(len(weights) // 3) - ] - - if tllm_key.endswith("bias"): - if self.is_padded: - weights = torch.split(weights, self.out_features, - 0)[self.tp_rank] - weights = pad_like(weights, self.bias.shape) - if using_head_as_leading_dim: - weights = change_qkv_leading_dim(weights, num_heads) - return weights - elif tllm_key.endswith("weight"): - qweight_int32, qzeros_int32, scales_fp16 = weights - qweight_unpacked_int8 = unpack_int32_into_int8( - qweight_int32.T).T.contiguous() - 8 - qweight = qweight_unpacked_int8 - qweight[qweight < 0] += 16 - qweight = qweight.view(torch.uint8) - if using_head_as_leading_dim: - qweight = change_qkv_leading_dim(qweight, num_heads) - scales_fp16 = change_qkv_leading_dim(scales_fp16, num_heads) - if self.is_padded: - qweight = torch.split(qweight, self.out_features, - 1)[self.tp_rank] - qweight = pad_like(qweight, - (self.in_features, self.out_features)) - qweight = (qweight[:, 1::2] * 16 + qweight[:, ::2]).view(torch.int8) - qweight = torch.ops.trtllm.preprocess_weights_for_mixed_gemm( - qweight.contiguous(), torch.quint4x2, - torch.float16).view(str_dtype_to_torch(self.dtype)) - # zeros = zeros * scales - qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32) - if using_head_as_leading_dim: - qzeros_unpacked_int32 = change_qkv_leading_dim( - qzeros_unpacked_int32, num_heads) - if self.is_padded: - scales_fp16 = torch.split(scales_fp16, - self.weights_scaling_factor.shape[1], - 1)[self.tp_rank] - qzeros_unpacked_int32 = torch.split( - qzeros_unpacked_int32, self.weights_scaling_factor.shape[1], - 1)[self.tp_rank] - scales_fp16 = pad_like(scales_fp16, - self.weights_scaling_factor.shape, 1) - qzeros_unpacked_int32 = pad_like(qzeros_unpacked_int32, - self.zero.shape, 7) - zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 7) * scales_fp16 - zeros_x_scales_fp16 = zeros_x_scales_fp16.to( - str_dtype_to_torch(self.dtype)) - - results = { - tllm_key: qweight, - tllm_key.replace("weight", "weights_scaling_factor"): - scales_fp16, - tllm_key.replace("weight", "zero"): zeros_x_scales_fp16, - } - return results + torch_dtype = str_dtype_to_torch(self.dtype) + return postprocess_weight_only_groupwise(tllm_key, weights, torch_dtype, + self, **kwargs) WeightOnlyGroupwiseQuantColumnLinear = WeightOnlyGroupwiseQuantLinear @@ -1021,14 +962,15 @@ def __init__( else: self.register_parameter('alpha', None) + validate_group_size(self) if pre_quant_scale: self.tllm_to_externel_key_dict = { - "weight": - ["weight", "weight_scaling_factor", "prequant_scaling_factor"] + "weights_scaling_factor": "weight_scale", + "prequant_scaling_factor": "input_quantizer._pre_quant_scale", } # AWQ else: self.tllm_to_externel_key_dict = { - "weight": ["qweight", "qzeros", "scales"] + "weight": ["qweight", "scales", "qzeros"] } # GPTQ def forward(self, x, lora_runtime_params=None, reduce_fusion_params=None): @@ -1056,54 +998,13 @@ def forward(self, x, lora_runtime_params=None, reduce_fusion_params=None): return x def postprocess(self, tllm_key, weights, **kwargs): - if not (tllm_key.endswith("bias") or tllm_key.endswith("weight")): + if tllm_key.endswith("zero") or ( + self.prequant_scaling_factor is None + and tllm_key.endswith("weights_scaling_factor")): return {} - if tllm_key.endswith("bias"): - if self.is_padded: - weights = pad_like(weights, self.bias.shape) - if self.tp_size > 1: - weights /= self.tp_size - return weights - elif tllm_key.endswith("weight"): - qweight_int32, qzeros_int32, scales_fp16 = weights - qweight_unpacked_int8 = unpack_int32_into_int8( - qweight_int32.T).T.contiguous() - 8 - qweight = qweight_unpacked_int8 - qweight[qweight < 0] += 16 - qweight = qweight.view(torch.uint8) - if self.is_padded: - qweight = torch.split(qweight, self.in_features, - 0)[self.tp_rank] - qweight = pad_like(qweight, - (self.in_features, self.out_features)) - qweight = (qweight[:, 1::2] * 16 + qweight[:, ::2]).view(torch.int8) - qweight = torch.ops.trtllm.preprocess_weights_for_mixed_gemm( - qweight.contiguous(), torch.quint4x2, - torch.float16).view(str_dtype_to_torch(self.dtype)) - # zeros = zeros * scales - qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32) - if self.is_padded: - scales_fp16 = torch.split(scales_fp16, - self.weights_scaling_factor.shape[0], - 0)[self.tp_rank] - qzeros_unpacked_int32 = torch.split( - qzeros_unpacked_int32, self.weights_scaling_factor.shape[0], - 0)[self.tp_rank] - scales_fp16 = pad_like(scales_fp16, - self.weights_scaling_factor.shape, 1) - qzeros_unpacked_int32 = pad_like(qzeros_unpacked_int32, - self.zero.shape, 7) - zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 7) * scales_fp16 - zeros_x_scales_fp16 = zeros_x_scales_fp16.to( - str_dtype_to_torch(self.dtype)) - - results = { - tllm_key: qweight, - tllm_key.replace("weight", "weights_scaling_factor"): - scales_fp16, - tllm_key.replace("weight", "zero"): zeros_x_scales_fp16, - } - return results + torch_dtype = str_dtype_to_torch(self.dtype) + return postprocess_weight_only_groupwise(tllm_key, weights, torch_dtype, + self, **kwargs) class SmoothQuantMLP(Module): @@ -1594,7 +1495,10 @@ def forward(self, hidden_states, lora_layer_params=None): if self.quant_mode.has_fp8_rowwise(): # Quantize per token outputs tuple: # quantized tensor and scaling factors per token - inter = quantize_fp8_per_token(inter, self.clamp_val.val) + if hasattr(self.clamp_val, "val"): + inter = quantize_fp8_per_token(inter, self.clamp_val.val) + else: + inter = quantize_fp8_per_token(inter) output = self.proj(inter) return output @@ -1718,32 +1622,31 @@ def forward(self, hidden_states, lora_layer_params=None): class SmoothQuantAttention(Module): - def __init__( - self, - *, - local_layer_idx, - hidden_size, - num_attention_heads, - num_kv_heads=None, - max_position_embeddings=1024, - num_layers=1, - apply_query_key_layer_scaling=False, - attention_head_size=None, - attention_mask_type=AttentionMaskType.padding, - bias=True, - dense_bias=None, - dtype=None, - position_embedding_type=PositionEmbeddingType.learned_absolute, - rotary_embedding_base=10000.0, - rotary_embedding_scaling=None, - rotary_embedding_percentage=1.0, - tp_group=None, - tp_size=1, - tp_rank=0, - scale_alibi_bias=False, - paged_kv_cache=False, - quant_mode=QuantMode(0), - ): + def __init__(self, + *, + local_layer_idx, + hidden_size, + num_attention_heads, + num_kv_heads=None, + max_position_embeddings=1024, + num_layers=1, + apply_query_key_layer_scaling=False, + attention_head_size=None, + attention_mask_type=AttentionMaskType.padding, + bias=True, + dense_bias=None, + dtype=None, + position_embedding_type=PositionEmbeddingType.learned_absolute, + rotary_embedding_base=10000.0, + rotary_embedding_scaling=None, + rotary_embedding_percentage=1.0, + tp_group=None, + tp_size=1, + tp_rank=0, + scale_alibi_bias=False, + paged_kv_cache=False, + quant_mode=QuantMode(0), + layer_idx_in_cache_pool=None): super().__init__() self.local_layer_idx = local_layer_idx self.attention_mask_type = attention_mask_type @@ -1752,6 +1655,7 @@ def __init__( self.num_attention_kv_heads = ( num_kv_heads + tp_size - 1 ) // tp_size if num_kv_heads is not None else self.num_attention_heads + self.layer_idx_in_cache_pool = layer_idx_in_cache_pool self.hidden_size = hidden_size // tp_size self.max_position_embeddings = 0 if max_position_embeddings is None else max_position_embeddings self.tp_size = tp_size @@ -1916,6 +1820,7 @@ def forward( layer_idx=self.local_layer_idx, num_heads=self.num_attention_heads, num_kv_heads=self.num_attention_kv_heads, + layer_idx_in_cache_pool=self.layer_idx_in_cache_pool, hidden_size_per_head=self.attention_head_size, q_scaling=self.q_scaling, rotary_embedding_dim=self.rotary_embedding_dim, @@ -1938,6 +1843,8 @@ def forward( host_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping, host_context_lengths=attention_params.host_context_lengths, use_cache=use_cache, spec_decoding_generation_lengths=spec_decoding_params. diff --git a/tensorrt_llm/quantization/quantize_by_modelopt.py b/tensorrt_llm/quantization/quantize_by_modelopt.py index 7883ca479..769f0e0f0 100644 --- a/tensorrt_llm/quantization/quantize_by_modelopt.py +++ b/tensorrt_llm/quantization/quantize_by_modelopt.py @@ -31,6 +31,7 @@ from torch.utils.data import DataLoader from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from .._utils import release_gc from ..logger import logger from ..mapping import Mapping from .mode import QuantAlgo @@ -123,7 +124,9 @@ def quant_cfg_choices(): "Phi3SmallForCausalLM": "phi3small", "Phi3ForCausalLM": "phi3", "Starcoder2ForCausalLM": "gptnext", + "GPTBigCodeForCausalLM": "gptnext", "GLM": "glm", + "DeciLMForCausalLM": "deci", } @@ -542,9 +545,6 @@ def quantize_and_export(*, with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) - torch.cuda.empty_cache( - ) # otherwise torch is keeping using GPU, other routine like build engine has less free GPU to use - # Workaround for combining medusa head # TODO: move these integration into modelopt to avoid redundant reading and writing if medusa_model_dir is not None: @@ -557,6 +557,12 @@ def quantize_and_export(*, "Quantized model exported to {} \nTotal time used {:.2f} s.".format( export_path, end_time - start_time)) + # Need to delete the model and release memory explicitly; + # otherwise torch may retain its GPU memory until a delayed GC running, + # which reduces the available GPU memory for subsequent stages. + del model + release_gc() + def unwrap_model(model, module_instances=None): # Reference: https://github.com/NVIDIA/Megatron-LM/blob/core_r0.8.0/megatron/training/utils.py @@ -823,11 +829,15 @@ def forward_loop(model): inference_pipeline_parallel=pp_size, ) - torch.cuda.empty_cache( - ) # otherwise torch is keeping using GPU, other routine like build engine has less free GPU to use end_time = time.time() print_rank_0( f"Model config exported to: {output_dir}. Total time used {end_time - start_time}s" ) if torch.distributed.get_rank() == 0: save_artifacts(model, output_dir, use_abspath=True) + + # Need to delete the model and release memory explicitly; + # otherwise torch may retain its GPU memory until a delayed GC running, + # which reduces the available GPU memory for subsequent stages. + del model + release_gc() diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index dba5aac5b..983d458b8 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -16,6 +16,7 @@ import copy import math import platform +from collections import Counter from dataclasses import dataclass, field from functools import reduce, wraps from pathlib import Path @@ -29,6 +30,10 @@ # isort: on from cuda import cudart +from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \ + MemoryPoolsAllocator +from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \ + PoolsKVCacheManager from tensorrt_llm.runtime.redrafter_utils import * from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy, @@ -39,7 +44,7 @@ from ..mapping import Mapping from ..plugin.plugin import CustomAllReduceHelper from ..quantization import QuantMode -from .kv_cache_manager import GenerationSequence, KVCacheManager, KVCacheUpdater +from .kv_cache_manager import GenerationSequence, KVCacheUpdater from .session import _scoped_stream @@ -809,10 +814,12 @@ def __init__(self, expected_tensor_names += [f'kv_cache_block_offsets'] expected_tensor_names += [f'host_kv_cache_block_offsets'] expected_tensor_names += [f'host_kv_cache_pool_pointers'] + expected_tensor_names += [f'host_kv_cache_pool_mapping'] if self.cross_attention: expected_tensor_names += [f'cross_kv_cache_block_offsets'] expected_tensor_names += [f'host_cross_kv_cache_block_offsets'] expected_tensor_names += [f'host_cross_kv_cache_pool_pointers'] + expected_tensor_names += [f'host_cross_kv_cache_pool_mapping'] else: # Refer to gpt_attention() inside functional.py if self.use_kv_cache and not self.paged_kv_cache: @@ -1695,40 +1702,42 @@ def setup(self, num_blocks, _ = self._get_num_paged_blocks( self.max_attention_window_size, self.sink_token_length, self.use_one_more_block) - cache_shape = ( - num_blocks, - self.num_attn_layers, - 2, - self.get_num_heads_kv(), - self.tokens_per_block, - self.head_size, - ) - self.kv_cache_pool = torch.empty(cache_shape, - dtype=kv_cache_type, - device=self.device) + self._memory_pool_allocator = MemoryPoolsAllocator( + num_blocks=num_blocks, + tokens_per_block=self.tokens_per_block, + head_size=self.head_size) + if self._model_config.num_kv_heads_per_layer is None: + num_kv_heads_per_layer = MemoryPoolsAllocator.prepare_num_kv_heads_per_layer( + self.get_num_heads_kv(), self.num_attn_layers) + else: + num_kv_heads_per_layer = self._model_config.num_kv_heads_per_layer + + self._memory_pool_allocator.allocate(kv_cache_type, + num_kv_heads_per_layer) + if self.cross_attention: # As for now we enable cross paged kv and self paged kv to share the same tokens_per_block cross_num_blocks, _ = self._get_num_paged_blocks( self.encoder_max_input_length, sink_token_length=0, use_one_more_block=False) - cross_cache_shape = ( - cross_num_blocks, - self.num_layers, - 2, - self.get_num_heads_kv(), - self.tokens_per_block, - self.head_size, - ) - self.cross_kv_cache_pool = torch.empty(cross_cache_shape, - dtype=kv_cache_type, - device=self.device) + + num_kv_heads_per_layer = MemoryPoolsAllocator.prepare_num_kv_heads_per_layer( + self.get_num_heads_kv(), self.num_layers) + + self._cross_memory_pool_allocator = MemoryPoolsAllocator( + num_blocks=cross_num_blocks, + tokens_per_block=self.tokens_per_block, + head_size=self.head_size) + self._cross_memory_pool_allocator.allocate( + kv_cache_type, num_kv_heads_per_layer) + elif self.has_attn_layers: for i in range(self.first_layer, self.last_layer): if self.layer_types[i] == 'attention': cache_shape = ( batch_size, 2, - self.get_num_heads_kv(self.general_to_attn_idx[i]), + self.get_num_heads_kv(i), self.max_attention_window_size, self.head_size, ) @@ -1844,6 +1853,43 @@ def setup(self, if self.is_medusa_mode: return self.num_draft_tokens + def _allocate_empty_kv_cache_pools(self, kv_cache_type, num_blocks): + # Layers are homogeneous, use old kv cache shape + unique_cache_pools = [] + if self._model_config.num_kv_heads_per_layer is None: + cache_shape = ( + num_blocks, + self.num_attn_layers, + 2, + self.get_num_heads_kv(), + self.tokens_per_block, + self.head_size, + ) + unique_cache_pools.append( + torch.empty(cache_shape, + dtype=kv_cache_type, + device=self.device)) + + # Layers are not homogeneous, use new kv cache shape + else: + kv_heads_unique_counter = Counter( + self._model_config.num_kv_heads_per_layer) + for kv_head, num_layers in kv_heads_unique_counter.items(): + cache_shape = ( + num_blocks, + num_layers, + 2, + kv_head, + self.tokens_per_block, + self.head_size, + ) + unique_cache_pools.append( + torch.empty(cache_shape, + dtype=kv_cache_type, + device=self.device)) + + return unique_cache_pools + def _get_context_shape_buffer( self, input_ids: torch.Tensor, @@ -1962,17 +2008,20 @@ def add_tensor_with_bs(x, name, bs): if self.paged_kv_cache and self.has_attn_layers: buffer = kv_cache_block_offsets.contiguous() shape = kv_cache_block_offsets.shape - shape = [shape[0] * shape[1], *shape[2:]] + shape = [shape[0], shape[1] * shape[2], *shape[3:]] add_tensor_with_shape(buffer, f'kv_cache_block_offsets', shape) add_tensor_with_shape(host_kv_cache_block_offsets, f'host_kv_cache_block_offsets', shape) pool_pointers = f'host_kv_cache_pool_pointers' + pool_mapping = f'host_kv_cache_pool_mapping' add_tensor(self.buffer[pool_pointers], pool_pointers) + add_tensor(self.buffer[pool_mapping], pool_mapping) if self.cross_attention: cross_buffer = cross_kv_cache_block_offsets.contiguous() cross_shape = cross_kv_cache_block_offsets.shape cross_shape = [ - cross_shape[0] * cross_shape[1], *cross_shape[2:] + cross_shape[0], cross_shape[1] * cross_shape[2], + *cross_shape[3:] ] add_tensor_with_shape(cross_buffer, f'cross_kv_cache_block_offsets', @@ -1981,8 +2030,10 @@ def add_tensor_with_bs(x, name, bs): f'host_cross_kv_cache_block_offsets', cross_shape) cross_pool_pointers = f'host_cross_kv_cache_pool_pointers' + cross_pool_mapping = f'host_cross_kv_cache_pool_mapping' add_tensor(self.buffer[cross_pool_pointers], cross_pool_pointers) + add_tensor(self.buffer[cross_pool_mapping], cross_pool_mapping) batch_size = context_lengths.shape[0] if self.use_kv_cache and not self.paged_kv_cache: @@ -2245,17 +2296,20 @@ def add_tensor_with_shape(x, name, shape): if self.paged_kv_cache and self.has_attn_layers: shape = kv_cache_block_offsets.shape - shape = [shape[0] * shape[1], *shape[2:]] + shape = [shape[0], shape[1] * shape[2], *shape[3:]] add_tensor_with_shape(kv_cache_block_offsets, f'kv_cache_block_offsets', shape) add_tensor_with_shape(host_kv_cache_block_offsets, f'host_kv_cache_block_offsets', shape) pool_pointers = f'host_kv_cache_pool_pointers' + pool_mapping = f'host_kv_cache_pool_mapping' add_tensor(self.buffer[pool_pointers], pool_pointers) + add_tensor(self.buffer[pool_mapping], pool_mapping) if self.cross_attention: cross_shape = cross_kv_cache_block_offsets.shape cross_shape = [ - cross_shape[0] * cross_shape[1], *cross_shape[2:] + cross_shape[0], cross_shape[1] * cross_shape[2], + *cross_shape[3:] ] add_tensor_with_shape(cross_kv_cache_block_offsets, f'cross_kv_cache_block_offsets', @@ -2264,8 +2318,10 @@ def add_tensor_with_shape(x, name, shape): f'host_cross_kv_cache_block_offsets', cross_shape) cross_pool_pointers = f'host_cross_kv_cache_pool_pointers' + cross_pool_mapping = f'host_cross_kv_cache_pool_mapping' add_tensor(self.buffer[cross_pool_pointers], cross_pool_pointers) + add_tensor(self.buffer[cross_pool_mapping], cross_pool_mapping) if prompt_embedding_table is not None: add_tensor(prompt_embedding_table, 'prompt_embedding_table') @@ -3054,11 +3110,11 @@ def handle_per_step( 'host_runtime_perf_knobs', None) if self.paged_kv_cache and self.has_attn_layers: - host_kv_cache_block_offsets = self.kv_cache_manager.get_block_offsets( + host_kv_cache_block_offsets = self.pools_kv_cache_manager.get_block_offsets( beam_width=1) kv_cache_block_offsets = host_kv_cache_block_offsets.to('cuda') if self.cross_attention: - host_cross_kv_cache_block_offsets = self.cross_kv_cache_manager.get_block_offsets( + host_cross_kv_cache_block_offsets = self.cross_pools_kv_cache_manager.get_block_offsets( beam_width=1) cross_kv_cache_block_offsets = host_cross_kv_cache_block_offsets.to( 'cuda') @@ -3235,7 +3291,7 @@ def handle_per_step( self.accept_lengths).item() assert add_token_count > 0 for _ in range(add_token_count): - self.kv_cache_manager.step([False] * batch_size) + self.pools_kv_cache_manager.step([False] * batch_size) if self.is_medusa_mode and self.num_draft_tokens > 0: # Allocate kv cache token slots for next step. # Make sure there are always > (num_draft_tokens + 1) free token slots. @@ -3245,16 +3301,16 @@ def handle_per_step( self.accept_lengths).item() assert add_token_count > 0 for _ in range(add_token_count): - self.kv_cache_manager.step([False] * batch_size) + self.pools_kv_cache_manager.step([False] * batch_size) else: - self.kv_cache_manager.step([False] * batch_size) + self.pools_kv_cache_manager.step([False] * batch_size) torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push("paged_kv_post_alloc") - host_kv_cache_block_offsets = self.kv_cache_manager.get_block_offsets( + host_kv_cache_block_offsets = self.pools_kv_cache_manager.get_block_offsets( beam_width) kv_cache_block_offsets = host_kv_cache_block_offsets.to('cuda') if self.cross_attention: - host_cross_kv_cache_block_offsets = self.cross_kv_cache_manager.get_block_offsets( + host_cross_kv_cache_block_offsets = self.cross_pools_kv_cache_manager.get_block_offsets( beam_width) cross_kv_cache_block_offsets = host_cross_kv_cache_block_offsets.to( 'cuda') @@ -3385,9 +3441,9 @@ def handle_per_step( and should_stop.item()): # Free all blocks in all sequences. # With in-flight batching and while loop we'll free some sequences, when they are done - self.kv_cache_manager.step([True] * batch_size) + self.pools_kv_cache_manager.step([True] * batch_size) if self.cross_attention: - self.cross_kv_cache_manager.step([True] * batch_size) + self.cross_pools_kv_cache_manager.step([True] * batch_size) if self.debug_mode: self.dump_debug_buffers(step) @@ -3763,21 +3819,23 @@ def decode(self, num_blocks, max_blocks_per_seq = self._get_num_paged_blocks( self.max_attention_window_size, self.sink_token_length, self.use_one_more_block) - self.buffer[f'host_kv_cache_pool_pointers'] = torch.tensor( - [self.kv_cache_pool.data_ptr(), 0], dtype=torch.int64) - - block_size = self.get_num_heads_kv( - ) * self.tokens_per_block * self.head_size - self.kv_cache_manager = KVCacheManager( - num_layers=self.num_attn_layers, - num_blocks=num_blocks, - block_size=block_size, - tokens_per_block=self.tokens_per_block, - max_blocks_per_seq=max_blocks_per_seq, + + self.buffer[ + f'host_kv_cache_pool_pointers'] = self._memory_pool_allocator.get_kv_cache_pool_pointers( + ) + self.buffer[ + f'host_kv_cache_pool_mapping'] = self._memory_pool_allocator.pool_mapping + + self.pools_kv_cache_manager = PoolsKVCacheManager( + self._memory_pool_allocator.pools_metadata, + max_blocks_per_seq, + num_blocks, + self.tokens_per_block, + self.head_size, max_attention_window_size=self.max_attention_window_size, - sink_token_len=self.sink_token_length, beam_width=beam_width, - use_one_more_block=self.use_one_more_block) + use_one_more_block=self.use_one_more_block, + sink_token_len=self.sink_token_length) if self.cross_attention: cross_num_blocks, max_cross_blocks_per_seq = self._get_num_paged_blocks( @@ -3785,33 +3843,32 @@ def decode(self, sink_token_length=0, use_one_more_block=False) self.buffer[ - f'host_cross_kv_cache_pool_pointers'] = torch.tensor( - [self.cross_kv_cache_pool.data_ptr(), 0], - dtype=torch.int64) - - cross_block_size = self.get_num_heads_kv( - ) * self.tokens_per_block * self.head_size - self.cross_kv_cache_manager = KVCacheManager( - num_layers=self.num_layers, - num_blocks=cross_num_blocks, - block_size=cross_block_size, - tokens_per_block=self.tokens_per_block, - max_blocks_per_seq=max_cross_blocks_per_seq, + f'host_cross_kv_cache_pool_pointers'] = self._cross_memory_pool_allocator.get_kv_cache_pool_pointers( + ) + self.buffer[ + f'host_cross_kv_cache_pool_mapping'] = self._cross_memory_pool_allocator.pool_mapping + + self.cross_pools_kv_cache_manager = PoolsKVCacheManager( + self._memory_pool_allocator.pools_metadata, + max_cross_blocks_per_seq, + cross_num_blocks, + self.tokens_per_block, + self.head_size, max_attention_window_size=self.encoder_max_input_length, - sink_token_len=self.sink_token_length, beam_width=beam_width, - use_one_more_block=False) + use_one_more_block=False, + sink_token_len=self.sink_token_length) # Add sequences to the manager for bi in range(batch_size): generation_sequence = GenerationSequence(seq_idx=bi, batch_idx=bi) - self.kv_cache_manager.add_sequence(generation_sequence, - max_context_length) + self.pools_kv_cache_manager.add_sequence( + generation_sequence, max_context_length) if self.cross_attention: cross_generation_sequence = GenerationSequence(seq_idx=bi, batch_idx=bi) - self.cross_kv_cache_manager.add_sequence( + self.cross_pools_kv_cache_manager.add_sequence( cross_generation_sequence, self.encoder_max_input_length, always_share_across_beam=True) @@ -3833,7 +3890,7 @@ def decode(self, if self.paged_kv_cache: self.kv_cache_updater.init_paged_kv_cache( self.num_layers, self.get_num_heads_kv(), self.head_size, - kv_cache_type, self.kv_cache_manager, + kv_cache_type, self.pools_kv_cache_manager, self.buffer[f'host_kv_cache_pool_pointers']) else: past_key_value_list = [ diff --git a/tensorrt_llm/runtime/kv_cache_manager.py b/tensorrt_llm/runtime/kv_cache_manager.py index f7b33c336..c2b6c3f9b 100644 --- a/tensorrt_llm/runtime/kv_cache_manager.py +++ b/tensorrt_llm/runtime/kv_cache_manager.py @@ -79,7 +79,8 @@ def __init__(self, max_blocks_per_seq: int = 128, beam_width: int = 1): """ - expected block pool shape: [num_blocks, num_layers, 2, block_size] + If layers are homogeneous then the expected block pool shape is: [num_blocks, num_layers, 2, block_size] + Otherwise, the expected block pool shape is: [num_blocks, 2, block_size] """ self.max_blocks_per_seq = max_blocks_per_seq @@ -263,6 +264,7 @@ def __init__(self, block_size=block_size, max_blocks_per_seq=max_blocks_per_seq, beam_width=beam_width) + self.tokens_per_block = tokens_per_block self.max_attention_window_size = max_attention_window_size self.sink_token_len = sink_token_len @@ -422,8 +424,15 @@ def update(self, accepted_draft_token_offsets, int) else 0 assert self.use_paged_kv_cache is not None if self.use_paged_kv_cache: - host_kv_cache_block_offsets = self.kv_cache_manager.get_block_offsets( - 1) + if self.kv_cache_manager.has_single_pool(): + kv_cache_manager = self.kv_cache_manager.get_single_kv_cache_manager( + ) + else: + raise RuntimeError( + "Currently, using KVCacheUpdater with more then single memory pool is not supported" + ) + + host_kv_cache_block_offsets = kv_cache_manager.get_block_offsets(1) kv_cache_block_offsets = host_kv_cache_block_offsets.to('cuda') torch.ops.tensorrt_llm.update_kv_cache_draft_token_location( accepted_draft_token_offsets, @@ -434,13 +443,13 @@ def update(self, accepted_draft_token_offsets, self.num_kv_heads, self.head_dim * self.elt_size, rewind_tokens_count, - self.kv_cache_manager.max_attention_window_size, + kv_cache_manager.max_attention_window_size, rewind_tokens_tensor, None, self.host_kv_cache_pool_pointers, kv_cache_block_offsets, - self.kv_cache_manager.blocks_manager.max_blocks_per_seq, - self.kv_cache_manager.tokens_per_block, + kv_cache_manager.blocks_manager.max_blocks_per_seq, + kv_cache_manager.tokens_per_block, None, ) else: diff --git a/tensorrt_llm/runtime/memory_pools/__init__.py b/tensorrt_llm/runtime/memory_pools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorrt_llm/runtime/memory_pools/memory_pools_allocator.py b/tensorrt_llm/runtime/memory_pools/memory_pools_allocator.py new file mode 100644 index 000000000..d24d8d68b --- /dev/null +++ b/tensorrt_llm/runtime/memory_pools/memory_pools_allocator.py @@ -0,0 +1,80 @@ +from collections import Counter +from typing import List + +import torch + +import tensorrt_llm +from tensorrt_llm.runtime.memory_pools.pool import Pool + + +class MemoryPoolsAllocator(object): + + def __init__(self, num_blocks, tokens_per_block, head_size): + self._pools_metadata = [] + self._pool_pointers = [] + self._pool_mapping = None + + self._num_blocks = num_blocks + self._tokens_per_block = tokens_per_block + self._head_size = head_size + + def allocate(self, dtype, num_kv_heads_per_layer: List[int], device="cuda"): + self._num_kv_heads_per_layer = num_kv_heads_per_layer + + if isinstance(dtype, str): + dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype) + kv_heads_unique_counter = Counter(self._num_kv_heads_per_layer) + keys_to_indices = {} + + for idx, (kv_head, + num_layers) in enumerate(kv_heads_unique_counter.items()): + keys_to_indices[kv_head] = idx + cache_shape = ( + self._num_blocks, + num_layers, + 2, + kv_head, + self._tokens_per_block, + self._head_size, + ) + self._pool_pointers.append( + torch.empty(cache_shape, dtype=dtype, device=device)) + self._pools_metadata.append( + Pool(num_kv_heads=kv_head, num_layers=num_layers)) + + self._set_layers_mapping(keys_to_indices) + + def get_kv_cache_pool_pointers(self): + return self._get_primarmy_secondary_pool_pointers() + + def _set_layers_mapping(self, keys_to_indices): + layers_mapping = [] + for kv_size in self._num_kv_heads_per_layer: + layers_mapping.append(keys_to_indices[kv_size]) + + self._pool_mapping = torch.tensor(layers_mapping, dtype=torch.int32) + + def _get_primarmy_secondary_pool_pointers(self): + assert len(self._pool_pointers + ) >= 1, "pool pointers haven't been initiated yet" + data_ptr_pointers = torch.tensor(list( + map(lambda x: x.data_ptr(), self._pool_pointers)), + dtype=torch.int64) + host_kv_cache_pool_pointers = torch.cat( + (data_ptr_pointers.view(-1, 1), + torch.zeros(len(self._pool_pointers), 1, dtype=torch.int64)), + dim=1) + + return host_kv_cache_pool_pointers + + @classmethod + def prepare_num_kv_heads_per_layer(cls, kv_head, num_layers): + return [kv_head] * num_layers + + @property + def pools_metadata(self): + return self._pools_metadata + + @property + def pool_mapping(self): + return self._pool_mapping diff --git a/tensorrt_llm/runtime/memory_pools/pool.py b/tensorrt_llm/runtime/memory_pools/pool.py new file mode 100644 index 000000000..63308ad0d --- /dev/null +++ b/tensorrt_llm/runtime/memory_pools/pool.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass +class Pool(object): + num_kv_heads: int + num_layers: int diff --git a/tensorrt_llm/runtime/memory_pools/pools_kv_cache_manager.py b/tensorrt_llm/runtime/memory_pools/pools_kv_cache_manager.py new file mode 100644 index 000000000..4baf86ad3 --- /dev/null +++ b/tensorrt_llm/runtime/memory_pools/pools_kv_cache_manager.py @@ -0,0 +1,67 @@ +from typing import List + +import torch + +from tensorrt_llm.runtime.kv_cache_manager import (GenerationSequence, + KVCacheManager) +from tensorrt_llm.runtime.memory_pools.pool import Pool + + +class PoolsKVCacheManager(object): + + def __init__(self, + pools_metadata: List[Pool], + max_blocks_per_seq, + num_blocks, + tokens_per_block, + head_size, + max_attention_window_size, + beam_width, + sink_token_len, + use_one_more_block: bool = False) -> None: + self._num_pools = len(pools_metadata) + self._kv_cache_managers = [] + + for pool in pools_metadata: + block_size = pool.num_kv_heads * tokens_per_block * head_size + self._kv_cache_managers.append( + KVCacheManager( + num_layers=pool.num_layers, + num_blocks=num_blocks, + block_size=block_size, + tokens_per_block=tokens_per_block, + max_blocks_per_seq=max_blocks_per_seq, + max_attention_window_size=max_attention_window_size, + sink_token_len=sink_token_len, + use_one_more_block=use_one_more_block, + beam_width=beam_width, + )) + + def add_sequence(self, + sequence: GenerationSequence, + context_len: int, + always_share_across_beam: bool = False): + for kv_cache_manager in self._kv_cache_managers: + kv_cache_manager.add_sequence(sequence, context_len, + always_share_across_beam) + + def step(self, finished: List[bool]): + for kv_cache_manager in self._kv_cache_managers: + kv_cache_manager.step(finished) + + def get_block_offsets(self, beam_width: int) -> torch.Tensor: + offsets = [] + for kv_cache_manager in self._kv_cache_managers: + block_offset = kv_cache_manager.get_block_offsets(beam_width) + offsets.append(block_offset) + + return torch.stack(offsets) + + def get_single_kv_cache_manager(self): + assert len(self._kv_cache_managers + ) == 1, f"More then one kv cache manager exists" + + return self._kv_cache_managers[0] + + def has_single_pool(self): + return len(self._kv_cache_managers) == 1 diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index 1a35299a8..d2ba7edfa 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -830,6 +830,12 @@ def generate(self, self._check_inputs(batch_input_ids, sampling_config) + if kwargs.get('num_return_sequences', 1) > 1: + logger.warning( + 'num_return_sequences will be ignored since ' + 'num_return_sequences > 1 is not supported on python runtime. ' + 'Please use C++ runtime.') + batch_size = len(batch_input_ids) batch_input_ids, input_lengths = self._prepare_inputs( batch_input_ids, sampling_config.pad_id) diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 3d576db35..daa5c608d 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -20,6 +20,7 @@ import torch from .. import profiler +from .._utils import mpi_broadcast from ..bindings import (DataType, GptJsonConfig, KVCacheType, ModelConfig, WorldConfig) from ..bindings import executor as trtllm @@ -73,31 +74,30 @@ def __init__(self, self.lora_manager = lora_manager @classmethod - def from_dir( - cls, - engine_dir: str, - *, - lora_dir: Optional[str] = None, - rank: int = 0, - max_batch_size: Optional[int] = None, - max_input_len: Optional[int] = None, - max_output_len: Optional[int] = None, - max_beam_width: Optional[int] = None, - max_attention_window_size: Optional[list[int]] = None, - sink_token_length: Optional[int] = None, - kv_cache_free_gpu_memory_fraction: Optional[float] = None, - medusa_choices: list[list[int]] | None = None, - lookahead_config: list[int] | None = None, - debug_mode: bool = False, - lora_ckpt_source: str = "hf", - gpu_weights_percent: float = 1, - max_tokens_in_paged_kv_cache: int | None = None, - kv_cache_enable_block_reuse: bool = False, - enable_chunked_context: bool = False, - is_enc_dec: bool = False, - multi_block_mode: bool = True, - enable_context_fmha_fp32_acc: Optional[bool] = None - ) -> 'ModelRunnerCpp': + def from_dir(cls, + engine_dir: str, + *, + lora_dir: Optional[str] = None, + rank: int = 0, + max_batch_size: Optional[int] = None, + max_input_len: Optional[int] = None, + max_output_len: Optional[int] = None, + max_beam_width: Optional[int] = None, + max_attention_window_size: Optional[list[int]] = None, + sink_token_length: Optional[int] = None, + kv_cache_free_gpu_memory_fraction: Optional[float] = None, + medusa_choices: list[list[int]] | None = None, + lookahead_config: list[int] | None = None, + debug_mode: bool = False, + lora_ckpt_source: str = "hf", + gpu_weights_percent: float = 1, + max_tokens_in_paged_kv_cache: int | None = None, + kv_cache_enable_block_reuse: bool = False, + enable_chunked_context: bool = False, + is_enc_dec: bool = False, + multi_block_mode: bool = True, + enable_context_fmha_fp32_acc: Optional[bool] = None, + cuda_graph_mode: Optional[bool] = None) -> 'ModelRunnerCpp': """ Create a ModelRunnerCpp instance from an engine directory. @@ -148,6 +148,8 @@ def from_dir( Whether to distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel. enable_context_fmha_fp32_acc (bool): Enable FMHA runner FP32 accumulation. + cuda_graph_mode (bool): + Whether to use cuda graph for inference. Returns: ModelRunnerCpp: An instance of ModelRunnerCpp. """ @@ -157,6 +159,8 @@ def from_dir( extended_runtime_perf_knob_config.multi_block_mode = multi_block_mode if enable_context_fmha_fp32_acc is not None: extended_runtime_perf_knob_config.enable_context_fmha_fp32_acc = enable_context_fmha_fp32_acc + if cuda_graph_mode is not None: + extended_runtime_perf_knob_config.cuda_graph_mode = cuda_graph_mode if is_enc_dec: encoder_config_path = Path(engine_dir) / "encoder" / "config.json" @@ -268,6 +272,9 @@ def from_dir( lora_manager = None peft_cache_config = trtllm.PeftCacheConfig() + if world_config.size > 1: + peft_cache_config = mpi_broadcast(peft_cache_config, 0) + profiler.start('load tensorrt_llm engine') kv_cache_config = trtllm.KvCacheConfig( @@ -313,8 +320,8 @@ def from_dir( # To debug specific tensors, add tensor names in the following list # if none provided, all input and output tensors will be dumped # if not none, it will disable all input/output dump - debug_tensor_names: List[ - str] = None # modify this list for specific tensor dump + debug_tensor_names: List[str] = [ + ] # modify this list for specific tensor dump debug_config = trtllm.DebugConfig( debug_input_tensors=True, debug_output_tensors=True, @@ -432,6 +439,7 @@ def generate( stopping_criteria: Optional[StoppingCriteria] = None, logits_processor: Optional[LogitsProcessor] = None, max_new_tokens: int = 1, + num_return_sequences: int = 1, end_id: int | None = None, pad_id: int | None = None, bad_words_list: list[list[int]] | None = None, @@ -481,6 +489,9 @@ def generate( Custom logits processors. return_all_generated_tokens (bool): Whether the full output is returned at each streaming step + num_return_sequences (int): + The number of sequences to generate for each input. It will + return (batch_size * num_return_sequences) sequences in total. kwargs (Dict[str, Any]: Ad hoc parametrization of sampling_config. The passed **kwargs matching the sampling_config's attributes will override them. @@ -576,6 +587,7 @@ def generate( position_ids=position_ids[i].tolist() if position_ids is not None else None, max_tokens=max_new_tokens, + num_return_sequences=num_return_sequences, pad_id=pad_id, end_id=end_id, stop_words=stop_words, @@ -595,18 +607,17 @@ def generate( request_ids = self.session.enqueue_requests(requests) if not streaming: - return self._initialize_and_fill_output(request_ids, end_id, - return_dict, - output_sequence_lengths, - output_log_probs, - output_cum_log_probs, - batch_input_ids, streaming) + return self._initialize_and_fill_output( + request_ids, end_id, return_dict, output_sequence_lengths, + output_log_probs, output_cum_log_probs, batch_input_ids, + streaming, max_new_tokens, num_return_sequences) else: return self._stream(request_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, batch_input_ids_list, streaming, - return_all_generated_tokens) + return_all_generated_tokens, max_new_tokens, + num_return_sequences) def _prepare_words_list(self, words_list: List[List[List[int]]], batch_size: int): @@ -655,13 +666,19 @@ def _prepare_lora_configs(self, lora_uids, batch_size): if int(uid) >= 0 else None for uid in lora_uids ] - def _initialize_and_fill_output(self, request_ids, end_id, return_dict, - output_sequence_lengths, output_log_probs, - output_cum_log_probs, batch_input_ids, - streaming): - output_ids = [[] for _ in range(len(request_ids))] - for reqid_pos in range(len(request_ids)): - output_ids[reqid_pos] = [[] for _ in range(self.max_beam_width)] + def _initialize_and_fill_output(self, + request_ids, + end_id, + return_dict, + output_sequence_lengths, + output_log_probs, + output_cum_log_probs, + batch_input_ids, + streaming, + max_new_tokens: int, + num_return_sequences: int = 1): + output_ids = [[[] for _ in range(self.max_beam_width)] + for _ in range(len(request_ids) * num_return_sequences)] multi_responses = self.session.await_responses(request_ids) responses = [ @@ -671,121 +688,178 @@ def _initialize_and_fill_output(self, request_ids, end_id, return_dict, return self._fill_output(responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, [], - streaming, request_ids, False) - - def _stream(self, request_ids, end_id, return_dict, output_sequence_lengths, - output_log_probs, output_cum_log_probs, batch_input_ids, - batch_input_ids_list, streaming, return_all_generated_tokens): - output_ids = [[] for _ in range(len(request_ids))] - for reqid_pos in range(len(request_ids)): + streaming, request_ids, False, max_new_tokens, + num_return_sequences) + + def _stream(self, + request_ids, + end_id, + return_dict, + output_sequence_lengths, + output_log_probs, + output_cum_log_probs, + batch_input_ids, + batch_input_ids_list, + streaming, + return_all_generated_tokens, + max_new_tokens: int, + num_return_sequences: int = 1): + + output_ids = [[] + for _ in range(len(request_ids) * num_return_sequences)] + for reqid_pos in range(len(request_ids) * num_return_sequences): + batch_idx = reqid_pos // num_return_sequences output_ids[reqid_pos] = [ - copy.deepcopy(batch_input_ids_list[reqid_pos]) + copy.deepcopy(batch_input_ids_list[batch_idx]) for _ in range(self.max_beam_width) ] - finished_reqs = 0 - while finished_reqs < len(request_ids): + finished_request_ids = set() + while finished_request_ids != set(request_ids): responses = self.session.await_responses() - for response in responses: if response.result.is_final: - finished_reqs += 1 + finished_request_ids.add(response.request_id) yield self._fill_output(responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, batch_input_ids_list, streaming, - request_ids, return_all_generated_tokens) + request_ids, return_all_generated_tokens, + max_new_tokens, num_return_sequences) def _fill_output(self, responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, batch_input_ids_list, streaming, request_ids, - return_all_generated_tokens): + return_all_generated_tokens, max_new_tokens, + num_return_sequences): cuda_device = torch.device("cuda") + # Total number of output sequences = batch_size * num_return_sequences. + batch_size = len(batch_input_ids) + num_output_sequences = len(output_ids) + num_beams = len(output_ids[0]) + assert batch_size * num_return_sequences == num_output_sequences + + def req_idx(response: trtllm.Response): + batch_idx = request_ids.index(response.request_id) + seq_idx = response.result.sequence_index + return batch_idx * num_return_sequences + seq_idx + for response in responses: if response.has_error(): raise RuntimeError(response.error_msg) - reqid_pos = request_ids.index(response.request_id) - for beam, output_tokens in enumerate( - response.result.output_token_ids): + result = response.result + batch_idx = request_ids.index(response.request_id) + + for beam, output_tokens in enumerate(result.output_token_ids): + # Return shape = (batch_size * num_return_seq, beam, seq_len) if return_all_generated_tokens: - output_ids[reqid_pos][ - beam] = batch_input_ids_list[reqid_pos] + output_tokens + output_ids[req_idx(response)][beam] = ( + batch_input_ids_list[batch_idx] + output_tokens) else: - output_ids[reqid_pos][beam] += output_tokens + output_ids[req_idx(response)][beam] += output_tokens - sequence_lengths = [] - for output in output_ids: - sequence_lengths.append([len(a) for a in output]) + if output_sequence_lengths: + sequence_lengths = [[len(b) for b in beam] for beam in output_ids] if streaming: output_ids = copy.deepcopy(output_ids) - for beam in output_ids: - for output_tokens in beam: - output_tokens += (self.max_seq_len - - len(output_tokens)) * [end_id] - + # Pad by end_id tokens (batch * n, num_beams, max_seq_len). + for beams in output_ids: + for token_ids in beams: + token_ids += [end_id] * (self.max_seq_len - len(token_ids)) output_ids = torch.tensor(output_ids, dtype=torch.int32, device=cuda_device) if return_dict: outputs = {'output_ids': output_ids} + + input_lengths = torch.tensor([x.size(0) for x in batch_input_ids], + dtype=torch.int32, + device=cuda_device) + if output_sequence_lengths: outputs['sequence_lengths'] = torch.tensor(sequence_lengths, dtype=torch.int32, device=cuda_device) if self.gather_context_logits: - outputs['context_logits'] = [ - a.result.context_logits.cuda() for a in responses - if a.result.context_logits is not None - ] - # Pad context_logits into a rectangle - max_input_length = max(a.shape[0] - for a in outputs['context_logits']) - for i, a in enumerate(outputs['context_logits']): - pad_length = max_input_length - a.shape[0] - outputs['context_logits'][i] = torch.nn.functional.pad( - a, [0, 0, 0, pad_length]) - outputs['context_logits'] = torch.stack( - outputs['context_logits']) + context_logits = None + max_input_len = input_lengths.max() + for response in responses: + result = response.result + logits = result.context_logits + if logits is None: + continue + input_len, vocab_size = logits.shape + if context_logits is None: + context_logits = torch.zeros( + (num_output_sequences, max_input_len, vocab_size), + dtype=logits.dtype, + device=cuda_device) + context_logits[req_idx(response), :input_len, :] = logits + assert context_logits is not None + outputs['context_logits'] = context_logits + if self.gather_generation_logits: - outputs['generation_logits'] = [ - a.result.generation_logits.cuda() for a in responses - if a.result.generation_logits is not None - ] - outputs['generation_logits'] = torch.stack( - outputs['generation_logits']) + if not streaming: + gen_shape = (num_beams, max_new_tokens, vocab_size) + elif streaming and return_all_generated_tokens: + gen_shape = (max_new_tokens, num_beams, vocab_size) + else: # streaming and not return_all_generated_tokens + gen_shape = (1, num_beams, vocab_size) + + gen_logits = None + for response in responses: + # gen logits shape: (beam, seq, vocab) + logits = response.result.generation_logits + if logits is None: + continue + num_beams, seq_len, vocab_size = logits.shape + if gen_logits is None: + gen_logits = torch.zeros( + (num_output_sequences, *gen_shape), + dtype=logits.dtype, + device=cuda_device) + batch_idx = request_ids.index(response.request_id) + seq_idx = response.result.sequence_index + reqid_pos = batch_idx * num_return_sequences + seq_idx + if streaming: + gen_logits[reqid_pos, :seq_len, ...] = logits[0] + else: + gen_logits[reqid_pos, :, :seq_len, ...] = logits[0] + outputs['generation_logits'] = gen_logits + if output_log_probs: - outputs['log_probs'] = [ - a.result.log_probs for a in responses - if a.result.log_probs is not None - ] - # Pad log_probs into a rectangle - max_seq_len = max( - len(a) for beam_list in outputs['log_probs'] - for a in beam_list) - for i, a in enumerate(outputs['log_probs']): - for j, b in enumerate(a): - pad_length = max_seq_len - len(b) - outputs['log_probs'][i][j] = b + [0.0] * pad_length - outputs['log_probs'] = torch.tensor(outputs['log_probs'], - device=cuda_device) + log_probs = None + for response in responses: + if log_probs is None: + # TODO: Refactor not to allocate a buffer per step. + output_len = len(response.result.log_probs[0]) + log_probs = torch.zeros( + (num_output_sequences, num_beams, output_len), + dtype=torch.float32) + for i, lprobs in enumerate(response.result.log_probs): + log_probs[req_idx(response), i, :len(lprobs)] = \ + torch.tensor(lprobs) + assert isinstance(log_probs, torch.Tensor) + outputs['log_probs'] = log_probs.to(cuda_device) + if output_cum_log_probs: - outputs['cum_log_probs'] = [ - a.result.cum_log_probs for a in responses - if a.result.cum_log_probs is not None - ] - outputs['cum_log_probs'] = torch.tensor( - outputs['cum_log_probs'], device=cuda_device) - input_lengths = torch.tensor([x.size(0) for x in batch_input_ids], - dtype=torch.int32, - device=cuda_device) + cum_log_probs = torch.zeros((num_output_sequences, num_beams), + dtype=torch.float32) + for response in responses: + if response.result.cum_log_probs is not None: + cum_log_probs[req_idx(response), :] = \ + torch.tensor(response.result.cum_log_probs) + outputs['cum_log_probs'] = cum_log_probs.to(cuda_device) + outputs = self._prepare_outputs(outputs, input_lengths) else: outputs = output_ids + return outputs diff --git a/tensorrt_llm/tools/multimodal_builder.py b/tensorrt_llm/tools/multimodal_builder.py index 3c66b5261..5981441d5 100644 --- a/tensorrt_llm/tools/multimodal_builder.py +++ b/tensorrt_llm/tools/multimodal_builder.py @@ -18,7 +18,6 @@ Pix2StructForConditionalGeneration, VisionEncoderDecoderModel) # isort: on -import json import math import torch.nn.functional as F @@ -731,13 +730,6 @@ def build_phi_engine(args): images=raw_image, return_tensors="pt")['pixel_values'].to( args.device, torch.float16) - try: - with open(f"{args.model_path}/preprocessor_config.json", "r") as file: - config = file.read() - config_dict = json.loads(config) - num_crops = config_dict.get("num_crops") - except: - num_crops = 16 class Phi3VisionWrapper(torch.nn.Module): @@ -792,7 +784,8 @@ def apply_img_projection(self, input): tensors = {"glb_GN": glb_GN, "sub_GN": sub_GN} save_file(tensors, args.output_dir + "/image_newlines.safetensors") export_onnx(wrapper, image, f'{args.output_dir}/onnx') - build_trt_engine( - args.model_type, [image.shape[1], image.shape[2], image.shape[3]], - f'{args.output_dir}/onnx', args.output_dir, - args.max_batch_size * (num_crops + 1)) #TODO: Take input from config + num_crops = processor.image_processor.num_crops + build_trt_engine(args.model_type, + [image.shape[1], image.shape[2], image.shape[3]], + f'{args.output_dir}/onnx', args.output_dir, + args.max_batch_size * (num_crops + 1)) diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index b1ebedec5..ed0a116b5 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.14.0.dev2024092401" +__version__ = "0.14.0.dev2024100100" diff --git a/tests/attention/test_gpt_attention.py b/tests/attention/test_gpt_attention.py index 9fdbaec30..dc7758898 100644 --- a/tests/attention/test_gpt_attention.py +++ b/tests/attention/test_gpt_attention.py @@ -40,13 +40,18 @@ RotaryScalingType) from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import GenerationSequence, KVCacheManager +from tensorrt_llm.runtime import GenerationSequence +from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \ + PoolsKVCacheManager sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import (getSMVersion, skip_bf16_fp32_accum, skip_bf16_pre_ampere, skip_fp8_pre_ada, skip_fp32_accum_pre_ampere, unittest_name_func) +from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \ + MemoryPoolsAllocator + class TestFunctional(unittest.TestCase): @@ -399,11 +404,12 @@ def test_gpt_attention(self, def _construct_execution( session, input_tensor, weight, bias, past_key_value, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, - packed_mask_for_fmha, sequence_length, - host_past_key_value_lengths, host_max_attention_window_sizes, - host_sink_token_length, context_lengths, host_context_lengths, - cache_indirection, host_request_types, num_heads, hidden_size, - num_kv_heads, output, dtype, max_context_length, shape_dict, + host_kv_cache_pool_mapping, packed_mask_for_fmha, + sequence_length, host_past_key_value_lengths, + host_max_attention_window_sizes, host_sink_token_length, + context_lengths, host_context_lengths, cache_indirection, + host_request_types, num_heads, hidden_size, num_kv_heads, + output, dtype, max_context_length, shape_dict, kv_int8_quant_scale, kv_int8_dequant_scale, configuration, host_runtime_perf_knobs): kv_cache_block_offsets = None @@ -480,6 +486,7 @@ def _construct_execution( kv_cache_block_offsets_tensor = None host_kv_cache_block_offsets_tensor = None host_kv_cache_pool_pointers_tensor = None + host_kv_cache_pool_mapping_tensor = None if paged_kv_cache: kv_cache_block_offsets_tensor = Tensor( name='kv_cache_block_offsets', @@ -491,8 +498,15 @@ def _construct_execution( dtype=tensorrt_llm.str_dtype_to_trt('int32')) host_kv_cache_pool_pointers_tensor = Tensor( name='host_kv_cache_pool_pointers', - shape=(1, ), + shape=( + 1, + 1, + ), dtype=tensorrt_llm.str_dtype_to_trt('int64')) + host_kv_cache_pool_mapping_tensor = Tensor( + name='host_kv_cache_pool_mapping', + shape=(1, ), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) else: past_key_value_tensor = Tensor( name='past_key_value', @@ -606,6 +620,7 @@ def _construct_execution( host_kv_cache_block_offsets_tensor, host_kv_cache_pool_pointers= host_kv_cache_pool_pointers_tensor, + host_kv_cache_pool_mapping=host_kv_cache_pool_mapping_tensor, max_context_length=max_context_length, qkv_bias=qkv_bias, host_runtime_perf_knobs=host_runtime_perf_knobs_tensor) @@ -639,6 +654,8 @@ def _construct_execution( 'host_kv_cache_block_offsets'] = host_kv_cache_block_offsets inputs[ 'host_kv_cache_pool_pointers'] = host_kv_cache_pool_pointers + inputs[ + 'host_kv_cache_pool_mapping'] = host_kv_cache_pool_mapping else: inputs['past_key_value'] = past_key_value @@ -725,24 +742,34 @@ def _construct_execution( dtype=torch_kv_cache_dtype, device='cuda') host_kv_cache_pool_pointers = None + host_kv_cache_pool_mapping = None # Init KV cache block manager if paged_kv_cache: - block_size = plugin_kv_num_heads * tokens_per_block * head_size - kv_cache_manager = KVCacheManager( - num_layers=1, + memory_pools_allocator = MemoryPoolsAllocator( num_blocks=num_blocks, - block_size=block_size, tokens_per_block=tokens_per_block, - max_blocks_per_seq=max_blocks_per_seq, + head_size=head_size) + + num_kv_heads_per_layer = MemoryPoolsAllocator.prepare_num_kv_heads_per_layer( + plugin_kv_num_heads, 1) + memory_pools_allocator.allocate(dtype, num_kv_heads_per_layer) + pools_kv_cache_manager = PoolsKVCacheManager( + memory_pools_allocator.pools_metadata, + max_blocks_per_seq, + num_blocks, + tokens_per_block, + head_size, max_attention_window_size=max_seq_len, - sink_token_len=sink_token_len, - beam_width=beam_width) + beam_width=beam_width, + sink_token_len=sink_token_len) + host_kv_cache_pool_pointers = torch.tensor( [present_key_value.data_ptr(), 0], dtype=torch.int64) + host_kv_cache_pool_mapping = memory_pools_allocator.pool_mapping # Add sequences to the kv_cache_manager for bi in range(batch_size): - kv_cache_manager.add_sequence( + pools_kv_cache_manager.add_sequence( GenerationSequence(seq_idx=bi, batch_idx=bi), in_len) weight = torch.randn(shape_dict['weight'], @@ -992,6 +1019,10 @@ def verify_kv_cache(torch_present): if not use_int8_kv_cache and not use_fp8_kv_cache and num_kv_heads == num_heads and beam_width == 1: if paged_kv_cache: + assert pools_kv_cache_manager.has_single_pool( + ) is True, f"Current test assuming only one memory pool" + kv_cache_manager = pools_kv_cache_manager.get_single_kv_cache_manager( + ) kv_cache_cont = kv_cache_manager.blocks_manager.get_continuous_caches( present_key_value) kv_cache_cont = kv_cache_cont.permute(1, 0, 2) @@ -1054,9 +1085,12 @@ def verify_kv_cache(torch_present): kv_cache_block_offsets = None if paged_kv_cache: # Get arrays of pointers to the "pages" of KV values + assert pools_kv_cache_manager.has_single_pool( + ) is True, f"Current test assuming only one memory pool" + kv_cache_manager = pools_kv_cache_manager.get_single_kv_cache_manager( + ) kv_cache_block_offsets = kv_cache_manager.get_block_offsets( beam_width) - if step == 0: host_request_types = torch.tensor([0] * batch_size, dtype=torch.int32) @@ -1181,8 +1215,9 @@ def verify_kv_cache(torch_present): session, output, present_key_value = _construct_execution( session, input_tensor, weight_plugin, bias_plugin, present_key_value, kv_cache_block_offsets, - host_kv_cache_pool_pointers, packed_mask_for_fmha, - sequence_length, host_past_key_value_lengths, + host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, + packed_mask_for_fmha, sequence_length, + host_past_key_value_lengths, host_max_attention_window_sizes, host_sink_token_length, input_lengths, host_context_lengths, cache_indirection, host_request_types, num_heads, hidden_size, num_kv_heads, @@ -1191,7 +1226,6 @@ def verify_kv_cache(torch_present): context_host_runtime_perf_knobs) del session session = None - # Note: Volta has larger errors. # We speculate it’s because Volta’s TC is smaller and more calculations are required, # which may lead to more error accumulation. @@ -1353,7 +1387,8 @@ def tile_beam_width(tensor: torch.Tensor, num_beams: int): session, tiled_output, present_key_value = _construct_execution( session, tiled_input_tensor, weight_plugin, bias_plugin, tiled_present_key_value, kv_cache_block_offsets, - host_kv_cache_pool_pointers, None, tiled_sequence_length, + host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, + None, tiled_sequence_length, tiled_host_past_key_value_lengths, host_max_attention_window_sizes, host_sink_token_length, tiled_input_lengths, tiled_host_context_lengths, @@ -1374,7 +1409,7 @@ def tile_beam_width(tensor: torch.Tensor, num_beams: int): if paged_kv_cache: # Iterate to the next step. Increase number of tokens for all unfinished sequences # And allocate new blocks if needed - kv_cache_manager.step([False] * batch_size) + pools_kv_cache_manager.step([False] * batch_size) # assert False, "Force fail" return diff --git a/tests/attention/test_gpt_attention_IFB.py b/tests/attention/test_gpt_attention_IFB.py index 08d327fe3..4e6e67c12 100644 --- a/tests/attention/test_gpt_attention_IFB.py +++ b/tests/attention/test_gpt_attention_IFB.py @@ -45,13 +45,18 @@ RotaryScalingType) from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import GenerationSequence, KVCacheManager +from tensorrt_llm.runtime import GenerationSequence sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import (skip_bf16_fp32_accum, skip_bf16_pre_ampere, skip_fp8_pre_ada, skip_fp32_accum_pre_ampere, unittest_name_func) +from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \ + MemoryPoolsAllocator +from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \ + PoolsKVCacheManager + class TestFunctional(unittest.TestCase): @@ -217,6 +222,7 @@ def _construct_execution(session, bias, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping, sequence_length, host_past_key_value_lengths, host_max_attention_window_sizes, @@ -291,8 +297,15 @@ def _construct_execution(session, dtype=tensorrt_llm.str_dtype_to_trt('int32')) host_kv_cache_pool_pointers_tensor = Tensor( name='host_kv_cache_pool_pointers', - shape=(1, ), + shape=( + 1, + 1, + ), dtype=tensorrt_llm.str_dtype_to_trt('int64')) + host_kv_cache_pool_mapping_tensor = Tensor( + name='host_kv_cache_pool_mapping', + shape=(1, ), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) host_runtime_perf_knobs_tensor = Tensor( name='host_runtime_perf_knobs', shape=[16], @@ -419,6 +432,7 @@ def _construct_execution(session, host_kv_cache_block_offsets_tensor, host_kv_cache_pool_pointers= host_kv_cache_pool_pointers_tensor, + host_kv_cache_pool_mapping=host_kv_cache_pool_mapping_tensor, host_context_lengths=host_context_lengths_tensor, qkv_bias=qkv_bias, host_runtime_perf_knobs=host_runtime_perf_knobs_tensor) @@ -443,6 +457,7 @@ def _construct_execution(session, 'kv_cache_block_offsets': kv_cache_block_offsets, 'host_kv_cache_block_offsets': host_kv_cache_block_offsets, 'host_kv_cache_pool_pointers': host_kv_cache_pool_pointers, + 'host_kv_cache_pool_mapping': host_kv_cache_pool_mapping, 'host_runtime_perf_knobs': host_runtime_perf_knobs } if use_int8_kv_cache or use_fp8_kv_cache: @@ -780,18 +795,27 @@ def torch_exec(step: int, torch.cuda.synchronize() return torch_output, torch_present - # Init KV cache block manager - block_size = plugin_kv_num_heads * tokens_per_block * head_size - kv_cache_manager = KVCacheManager(num_layers=1, - num_blocks=num_blocks, - block_size=block_size, - tokens_per_block=tokens_per_block, - max_blocks_per_seq=max_blocks_per_seq, - max_attention_window_size=max_seq_len, - sink_token_len=sink_token_len, - beam_width=beam_width) - host_kv_cache_pool_pointers = torch.tensor( - [ordered_key_value.data_ptr(), 0], dtype=torch.int64) + # Init Pools KV cache manager + memory_pools_allocator = MemoryPoolsAllocator( + num_blocks=num_blocks, + tokens_per_block=tokens_per_block, + head_size=head_size) + num_kv_heads_per_layer = MemoryPoolsAllocator.prepare_num_kv_heads_per_layer( + plugin_kv_num_heads, 1) + memory_pools_allocator.allocate(dtype, num_kv_heads_per_layer) + pools_kv_cache_manager = PoolsKVCacheManager( + memory_pools_allocator.pools_metadata, + max_blocks_per_seq, + num_blocks, + tokens_per_block, + head_size, + max_attention_window_size=max_seq_len, + beam_width=beam_width, + sink_token_len=sink_token_len) + + host_kv_cache_pool_pointers = memory_pools_allocator.get_kv_cache_pool_pointers( + ) + host_kv_cache_pool_mapping = memory_pools_allocator.pool_mapping print("pool ptr ", ordered_key_value.data_ptr()) torch_cache_list = [None] * num_req @@ -848,11 +872,15 @@ def torch_exec(step: int, # Add sequence to the manager sequence = GenerationSequence(seq_idx=iteration, batch_idx=iteration) - kv_cache_manager.add_sequence(sequence, in_len_req.clone()) + pools_kv_cache_manager.add_sequence(sequence, + in_len_req.clone()) # Get arrays of pointers to the "pages" of KV values - offset_array = kv_cache_manager.get_block_offsets(beam_width) - dense_offset_array = offset_array[sequence_selection] + offset_array = pools_kv_cache_manager.get_block_offsets(beam_width) + assert offset_array.shape[ + 0] == 1, f"test is suppose to use only one pool. sequence_selection is based on a single pool" + # assume only one pool + dense_offset_array = offset_array[0][sequence_selection] host_input_lengths = np.concatenate(input_length_list) host_input_lengths = torch.tensor(host_input_lengths, @@ -1022,11 +1050,11 @@ def torch_exec(step: int, session, output = _construct_execution( session, input_tensor, weight_plugin, bias_plugin, dense_offset_array, host_kv_cache_pool_pointers, - sequence_lengths, host_past_key_value_lengths, - host_max_attention_window_sizes, host_sink_token_length, - context_lengths, max_context_length, cache_indirection, - num_heads, hidden_size, num_kv_heads, output, dtype, - kv_quant_scale, kv_dequant_scale, host_context_lengths, + host_kv_cache_pool_mapping, sequence_lengths, + host_past_key_value_lengths, host_max_attention_window_sizes, + host_sink_token_length, context_lengths, max_context_length, + cache_indirection, num_heads, hidden_size, num_kv_heads, output, + dtype, kv_quant_scale, kv_dequant_scale, host_context_lengths, host_request_types, generation_host_runtime_perf_knobs, use_fp8_context_fmha, atten_output_quant_scale) @@ -1050,7 +1078,7 @@ def torch_exec(step: int, finished = [False for _ in range(cache_num_req)] # Iterate to the next step. Increase number of tokens for all unfinished sequences # And allocate new blocks if needed - kv_cache_manager.step(finished) + pools_kv_cache_manager.step(finished) if __name__ == "__main__": diff --git a/tests/bindings/test_bindings_ut.py b/tests/bindings/test_bindings_ut.py index 573368ead..3cb1ba598 100644 --- a/tests/bindings/test_bindings_ut.py +++ b/tests/bindings/test_bindings_ut.py @@ -40,9 +40,10 @@ def test_model_config(): num_heads = 16 hidden_size = 768 data_type = _tb.DataType.FLOAT - model_config = _tb.ModelConfig(vocab_size, num_attention_layers, - num_rnn_layers, num_heads, hidden_size, - data_type) + model_config = _tb.ModelConfig(vocab_size, + num_attention_layers + num_rnn_layers, + num_attention_layers, num_rnn_layers, + num_heads, hidden_size, data_type) assert model_config.vocab_size == vocab_size assert model_config.num_attention_layers() == num_attention_layers assert model_config.num_rnn_layers() == num_rnn_layers @@ -53,10 +54,23 @@ def test_model_config(): assert model_config.vocab_size_padded(1) is not None assert model_config.size_per_head == hidden_size // num_heads - assert model_config.num_kv_heads == num_heads + num_kv_heads_per_layer = model_config.num_kv_heads_per_layer + for layer_idx in range(num_attention_layers): + assert model_config.num_kv_heads(layer_idx) == num_heads + assert num_kv_heads_per_layer[layer_idx] == num_heads + num_kv_heads = 1 - model_config.num_kv_heads = num_kv_heads - assert model_config.num_kv_heads == num_kv_heads + model_config.set_num_kv_heads(num_kv_heads) + num_kv_heads_per_layer = model_config.num_kv_heads_per_layer + for layer_idx in range(num_attention_layers): + assert model_config.num_kv_heads(layer_idx) == num_kv_heads + assert num_kv_heads_per_layer[layer_idx] == num_kv_heads + + num_kv_heads_per_layer[-1] = 2 + model_config.num_kv_heads_per_layer = num_kv_heads_per_layer + for nheads, ref in zip(model_config.num_kv_heads_per_layer, + num_kv_heads_per_layer): + assert nheads == ref assert not model_config.use_gpt_attention_plugin model_config.use_gpt_attention_plugin = True @@ -182,6 +196,7 @@ def check_empty_then_set(member, value): def test_gpt_json_config(): model_config = { "vocab_size": 1000, + "num_layers": 18, # >= attn + rnn "num_attention_layers": 12, "num_rnn_layers": 2, "num_heads": 4, diff --git a/tests/bindings/test_executor_bindings.py b/tests/bindings/test_executor_bindings.py index 1a017313f..315b55158 100644 --- a/tests/bindings/test_executor_bindings.py +++ b/tests/bindings/test_executor_bindings.py @@ -1023,6 +1023,11 @@ def test_scheduler_config(): assert config.capacity_scheduler_policy == capacity_scheduler_policy assert config.context_chunking_policy == None + capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.STATIC_BATCH + config = trtllm.SchedulerConfig(capacity_scheduler_policy) + assert config.capacity_scheduler_policy == capacity_scheduler_policy + assert config.context_chunking_policy == None + context_chunking_policy = trtllm.ContextChunkingPolicy.FIRST_COME_FIRST_SERVED config = trtllm.SchedulerConfig(capacity_scheduler_policy, context_chunking_policy) diff --git a/tests/hlapi/apps/_test_llm_server.py b/tests/hlapi/apps/_test_llm_server.py index bf6b97881..ff276683e 100644 --- a/tests/hlapi/apps/_test_llm_server.py +++ b/tests/hlapi/apps/_test_llm_server.py @@ -31,6 +31,11 @@ def test_health(client): assert response.status_code == 200 +def test_health(client): + response = client.get("/health") + assert response.status_code == 200 + + def test_generate(client): response = client.post("/generate", json={"prompt": "A B C"}) assert response.status_code == 200 diff --git a/tests/hlapi/apps/_test_openai_chat.py b/tests/hlapi/apps/_test_openai_chat.py index 22a77acb3..bcc93b70b 100644 --- a/tests/hlapi/apps/_test_openai_chat.py +++ b/tests/hlapi/apps/_test_openai_chat.py @@ -4,6 +4,7 @@ import sys from typing import List +import numpy as np import openai import pytest from openai_server import RemoteOpenAIServer @@ -49,6 +50,7 @@ def test_single_chat_session(client: openai.OpenAI, model_name: str): model=model_name, messages=messages, max_tokens=10, + logprobs=True, ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 @@ -56,9 +58,18 @@ def test_single_chat_session(client: openai.OpenAI, model_name: str): message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" - messages.append({"role": "assistant", "content": message.content}) + + # test logprobs + logprobs = chat_completion.choices[0].logprobs.content + assert len(logprobs) == 10 + for logprob in logprobs: + assert logprob.token is not None + assert logprob.logprob is not None + assert logprob.bytes is not None + assert len(logprob.top_logprobs) == 0 # test multi-turn dialogue + messages.append({"role": "assistant", "content": message.content}) messages.append({"role": "user", "content": "express your result in json"}) chat_completion = client.chat.completions.create( model=model_name, @@ -99,8 +110,13 @@ async def test_chat_streaming(async_client: openai.AsyncOpenAI, messages=messages, max_tokens=10, temperature=0.0, + logprobs=True, ) output = chat_completion.choices[0].message.content + logprobs = [ + logprob_content.logprob + for logprob_content in chat_completion.choices[0].logprobs.content + ] # test streaming stream = await async_client.chat.completions.create( @@ -108,18 +124,28 @@ async def test_chat_streaming(async_client: openai.AsyncOpenAI, messages=messages, max_tokens=10, temperature=0.0, + logprobs=True, stream=True, ) - chunks: List[str] = [] + str_chunks: List[str] = [] + logprob_chunks: List[float] = [] + # TODO{pengyunl}: add stop_reason test when supported async for chunk in stream: delta = chunk.choices[0].delta + if logprob_chunk := chunk.choices[0].logprobs: + assert len(logprob_chunk.content) == 1 + assert len(logprob_chunk.content[0].top_logprobs) == 0 + logprob_chunks.append(logprob_chunk.content[0].logprob) if delta.role: assert delta.role == "assistant" if delta.content: - chunks.append(delta.content) + str_chunks.append(delta.content) assert delta.content - assert "".join(chunks) == output + assert "".join(str_chunks) == output + assert len(logprob_chunks) == len(logprobs) + logprobs, logprob_chunks = np.array(logprobs), np.array(logprob_chunks) + assert np.allclose(logprobs, logprob_chunks) @pytest.mark.asyncio diff --git a/tests/hlapi/run_llm.py b/tests/hlapi/run_llm.py index f2bea7045..fd69ae7bf 100644 --- a/tests/hlapi/run_llm.py +++ b/tests/hlapi/run_llm.py @@ -3,7 +3,7 @@ import click -from tensorrt_llm.hlapi import LLM, SamplingParams +from tensorrt_llm.hlapi import LLM, KvCacheConfig, SamplingParams @click.command() @@ -11,7 +11,9 @@ @click.option("--tp_size", type=int, required=True) @click.option("--engine_dir", type=str, default=None) def main(model_dir: str, tp_size: int, engine_dir: str): - llm = LLM(model_dir, tensor_parallel_size=tp_size) + llm = LLM(model_dir, + tensor_parallel_size=tp_size, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) if engine_dir is not None and os.path.abspath( engine_dir) != os.path.abspath(model_dir): diff --git a/tests/hlapi/test_build_cache.py b/tests/hlapi/test_build_cache.py index d438b5b04..7b1cef7ef 100644 --- a/tests/hlapi/test_build_cache.py +++ b/tests/hlapi/test_build_cache.py @@ -1,14 +1,12 @@ import json from tempfile import TemporaryDirectory -from tensorrt_llm.builder import EngineConfig from tensorrt_llm.hlapi.build_cache import * -from tensorrt_llm.hlapi.llm_utils import CachedModelLoader, LlmArgs, ModelLoader try: - from test_llm import llama_model_path + pass except ImportError: - from .test_llm import llama_model_path + pass def test_BuildStep(): @@ -16,7 +14,7 @@ def test_BuildStep(): build_cache = BuildCache(BuildCacheConfig(Path(tempdir))) build_step = build_cache.get_engine_building_cache_stage( build_config=BuildConfig(), hf_model_name="test") - assert not build_step.cache_hitted() + assert not build_step.is_cached() print(build_step.get_cache_path().absolute()) assert build_step.get_cache_metadata( )["version"] == BuildCache.CACHE_VERSION @@ -26,7 +24,7 @@ def test_BuildStep(): with open(product_path / 'config.json', 'w') as f: f.write(json.dumps({"a": 1, "b": 2})) - assert build_step.cache_hitted() + assert build_step.is_cached() def test_BuildCache_clean_untracked_path(): @@ -49,7 +47,7 @@ def test_BuildCache_clean_cache_exceed_record_limit(): def create_cache(hf_model_name: str): step = build_cache.get_engine_building_cache_stage( build_config=build_config, hf_model_name=hf_model_name) - assert not step.cache_hitted() + assert not step.is_cached() with step.write_guard() as product_path: product_path.mkdir() with open(product_path / 'config.json', 'w') as f: @@ -92,34 +90,6 @@ def test_build_cache_prune_untracked_files(): assert not (build_cache.cache_root / 'broken_cache').exists() -def test_build_get_updated_build_cache(): - # Test the build method in commands/build.py get an updated BuildConfig that is the same as the real one in engine - # directory - build_config = BuildConfig() - build_config.max_batch_size = 100 - build_config.max_beam_width = 4 - - args = LlmArgs(model=llama_model_path, - build_config=build_config, - enable_tqdm=True) - args.setup() - ml = ModelLoader(args) - - with TemporaryDirectory() as engine_dir: - engine_dir = ml(Path(engine_dir)) - - updated_build_config = CachedModelLoader.get_final_build_config( - args, Path(llama_model_path)) - - actual_build_config = EngineConfig.from_json_file( - engine_dir / 'config.json').build_config - - assert BuildCache.prune_build_config_for_cache_key( - updated_build_config.to_dict( - )) == BuildCache.prune_build_config_for_cache_key( - actual_build_config.to_dict()) - - if __name__ == '__main__': #test_build_get_updated_build_cache() test_build_cache_prune_untracked_files() diff --git a/tests/hlapi/test_llm.py b/tests/hlapi/test_llm.py index f21875227..7ebc7d717 100644 --- a/tests/hlapi/test_llm.py +++ b/tests/hlapi/test_llm.py @@ -4,12 +4,13 @@ import sys import tempfile import time -from typing import List, Optional +from typing import List, Optional, Union import pytest import torch import transformers +from tensorrt_llm._utils import release_gc from tensorrt_llm.executor import (ExecutorBindingsWorker, GenerationRequest, GenerationResult, LoRARequest) from tensorrt_llm.hlapi import (LLM, BuildCacheConfig, KvCacheConfig, @@ -18,7 +19,6 @@ from tensorrt_llm.hlapi.tokenizer import TokenizerBase, TransformersTokenizer from tensorrt_llm.hlapi.utils import get_total_gpu_memory from tensorrt_llm.lora_manager import LoraConfig -from tensorrt_llm.models import PretrainedConfig sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.llm_data import llm_models_root @@ -72,24 +72,41 @@ def llm_test_harness(model_dir: str, ref, threshold=similar_threshold) + del llm + release_gc() + def llm_check_output(llm: LLM, inputs: List[str], references: List[str], + *, + sampling_params: Optional[SamplingParams] = None, similar_threshold: float = 0.8, - *gen_args, + finish_reasons: Optional[List[str]] = None, + stop_reasons: Optional[List[Union[int, str]]] = None, **gen_kwargs): - outputs = llm.generate(inputs, *gen_args, **gen_kwargs) + outputs = llm.generate(inputs, + sampling_params=sampling_params, + **gen_kwargs) assert len(outputs) == len(references) - for output, target_output in zip(outputs, references): + for i, (output, target_output) in enumerate(zip(outputs, references)): if isinstance(target_output, list): # N output assert len(output.outputs) == len(target_output) - for out, ref in zip(output.outputs, target_output): + for j, (out, ref) in enumerate(zip(output.outputs, target_output)): assert similar(out.text, ref, threshold=similar_threshold) + if finish_reasons is not None: + assert out.finish_reason == finish_reasons[i][j] + if stop_reasons is not None: + assert out.stop_reason == stop_reasons[i][j] else: - assert similar(output.outputs[0].text, target_output) + out = output.outputs[0] + assert similar(out.text, target_output, threshold=similar_threshold) + if finish_reasons is not None: + assert out.finish_reason == finish_reasons[i] + if stop_reasons is not None: + assert out.stop_reason == stop_reasons[i] default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" @@ -101,6 +118,7 @@ def llm_check_output(llm: LLM, cnn_dailymail_path = str(llm_models_root() / "datasets" / "cnn_dailymail") prompts = ["A B C"] +global_kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) @force_ampere @@ -116,7 +134,7 @@ def test_llm_build_config(): llm = LLM(model=llama_model_path, build_config=build_config, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) + kv_cache_config=global_kvcache_config) tmpdir = tempfile.TemporaryDirectory() llm.save(tmpdir.name) @@ -124,8 +142,6 @@ def test_llm_build_config(): # read the build_config and check if the parameters are correctly saved engine_config = json.load(f) - pretrained_config = PretrainedConfig.from_dict( - engine_config["pretrained_config"]) build_config1 = BuildConfig.from_dict(engine_config["build_config"]) # Know issue: this will be converted to None after save engine for single-gpu @@ -140,11 +156,10 @@ def test_llm_build_config(): def test_llm_loading_from_hf(): sampling_params = SamplingParams(max_tokens=8) - llm_test_harness( - llama_model_path, - prompts, ["D E F G H I J K"], - sampling_params=sampling_params, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) + llm_test_harness(llama_model_path, + prompts, ["D E F G H I J K"], + sampling_params=sampling_params, + kv_cache_config=global_kvcache_config) @force_ampere @@ -157,12 +172,11 @@ def test_llm_loading_from_ckpt(): llama.save_checkpoint(ckpt_dir.name) del llama - llm_test_harness( - ckpt_dir.name, - prompts, ["D E F G H I J K"], - tokenizer=tokenizer, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - sampling_params=SamplingParams(max_tokens=8)) + llm_test_harness(ckpt_dir.name, + prompts, ["D E F G H I J K"], + tokenizer=tokenizer, + kv_cache_config=global_kvcache_config, + sampling_params=SamplingParams(max_tokens=8)) @pytest.mark.parametrize('model_format', ['hf', 'ckpt']) @@ -180,13 +194,13 @@ def test_llm_with_dummy_weights(model_format): tokenizer.save_pretrained(dummy_dir.name) sampling_params = SamplingParams(max_tokens=8) - llm_test_harness( - dummy_dir.name, - prompts, ["A placeholder reference for dummy-weight engine."], - sampling_params=sampling_params, - similar_threshold=0.0, - load_format='dummy', - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) + llm_test_harness(dummy_dir.name, + prompts, + ["A placeholder reference for dummy-weight engine."], + sampling_params=sampling_params, + similar_threshold=0.0, + load_format='dummy', + kv_cache_config=global_kvcache_config) class MyTokenizer(TokenizerBase): @@ -225,19 +239,16 @@ def test_llm_with_customized_tokenizer(): model=llama_model_path, # a customized tokenizer is passed to override the default one tokenizer=MyTokenizer.from_pretrained(llama_model_path), - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - ) + kv_cache_config=global_kvcache_config) for output in llm.generate(prompts): print(output) def test_llm_without_tokenizer(): - llm = LLM( - model=llama_model_path, - skip_tokenizer_init=True, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - ) + llm = LLM(model=llama_model_path, + skip_tokenizer_init=True, + kv_cache_config=global_kvcache_config) sampling_params = SamplingParams(end_id=2, pad_id=2, max_tokens=8) @@ -272,15 +283,13 @@ def _test_llm_generate_async(model_name=default_model_name, if "Mixtral" in model_name and use_auto_parallel: pytest.skip("Auto parallel is not supported for Mixtral models") - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) - tp_size = tp_size if not use_auto_parallel else 1 world_size = tp_size if use_auto_parallel else None llm = LLM( model=get_model_path(model_name), tokenizer=tokenizer, - kv_cache_config=kv_cache_config, + kv_cache_config=global_kvcache_config, tensor_parallel_size=tp_size, auto_parallel=use_auto_parallel, world_size=world_size, @@ -364,7 +373,7 @@ async def main(): def llm_for_sampling_params() -> LLM: llm = LLM( model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=global_kvcache_config, ) return llm @@ -475,7 +484,7 @@ def test_generate_with_OutputConfig(gather_context_logits: bool, llm = LLM( model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=global_kvcache_config, build_config=build_config, ) sampling_params = SamplingParams( @@ -502,40 +511,68 @@ def test_generate_with_OutputConfig(gather_context_logits: bool, def test_generate_with_stop_words(): llm = LLM( model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=global_kvcache_config, ) stop_id = llm.tokenizer.encode("N", add_special_tokens=False)[-1] llm_check_output(llm, prompts, ["D E F G H I J K L M"], - sampling_params=SamplingParams(stop_token_ids=[stop_id])) + sampling_params=SamplingParams(end_id=stop_id), + finish_reasons=['stop'], + stop_reasons=[None]) + + llm_check_output(llm, + prompts, ["D E F G H"], + sampling_params=SamplingParams(max_tokens=5), + finish_reasons=['length'], + stop_reasons=[None]) + + llm_check_output(llm, + prompts, ["D E F G H I J K L M"], + sampling_params=SamplingParams(stop_token_ids=[stop_id]), + finish_reasons=['stop'], + stop_reasons=[stop_id]) llm_check_output(llm, prompts, ["D E F G H I J K L M N"], sampling_params=SamplingParams( stop_token_ids=[stop_id], - include_stop_str_in_output=True)) + include_stop_str_in_output=True), + finish_reasons=['stop'], + stop_reasons=[stop_id]) llm_check_output(llm, prompts, ["D E F G H"], - sampling_params=SamplingParams(stop="I J")) + sampling_params=SamplingParams(stop="I J"), + finish_reasons=['stop'], + stop_reasons=["I J"]) + + llm_check_output(llm, + prompts, ["D E F G H I J K L M"], + sampling_params=SamplingParams(stop="I E", max_tokens=10), + finish_reasons=['length'], + stop_reasons=[None]) llm_check_output(llm, prompts, ["D E F G H I J"], sampling_params=SamplingParams( - stop="I J", include_stop_str_in_output=True)) + stop="I J", include_stop_str_in_output=True), + finish_reasons=['stop'], + stop_reasons=["I J"]) llm_check_output(llm, prompts, ["D E F G H"], sampling_params=SamplingParams(stop=["F E", "I J"], - stop_token_ids=[stop_id])) + stop_token_ids=[stop_id]), + finish_reasons=['stop'], + stop_reasons=["I J"]) @force_ampere def test_generate_with_bad_words(): llm = LLM( model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=global_kvcache_config, ) bad_id = llm.tokenizer.encode("N", add_special_tokens=False)[-1] @@ -600,6 +637,7 @@ def llama_v2_13b_lora_test_harness(**llm_kwargs): hf_model_dir = get_model_path("llama-models-v2/llama-v2-13b-hf") hf_lora_dir = get_model_path("llama-models-v2/chinese-llama-2-lora-13b") + # For LoRA checkpoints with finetuned embedding and lm_head, lora_dir must be provided at build time. build_config = BuildConfig(lora_config=LoraConfig(lora_dir=[hf_lora_dir])) llm = LLM(hf_model_dir, tokenizer=hf_lora_dir, @@ -630,8 +668,10 @@ def llama_7b_multi_lora_test_harness(**llm_kwargs): hf_lora_dir1 = get_model_path("llama-models/luotuo-lora-7b-0.1") hf_lora_dir2 = get_model_path("llama-models/Japanese-Alpaca-LoRA-7b-v0") + # For LoRA checkpoints without finetuned embedding and lm_head, we can either: + # (1) specify lora_target_modules, or + # (2) provide a lora_dir to infer the lora_target_modules. build_config = BuildConfig(lora_config=LoraConfig( - lora_dir=[hf_lora_dir1, hf_lora_dir2], lora_target_modules=['attn_q', 'attn_k', 'attn_v'])) llm = LLM(hf_model_dir, enable_lora=True, @@ -692,10 +732,12 @@ def test_generate_block_reuse(): for output in llm.generate(prompts, sampling_params=sampling_params): print(output) + del llm + release_gc() + def test_executor_results_cleanup(): - llm = LLM(model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) + llm = LLM(model=llama_model_path, kv_cache_config=global_kvcache_config) sampling_params = SamplingParams(max_tokens=6) for i in range(20): llm.generate(prompts, sampling_params=sampling_params) @@ -704,6 +746,41 @@ def test_executor_results_cleanup(): print(f"result.size: {num_remaining_results}") assert num_remaining_results == 0 + del llm + release_gc() + + +@pytest.mark.parametrize("trust_remote_code", [True, False]) +def _test_llm_trust_remote_code(trust_remote_code: bool): + # OOM when tested with other cases + # TODO[chunweiy]: Enable this later + release_gc() + + if trust_remote_code: + internlm_model_path = get_model_path("internlm-chat-7b") + llm = LLM(model=internlm_model_path, + trust_remote_code=trust_remote_code, + tokenizer=TransformersTokenizer.from_pretrained( + internlm_model_path, trust_remote_code=trust_remote_code), + kv_cache_config=global_kvcache_config) + sampling_params = SamplingParams(max_tokens=6, + temperature=0.8, + top_p=0.95) + prompts = [ + "The future of AI is", + ] + + for output in llm.generate(prompts, sampling_params=sampling_params): + print(output) + del llm + release_gc() + else: + with pytest.raises(ValueError): + llm = LLM(model="internlm/internlm-chat-7b", + trust_remote_code=trust_remote_code, + tokenizer="internlm/internlm-chat-7b", + kv_cache_config=global_kvcache_config) + def test_llm_build_cache(): # Activate the build-cache @@ -712,70 +789,47 @@ def test_llm_build_cache(): def first_run(): llm = LLM(model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=global_kvcache_config, enable_build_cache=cache_config) llm_check_output(llm, prompts, ["D E F G H I J K"], sampling_params=sampling_params) + del llm + release_gc() + def second_run(): llm = LLM(model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=global_kvcache_config, enable_build_cache=cache_config) llm_check_output(llm, prompts, ["D E F G H I J K"], sampling_params=sampling_params) # the cache should be hitted - assert llm.llm_build_stats.cache_hitted + assert llm.llm_build_stats.cache_hitted, llm.llm_build_stats.cache_info + del llm + release_gc() first_run() second_run() -def test_executor_catching_exception(): +class DummyError(Exception): + pass - class DummyError(Exception): - pass - llm = LLM(model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) +class DummyExecutorMeta(type): - sampling_params = SamplingParams(max_tokens=6) - - def test_before_generation(): - # Since we cannot alter the executor's behavior, we put a dummy error in the error queue - llm._executor._error_queue.put(DummyError("Test exception")) - # The dummy error should be caught and raised in the main thread. - with pytest.raises(DummyError): - for output in llm.generate(prompts, - sampling_params=sampling_params): - pass - - def test_during_generation(): - with pytest.raises(DummyError): - prompts = ["A B C"] * 10 - futures = [] - for no, prompt in enumerate(prompts): - futures.append( - llm.generate_async(prompt, sampling_params=sampling_params)) - if no == 3: - # This exception should be caught and raised in the main thread. Before the 4-th output is generated. - llm._executor._error_queue.put(DummyError("Test exception")) - - for no, future in enumerate(futures): - print(future.result()) - - test_before_generation() - test_during_generation() + def __new__(cls, name, bases, dic, worker_cls): + new_cls = super().__new__(cls, name, bases, dic) + @classmethod + def create(cls, engine, executor_config, *args, **kwargs): + return worker_cls(engine=engine, executor_config=executor_config) -class DummyExecutor: - - @classmethod - def create(cls, engine, executor_config, *args, **kwargs): - return DummyExecutorWorker(engine=engine, - executor_config=executor_config) + new_cls.create = create + return new_cls class DummyExecutorWorker(ExecutorBindingsWorker): @@ -806,8 +860,14 @@ def submit(self, request: GenerationRequest) -> GenerationResult: return result +DummyExecutor = DummyExecutorMeta("DummyExecutor", (), {}, + worker_cls=DummyExecutorWorker) + + def test_executor_pending_requests(): - llm = LLM(model=llama_model_path, executor_cls=DummyExecutor) + llm = LLM(model=llama_model_path, + executor_cls=DummyExecutor, + kv_cache_config=global_kvcache_config) # The dummy executor will delay the responses sampling_params = SamplingParams(max_tokens=6) @@ -830,14 +890,6 @@ async def task(): test_streaming() -class DummyExecutor2: - - @classmethod - def create(cls, engine, executor_config, *args, **kwargs): - return DummyExecutorWorker2(engine=engine, - executor_config=executor_config) - - class DummyExecutorWorker2(ExecutorBindingsWorker): def __init__(self, *args, **kwargs): @@ -850,19 +902,25 @@ def await_response_task(self) -> bool: if self.counter == 2: # raise exception on the third token print(f"To raise exception") - raise ValueError("Test exception") + raise DummyError("Test exception") return super().await_response_task() +DummyExecutor2 = DummyExecutorMeta("DummyExecutor2", (), {}, + worker_cls=DummyExecutorWorker2) + + def test_executor_process_background_error(): - llm = LLM(model=llama_model_path, executor_cls=DummyExecutor2) + llm = LLM(model=llama_model_path, + executor_cls=DummyExecutor2, + kv_cache_config=global_kvcache_config) # The dummy executor will delay the responses sampling_params = SamplingParams(max_tokens=6) # test in streaming mode async def task(): - with pytest.raises(ValueError): + with pytest.raises(DummyError): async for output in llm.generate_async( prompts[0], streaming=True, sampling_params=sampling_params): @@ -871,11 +929,61 @@ async def task(): asyncio.run(task()) -# TODO[chunweiy]: Add test for loading inmemory model +def test_llm_apidocs(): + doc = LLM.__doc__ + assert doc + assert doc.find('pipeline_parallel_size') != -1 + assert doc.find('tensor_parallel_size') != -1 + assert doc.find('auto_parallel') != -1 -if __name__ == '__main__': - #test_executor_results_cleanup() - #test_llm_loading_from_hf() - #test_executor_catching_exception() - test_executor_pending_requests() - #test_executor_process_background_error() + +def check_llm_return_context_logits(tp_size=1): + build_config = BuildConfig(gather_context_logits=True) + + llm = LLM(llama_model_path, + tensor_parallel_size=tp_size, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + build_config=build_config) + + sampling_params = SamplingParams(max_tokens=8, return_context_logits=True) + + prompts = ["A B C D E F G H I J K"] * 8 + + for output in llm.generate(prompts, sampling_params=sampling_params): + assert isinstance(output.context_logits, torch.Tensor) + print(output) + + del llm + release_gc() + + +def check_llm_return_generation_logits(tp_size=1): + build_config = BuildConfig(gather_generation_logits=True) + + llm = LLM(llama_model_path, + tensor_parallel_size=tp_size, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + build_config=build_config) + + sampling_params = SamplingParams(max_tokens=8, + return_generation_logits=True) + + prompts = ["A B C D E F G H I J K"] * 8 + + for output in llm.generate(prompts, sampling_params=sampling_params): + assert isinstance(output.outputs[0].generation_logits, torch.Tensor) + print(output) + + del llm + release_gc() + + +def test_llm_return_context_logits(): + check_llm_return_context_logits(tp_size=1) + + +def test_llm_return_generation_logits(): + check_llm_return_generation_logits(tp_size=1) + + +# TODO[chunweiy]: Add test for loading inmemory model diff --git a/tests/hlapi/test_llm_models.py b/tests/hlapi/test_llm_models.py index 06d16f25b..a659f8002 100644 --- a/tests/hlapi/test_llm_models.py +++ b/tests/hlapi/test_llm_models.py @@ -168,7 +168,8 @@ def test_llm_phi_3_small_8k(): inputs=["where is France's capital?"], references=[' Paris is the capital of France. It is known'], sampling_params=sampling_params, - build_config=build_config) + build_config=build_config, + trust_remote_code=True) @force_ampere @@ -225,7 +226,8 @@ def test_llm_glm(): llm_test_harness(glm_model_path, inputs=['A B C'], references=['D E F G H I J K L M'], - sampling_params=sampling_params) + sampling_params=sampling_params, + trust_remote_code=True) @force_ampere @@ -233,7 +235,8 @@ def test_llm_baichuan_7b(): llm_test_harness(baichuan_7b_model_path, inputs=['A B C'], references=['D E F G H I J K L M'], - sampling_params=sampling_params) + sampling_params=sampling_params, + trust_remote_code=True) @force_ampere @@ -241,7 +244,8 @@ def test_llm_baichuan2_7b(): llm_test_harness(baichuan2_7b_model_path, inputs=['A B C'], references=['D E F G H I J K L M'], - sampling_params=sampling_params) + sampling_params=sampling_params, + trust_remote_code=True) @force_ampere @@ -250,7 +254,8 @@ def test_llm_baichuan_13b(): llm_test_harness(baichuan_13b_model_path, inputs=['A B C'], references=['D E F G H I J K L M'], - sampling_params=sampling_params) + sampling_params=sampling_params, + trust_remote_code=True) @force_ampere @@ -259,7 +264,8 @@ def test_llm_baichuan2_13b(): llm_test_harness(baichuan2_13b_model_path, inputs=['A B C'], references=['D E F G H I J K L M'], - sampling_params=sampling_params) + sampling_params=sampling_params, + trust_remote_code=True) @force_ampere @@ -271,7 +277,8 @@ def test_llm_baichuan2_7b_int4weight_only(): references=['D E F G H I J K L M'], sampling_params=sampling_params, quant_config=quant_config, - calib_config=calib_config) + calib_config=calib_config, + trust_remote_code=True) @skip_pre_ampere @@ -283,7 +290,8 @@ def test_llm_qwen(): llm_test_harness(qwen_model_path, inputs=['A B C'], references=['D E F G H I J K L M'], - sampling_params=sampling_params) + sampling_params=sampling_params, + trust_remote_code=True) @skip_pre_ampere @@ -292,7 +300,8 @@ def test_llm_qwen1_5(): llm_test_harness(qwen1_5_model_path, inputs=['1+1='], references=['2'], - sampling_params=qwen1_5_sampling_params) + sampling_params=qwen1_5_sampling_params, + trust_remote_code=True) @skip_pre_ampere @@ -300,7 +309,8 @@ def test_llm_qwen2(): llm_test_harness(qwen2_model_path, inputs=['A B C'], references=['D E F G H I J K L M'], - sampling_params=sampling_params) + sampling_params=sampling_params, + trust_remote_code=True) @skip_pre_ampere @@ -312,7 +322,8 @@ def test_llm_qwen2_int4_weight_only(): references=['D E F G H I J K L M'], sampling_params=sampling_params, quant_config=quant_config, - calib_config=calib_config) + calib_config=calib_config, + trust_remote_code=True) @skip_pre_hopper @@ -324,7 +335,8 @@ def test_llm_qwen2_fp8(): references=['D E F G H I J K L M'], sampling_params=sampling_params, quant_config=quant_config, - calib_config=calib_config) + calib_config=calib_config, + trust_remote_code=True) if __name__ == '__main__': diff --git a/tests/hlapi/test_llm_models_multi_gpu.py b/tests/hlapi/test_llm_models_multi_gpu.py index b8a6890b8..0582647c8 100644 --- a/tests/hlapi/test_llm_models_multi_gpu.py +++ b/tests/hlapi/test_llm_models_multi_gpu.py @@ -40,7 +40,8 @@ def test_llm_baichuan2_7b_tp2(): inputs=['A B C'], references=['D E F G H I J K L M'], sampling_params=sampling_params, - tensor_parallel_size=2) + tensor_parallel_size=2, + trust_remote_code=True) @skip_single_gpu diff --git a/tests/hlapi/test_llm_multi_gpu.py b/tests/hlapi/test_llm_multi_gpu.py index de06dec0e..782a8037b 100644 --- a/tests/hlapi/test_llm_multi_gpu.py +++ b/tests/hlapi/test_llm_multi_gpu.py @@ -11,8 +11,7 @@ from tensorrt_llm.executor import (ExecutorBindingsProxy, GenerationRequest, GenerationResult) -from tensorrt_llm.hlapi.llm import LLM, SamplingParams -from tensorrt_llm.hlapi.llm_utils import KvCacheConfig +from tensorrt_llm.hlapi import LLM, KvCacheConfig, SamplingParams from tensorrt_llm.hlapi.tokenizer import TransformersTokenizer from tensorrt_llm.hlapi.utils import get_total_gpu_memory from tensorrt_llm.mapping import Mapping @@ -22,17 +21,23 @@ from utils.util import skip_single_gpu, unittest_name_func try: - from .test_llm import (_test_llm_generate_async, default_model_name, - get_model_path, llama_7b_multi_lora_test_harness, - llama_model_path, llama_v2_13b_lora_test_harness, - llm_check_output, llm_test_harness, - mixtral_model_name, prompts) + from .test_llm import DummyExecutorWorker2 # isort:skip + from .test_llm import (DummyError, _test_llm_generate_async, + check_llm_return_context_logits, + check_llm_return_generation_logits, + default_model_name, get_model_path, + llama_7b_multi_lora_test_harness, llama_model_path, + llama_v2_13b_lora_test_harness, llm_check_output, + llm_test_harness, mixtral_model_name, prompts) except ImportError: - from test_llm import (_test_llm_generate_async, default_model_name, - get_model_path, llama_7b_multi_lora_test_harness, - llama_model_path, llama_v2_13b_lora_test_harness, - llm_check_output, llm_test_harness, - mixtral_model_name, prompts) + from test_llm import DummyExecutorWorker2 # isort:skip + from test_llm import (DummyError, _test_llm_generate_async, + check_llm_return_context_logits, + check_llm_return_generation_logits, + default_model_name, get_model_path, + llama_7b_multi_lora_test_harness, llama_model_path, + llama_v2_13b_lora_test_harness, llm_check_output, + llm_test_harness, mixtral_model_name, prompts) @pytest.fixture(scope="module") @@ -61,6 +66,10 @@ def engine_from_checkpoint() -> tempfile.TemporaryDirectory: return tmpdir +# shrink the kv_cache_config to avoid OOM in CI +global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + + @skip_single_gpu def test_llm_loading_from_ckpt_for_tp2( engine_from_checkpoint: tempfile.TemporaryDirectory): @@ -68,7 +77,8 @@ def test_llm_loading_from_ckpt_for_tp2( llm_test_harness(engine_from_checkpoint.name, prompts, ["D E F G H I J K"], sampling_params=SamplingParams(max_tokens=8), - tokenizer=tokenizer) + tokenizer=tokenizer, + kv_cache_config=global_kv_cache_config) @skip_single_gpu @@ -76,13 +86,22 @@ def test_llm_generate_tp2(engine_from_checkpoint): model_dir = engine_from_checkpoint.name tokenizer = TransformersTokenizer.from_pretrained(llama_model_path) - llm_test_harness( - model_dir, - prompts, ["D E F G H I J K"], - sampling_params=SamplingParams(max_tokens=8), - tokenizer=tokenizer, - tensor_parallel_size=2, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) + llm_test_harness(model_dir, + prompts, ["D E F G H I J K"], + sampling_params=SamplingParams(max_tokens=8), + tokenizer=tokenizer, + tensor_parallel_size=2, + kv_cache_config=global_kv_cache_config) + + +@skip_single_gpu +def test_llm_return_context_logits_tp2(): + check_llm_return_context_logits(tp_size=2) + + +@skip_single_gpu +def test_llm_return_generation_logits_tp2(): + check_llm_return_generation_logits(tp_size=2) @pytest.mark.parametrize("use_auto_parallel", [True, False], @@ -99,12 +118,10 @@ def test_llm_generate_async_tp2( llama_model_path) tokenizer_dir = get_model_path(llama_model_path) tokenizer = TransformersTokenizer.from_pretrained(tokenizer_dir) - _test_llm_generate_async( - model_dir, - tp_size=2, - use_auto_parallel=use_auto_parallel, - tokenizer=tokenizer, - ) + _test_llm_generate_async(model_dir, + tp_size=2, + use_auto_parallel=use_auto_parallel, + tokenizer=tokenizer) # TODO[chunweiy]: Move mixtral test to the e2e test @@ -124,23 +141,20 @@ def is_memory_enough_for_mixtral(): @pytest.mark.skipif(not is_memory_enough_for_mixtral(), reason="The test needs at least 160GB memory, skipping") def test_llm_generate_mixtral_for_tp2(): - llm = LLM( - get_model_path(mixtral_model_name), - tensor_parallel_size=2, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - ) + llm = LLM(get_model_path(mixtral_model_name), + tensor_parallel_size=2, + kv_cache_config=global_kv_cache_config) for output in llm.generate(prompts): print(output) def test_llm_pp2(): - llm_test_harness( - llama_model_path, - prompts, ["D E F G H I J K"], - sampling_params=SamplingParams(max_tokens=8), - pipeline_parallel_size=2, - auto_parallel=False, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) + llm_test_harness(llama_model_path, + prompts, ["D E F G H I J K"], + sampling_params=SamplingParams(max_tokens=8), + pipeline_parallel_size=2, + auto_parallel=False, + kv_cache_config=global_kv_cache_config) def llm_end2end_tp2_cases(): @@ -158,7 +172,10 @@ def llm_end2end_tp2_cases(): def test_llm_end2end_tp2(llm_additional_options): model_path = get_model_path(default_model_name) - llm = LLM(model_path, tensor_parallel_size=2, **llm_additional_options) + llm = LLM(model_path, + tensor_parallel_size=2, + **llm_additional_options, + kv_cache_config=global_kv_cache_config) assert llm.args._convert_checkpoint_options embedding_parallel_mode = llm_additional_options.pop( @@ -194,14 +211,16 @@ def test_llm_end2end_tp2(llm_additional_options): @skip_single_gpu def test_llama_v2_13b_lora_tp2(): - llama_v2_13b_lora_test_harness(tensor_parallel_size=2) + llama_v2_13b_lora_test_harness(tensor_parallel_size=2, + kv_cache_config=global_kv_cache_config) @skip_single_gpu def test_llama_7b_multi_lora_tp2(): llama_7b_multi_lora_test_harness(tensor_parallel_size=2, max_loras=1, - max_cpu_loras=8) + max_cpu_loras=8, + kv_cache_config=global_kv_cache_config) @skip_single_gpu @@ -218,7 +237,7 @@ def _test_llm_multi_node(engine_from_checkpoint: tempfile.TemporaryDirectory): @skip_single_gpu def test_executor_results_cleanup(): llm = LLM(model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=global_kv_cache_config, tensor_parallel_size=2) sampling_params = SamplingParams(max_new_tokens=6) for i in range(20): @@ -228,26 +247,6 @@ def test_executor_results_cleanup(): assert num_remaining_results == 0 -class DummyExecutor: - - @staticmethod - def create( - engine, - executor_config, - model_world_size: int = 1, - world_size: int = 0, - mpi_session=None, - reuse_mpi_comm: bool = False, - ): - worker_kwargs = { - "engine": engine, - "executor_config": executor_config, - } - return DummyExecutorProxy(worker_kwargs, - model_world_size=model_world_size, - mpi_session=mpi_session) - - class DummyExecutorProxy(ExecutorBindingsProxy): def __init__( @@ -271,8 +270,9 @@ def submit(self, request: GenerationRequest) -> GenerationResult: result = GenerationResult( request, background_error_handler=self._handle_background_error) - # Force the responses to be delayed - time.sleep(1) + # Force the responses to be delayed, need a long time to ensure at least one response is generated, especially + # for the non-streaming mode when some internal lasy-setup considered + time.sleep(10) print(f"number of pending responses: {len(self._pending_responses)}") assert self._pending_responses @@ -287,7 +287,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult: def test_executor_pending_requests(): llm = LLM(model=llama_model_path, executor_cls=DummyExecutor, - tensor_parallel_size=2) + tensor_parallel_size=2, + kv_cache_config=global_kv_cache_config) # The dummy executor will delay the responses sampling_params = SamplingParams(max_tokens=6) @@ -310,27 +311,36 @@ async def task(): test_streaming() -class DummyExecutor2: +class DummyExecutorMeta(type): - @staticmethod - def create( - engine, - executor_config, - model_world_size: int = 1, - world_size: int = 0, - mpi_session=None, - reuse_mpi_comm: bool = False, - ): - worker_kwargs = { - "engine": engine, - "executor_config": executor_config, - } - return DummyExecutorProxy2(worker_kwargs, - model_world_size=model_world_size, - mpi_session=mpi_session) + def __new__(cls, name, bases, dic, proxy_class): + new_cls = super().__new__(cls, name, bases, dic) + + @staticmethod + def create(engine, + executor_config, + model_world_size: int = 1, + world_size: int = 0, + mpi_session=None, + reuse_mpi_comm: bool = False): + worker_kwargs = { + "engine": engine, + "executor_config": executor_config, + } + return proxy_class(worker_kwargs, + model_world_size=model_world_size, + mpi_session=mpi_session) + + new_cls.create = create + return new_cls + + +DummyExecutor = DummyExecutorMeta("DummyExecutor", (), {}, + proxy_class=DummyExecutorProxy) class DummyExecutorProxy2(ExecutorBindingsProxy): + ''' This is for testing the error occur in the thread in the Proxy. ''' def __init__( self, @@ -345,21 +355,64 @@ def dispatch_result_task(self) -> bool: self.counter += 1 if self.counter == 2: - raise ValueError("Test error") + raise DummyError("Test error") return super().dispatch_result_task() +DummyExecutor2 = DummyExecutorMeta("DummyExecutor2", (), {}, + proxy_class=DummyExecutorProxy2) + + +class DummyExecutorProxy3(ExecutorBindingsProxy): + ''' This is for testing the error occur in a Worker process in the Proxy. ''' + + def __init__( + self, + workers_kwargs, + model_world_size: int = 1, + mpi_session=None, + ) -> None: + super().__init__(workers_kwargs, + model_world_size, + mpi_session, + worker_cls=DummyExecutorWorker2) + + +DummyExecutor3 = DummyExecutorMeta("DummyExecutor3", (), {}, + proxy_class=DummyExecutorProxy3) + + # TODO[chunweiy]: This test is not stable, need to investigate -def _test_executor_process_background_error(): +def test_executor_handle_background_error(): + + llm = LLM(model=llama_model_path, + executor_cls=DummyExecutor2, + kv_cache_config=global_kv_cache_config) + # The dummy executor will delay the responses + sampling_params = SamplingParams(max_tokens=6) - llm = LLM(model=llama_model_path, executor_cls=DummyExecutor2) + # test in streaming mode + async def task(): + with pytest.raises(DummyError): + async for output in llm.generate_async( + prompts[0], streaming=True, + sampling_params=sampling_params): + print(output) + + asyncio.run(task()) + + +def test_executor_handle_background_error_in_worker(): + llm = LLM(model=llama_model_path, + executor_cls=DummyExecutor2, + kv_cache_config=global_kv_cache_config) # The dummy executor will delay the responses sampling_params = SamplingParams(max_tokens=6) # test in streaming mode async def task(): - with pytest.raises(ValueError): + with pytest.raises(DummyError): async for output in llm.generate_async( prompts[0], streaming=True, sampling_params=sampling_params): @@ -369,5 +422,7 @@ async def task(): if __name__ == '__main__': - test_llm_pp2() - test_executor_pending_requests() + #test_llama_v2_13b_lora_tp2() + #test_llm_end2end_tp2({'embedding_parallel_mode': 'NONE'}) + test_llm_return_context_logits_tp2() + test_llm_return_generation_logits_tp2() diff --git a/tests/hlapi/test_llm_utils.py b/tests/hlapi/test_llm_utils.py index ae3cae043..8bc460725 100644 --- a/tests/hlapi/test_llm_utils.py +++ b/tests/hlapi/test_llm_utils.py @@ -2,7 +2,6 @@ from pathlib import Path import pytest -from transformers import AutoTokenizer from tensorrt_llm.builder import PluginConfig from tensorrt_llm.hlapi.llm_utils import * @@ -170,11 +169,10 @@ def fallback(): def test_ModelLoader(): - args = LlmArgs(llama_model_path) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + args = LlmArgs(llama_model_path, kv_cache_config=kv_cache_config) args.setup() - tokenizer = AutoTokenizer.from_pretrained(args.model) - # Test with HF model temp_dir = tempfile.TemporaryDirectory() @@ -196,7 +194,8 @@ def build_engine(): def test_CachedModelLoader(): # CachedModelLoader enables engine caching and multi-gpu building - args = LlmArgs(llama_model_path) + args = LlmArgs(llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) args.enable_build_cache = True args.setup() stats = LlmBuildStats() diff --git a/tests/model/test_decilm.py b/tests/model/test_decilm.py deleted file mode 100644 index 083db6671..000000000 --- a/tests/model/test_decilm.py +++ /dev/null @@ -1,602 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import itertools -import os -import sys -import tempfile -import unittest -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import tensorrt as trt -import torch -import transformers -from parameterized import parameterized - -import tensorrt_llm -from tensorrt_llm import logger -from tensorrt_llm._utils import str_dtype_to_torch -from tensorrt_llm.builder import Builder -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.deci.config import DeciConfig, DeciLayerConfig -from tensorrt_llm.models.deci.convert import _ffn_mult_to_intermediate_size -from tensorrt_llm.models.deci.layer_config import (AttentionImplementation, - FFNImplementation) -from tensorrt_llm.models.deci.model import DeciLMForCausalLM -from tensorrt_llm.network import Network, net_guard -from tensorrt_llm.plugin.plugin import ContextFMHAType -from tensorrt_llm.runtime.generation import _Runtime - -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.llm_data import llm_models_root -from utils.util import unittest_name_func - - -class TestDeciLM(unittest.TestCase): - - def _make_decilm_config(self, - layer_configs: List[Union[DeciLayerConfig, - Dict[str, Dict[str, - Any]]]], - dtype: str = 'bfloat16', - num_attention_heads: int = 32, - num_key_value_heads: Optional[int] = None, - hidden_size: int = 4096, - intermediate_size: int = 16384, - vocab_size: int = 32128, - max_positions_embedding: int = 1024, - norm_epsilon: float = 1e-05) -> DeciConfig: - config = { - 'architecture': 'DeciLMForCausalLM', - 'num_hidden_layers': len(layer_configs), - 'num_attention_heads': num_attention_heads, - 'num_key_value_heads': num_key_value_heads, - 'dtype': dtype, - 'logits_dtype': dtype, - 'hidden_size': hidden_size, - 'intermediate_size': intermediate_size, - 'vocab_size': vocab_size, - 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': max_positions_embedding, - 'hidden_act': 'silu', - 'norm_epsilon': norm_epsilon, - 'layer_configs': layer_configs - } - - config = DeciConfig.from_dict(config) - return config - - def _gen_tensorrt_llm_network(self, network: Network, - decilm: DeciLMForCausalLM, batch_size: int, - beam_width: int, input_len: int, - output_len: int, rank: int, - tensor_parallel: int, **opt_flags): - list(range(tensor_parallel)) - - with net_guard(network): - # optimize_model(decilm, **opt_flags) - # Prepare - network.set_named_parameters(decilm.named_parameters()) - inputs = decilm.prepare_inputs(max_batch_size=batch_size, - max_input_len=input_len, - max_seq_len=input_len + output_len, - max_num_tokens=batch_size * - input_len, - use_cache=True, - max_beam_width=beam_width) - # Forward - decilm(**inputs) - return network - - def _gen_tensorrt_llm_engine( - self, - rank: int, - world_size: int, - decilm: DeciLMForCausalLM, - model_name: str, - use_plugin: bool, - batch_size: int, - beam_width: int, - input_len: int, - output_len: int, - use_refit: bool, - use_gemm: bool = False, - context_fmha_flag: ContextFMHAType = ContextFMHAType.disabled, - enable_remove_input_padding: bool = False, - **opt_flags) -> trt.IHostMemory: - - builder = Builder() - dtype = decilm.config.dtype - - with tempfile.TemporaryDirectory(): - builder_config = builder.create_builder_config( - name=model_name, - precision=dtype, - timing_cache='model.cache', - tensor_parallel=world_size, # TP only - use_refit=use_refit, - strongly_typed=True, - ) - network = builder.create_network() - network.plugin_config.to_legacy_setting() - if use_plugin: - network.plugin_config.gpt_attention_plugin = dtype - if use_gemm: - network.plugin_config.gemm_plugin = dtype - if enable_remove_input_padding: - network.plugin_config.remove_input_padding = True - network.plugin_config.set_context_fmha(context_fmha_flag) - - self._gen_tensorrt_llm_network(network=network, - decilm=decilm, - batch_size=batch_size, - beam_width=beam_width, - input_len=input_len, - output_len=output_len, - rank=rank, - tensor_parallel=world_size, - **opt_flags) - engine_buffer = builder.build_engine(network, builder_config) - return engine_buffer - - def _gen_tensorrt_llm_runtime( - self, - log_level: str, - world_size: int, - rank: int, - decilm: DeciLMForCausalLM, - model_name: str, - use_plugin: bool, - batch_size: int, - beam_width: int, - input_len: int, - output_len: int, - use_refit: bool, - use_gemm: bool = False, - context_fmha_flag: ContextFMHAType = ContextFMHAType.disabled, - enable_remove_input_padding: bool = False, - **opt_flags) -> Tuple[_Runtime, trt.IHostMemory]: - logger.set_level(log_level) - mapping = Mapping(world_size, rank, tp_size=world_size) - engine_buffer = self._gen_tensorrt_llm_engine( - rank=rank, - world_size=world_size, - decilm=decilm, - model_name=model_name, - use_plugin=use_plugin, - batch_size=batch_size, - beam_width=beam_width, - input_len=input_len, - output_len=output_len, - use_refit=use_refit, - use_gemm=use_gemm, - context_fmha_flag=context_fmha_flag, - enable_remove_input_padding=enable_remove_input_padding, - **opt_flags) - runtime = _Runtime(engine_buffer, mapping) - return runtime, engine_buffer - - def test_config_to_from_dict(self) -> None: - config = self._make_decilm_config(layer_configs=[{ - "attention": { - "num_key_value_heads": 4 - }, - "ffn": {} - }, { - "attention": { - "num_key_value_heads": 2 - }, - "ffn": { - "impl": "no_op" - } - }, { - "attention": { - "impl": "no_op" - }, - "ffn": { - "intermediate_size": 8192 - } - }]) - - config2 = DeciConfig.from_dict(config.to_dict()) - self.assertListEqual(config.layer_configs, config2.layer_configs) - - def test_save_load_config(self) -> None: - config = self._make_decilm_config(layer_configs=[{ - "attention": { - "num_key_value_heads": 4 - }, - "ffn": {} - }, { - "attention": { - "num_key_value_heads": 2 - }, - "ffn": { - "impl": "no_op" - } - }, { - "attention": { - "impl": "no_op" - }, - "ffn": { - "intermediate_size": 8192 - } - }]) - - with tempfile.TemporaryDirectory( - prefix="test_save_load_checkpoint") as ckpt_dir: - config_file = f"{ckpt_dir}/config.json" - config.to_json_file(config_file) - config2 = DeciConfig.from_json_file(config_file) - - self.assertDictEqual(config.to_dict(), config2.to_dict()) - self.assertListEqual(config.layer_configs, config2.layer_configs) - - def get_loader_test_cases(): - model_root = llm_models_root(check=True) - test_models_base_path = Path(model_root, "nvsmall/tests") - - models_path = [ - os.path.join(test_models_base_path, x) - for x in os.listdir(test_models_base_path) - ] - test_cases = list( - itertools.product(models_path, ["bfloat16", "float16"])) - - return test_cases - - @parameterized.expand(get_loader_test_cases, name_func=unittest_name_func) - def test_allclose_to_hf(self, hf_model_dir, dtype): - if hf_model_dir is None: - self.skipTest( - f"Missing nvsmall checkpoint, define a valid checkpoint path with the NVSMALL_CKPT environment variable" - ) - - dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype) - - hf_model = transformers.AutoModelForCausalLM.from_pretrained( - hf_model_dir, trust_remote_code=True, torch_dtype=dtype).cuda() - decilm = DeciLMForCausalLM.from_hugging_face(hf_model) - config = decilm.config - - log_level = "warning" - batch_size = 1 - beam_width = 1 - input_len = 4 - output_len = 2 - max_seq_len = input_len + output_len - dtype = config.dtype - enable_remove_input_padding = False - use_gpt_plugin = True - use_gemm = True - - runtime, engine_buffer = self._gen_tensorrt_llm_runtime( - log_level=log_level, - decilm=decilm, - batch_size=batch_size, - beam_width=beam_width, - input_len=input_len, - output_len=output_len, - rank=0, - world_size=1, - model_name="decilm", - use_gemm=use_gemm, - use_plugin=use_gpt_plugin, - use_refit=False) - - key_value_cache_buffers = [] - head_size = config.hidden_size // config.num_attention_heads - - attn_layer_idx = [ - i for i in range(config.num_hidden_layers) - if config.get_layer_config(i).attention.needs_kv_cache - ] - for layer_idx in attn_layer_idx: - layer_config = config.get_layer_config(layer_idx) - new_cache = torch.zeros(( - batch_size, - 2, - layer_config.attention.num_key_value_heads, - max_seq_len, - head_size, - ), - dtype=str_dtype_to_torch(dtype), - device='cuda') - key_value_cache_buffers.append(new_cache) - - # compare context - ctx_ids = torch.randint(100, (batch_size, input_len)).int().cuda() - ctx_context_lengths = input_len * torch.ones( - (batch_size), dtype=torch.int32, device='cuda') - ctx_position_ids = torch.tensor(range(input_len), - dtype=torch.int32).reshape([ - 1, input_len - ]).expand([batch_size, - input_len]).cuda() - ctx_last_token_ids = ctx_context_lengths.clone() - ctx_host_request_types = torch.tensor([0] * batch_size, - dtype=torch.int32) - - # We need sequence_lengths start as context_lengths for step 0, - # and it will be added one after each step. - sequence_length_buffer = ctx_context_lengths.detach().clone() - - with torch.no_grad(): - hf_outputs = hf_model.forward(ctx_ids, - output_hidden_states=True, - output_attentions=True) - - torch.cuda.synchronize() - ref = hf_outputs.logits[:, -1, :] - - if enable_remove_input_padding: - ctx_ids = ctx_ids.view([batch_size * input_len]) - ctx_position_ids = ctx_position_ids.view([batch_size * input_len]) - ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, dim=0).int() - - cache_indirections = [ - torch.full(( - batch_size, - beam_width, - max_seq_len, - ), - 0, - dtype=torch.int32, - device='cuda'), - torch.full(( - batch_size, - beam_width, - max_seq_len, - ), - 0, - dtype=torch.int32, - device='cuda') - ] # ping-pong buffers - - perf_knob_tensor_size = 16 - # runtime_perf_knobs is not used in context phase - context_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size, - dtype=torch.int64) - - ctx_buffer = { - 'input_ids': ctx_ids, - 'context_lengths': ctx_context_lengths, - 'position_ids': ctx_position_ids, - 'last_token_ids': ctx_last_token_ids, - 'cache_indirection': cache_indirections[0], - 'host_request_types': ctx_host_request_types, - 'host_runtime_perf_knobs': context_runtime_perf_knobs, - } - if enable_remove_input_padding: - ctx_buffer['host_context_lengths'] = ctx_context_lengths.cpu() - - ctx_shape = {k: v.shape for k, v in ctx_buffer.items()} - - ctx_buffer[f'host_max_attention_window_sizes'] = torch.tensor( - [max_seq_len] * len(attn_layer_idx), dtype=torch.int32) - ctx_shape[f'host_max_attention_window_sizes'] = (len(attn_layer_idx), ) - for layer_idx, buf in zip(attn_layer_idx, key_value_cache_buffers): - layer_config = config.get_layer_config(layer_idx) - kv_shape = (batch_size, 2, - layer_config.attention.num_key_value_heads, max_seq_len, - head_size) - ctx_shape[f'past_key_value_{layer_idx}'] = kv_shape - ctx_buffer[f'past_key_value_{layer_idx}'] = buf - ctx_buffer[f'present_key_value_{layer_idx}'] = buf - - ctx_buffer['sequence_length'] = sequence_length_buffer - ctx_shape['sequence_length'] = ctx_buffer['sequence_length'].shape - ctx_shape['host_past_key_value_lengths'] = (batch_size, ) - ctx_buffer['host_past_key_value_lengths'] = torch.tensor( - [0] * batch_size, dtype=torch.int32) - ctx_shape['host_sink_token_length'] = (1, ) - ctx_buffer['host_sink_token_length'] = torch.tensor([0], - dtype=torch.int32) - - context = runtime.ctx_context - runtime._set_shape(context, ctx_shape) - runtime._set_buffer(context, ctx_buffer) - runtime._run(context) - torch.cuda.synchronize() - - res = ctx_buffer['logits'] - np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), - res.to(torch.float32).cpu().numpy(), - atol=0.12) - - # compare generation - step = 1 - step1_id = torch.randint(100, (batch_size, 1)).int().cuda() - gen_context_lengths = ctx_context_lengths.clone() - gen_position_ids = torch.ones_like(step1_id).int().cuda() * input_len - gen_last_token_ids = torch.zeros_like(gen_context_lengths).int().cuda() - gen_host_request_types = torch.tensor([1] * batch_size, - dtype=torch.int32) - gen_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size, - dtype=torch.int64) - - with torch.no_grad(): - hf_outputs = hf_model.forward( - step1_id, - past_key_values=hf_outputs.past_key_values, - use_cache=True, - output_hidden_states=True) - - torch.cuda.synchronize() - ref = hf_outputs.logits[:, -1, :] - - if enable_remove_input_padding: - step1_id = step1_id.view([batch_size]) - gen_position_ids = gen_position_ids.view([batch_size]) - gen_last_token_ids = torch.ones_like( - gen_context_lengths).int().cuda() - gen_last_token_ids = torch.cumsum(gen_last_token_ids, dim=0).int() - - step1_buffer = { - 'input_ids': step1_id, - 'context_lengths': gen_context_lengths, - 'position_ids': gen_position_ids, - 'last_token_ids': gen_last_token_ids, - 'host_request_types': gen_host_request_types, - 'cache_indirection': cache_indirections[1], - 'host_runtime_perf_knobs': gen_runtime_perf_knobs, - } - if enable_remove_input_padding: - step1_buffer['host_context_lengths'] = gen_context_lengths.cpu() - - step1_shape = {k: v.shape for k, v in step1_buffer.items()} - - sequence_length_buffer = torch.add(sequence_length_buffer, step) - step1_buffer[f'host_max_attention_window_sizes'] = torch.tensor( - [max_seq_len] * len(attn_layer_idx), dtype=torch.int32) - step1_shape[f'host_max_attention_window_sizes'] = ( - len(attn_layer_idx), ) - for layer_idx, buf in zip(attn_layer_idx, key_value_cache_buffers): - layer_config = config.get_layer_config(layer_idx) - kv_shape = (batch_size, 2, - layer_config.attention.num_key_value_heads, max_seq_len, - head_size) - step1_shape[f"past_key_value_{layer_idx}"] = kv_shape - step1_buffer[f"past_key_value_{layer_idx}"] = buf - step1_buffer[f"present_key_value_{layer_idx}"] = buf - - step1_buffer['sequence_length'] = sequence_length_buffer - step1_shape['sequence_length'] = ctx_buffer['sequence_length'].shape - step1_shape['sequence_length'] = (batch_size, ) - step1_shape['host_past_key_value_lengths'] = (batch_size, ) - step1_buffer[ - 'host_past_key_value_lengths'] = sequence_length_buffer.cpu() - step1_shape['host_sink_token_length'] = (1, ) - step1_buffer['host_sink_token_length'] = torch.tensor([0], - dtype=torch.int32) - - context = runtime.context_1 - runtime._set_shape(context, step1_shape) - runtime._set_buffer(context, step1_buffer) - runtime._run(context) - torch.cuda.synchronize() - res = step1_buffer['logits'] - - np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), - res.to(torch.float32).cpu().numpy(), - atol=0.12) - - @parameterized.expand( - itertools.product( - (os.getenv("NVSMALL_CKPT"), ), # "deci/decilm-7b"), - (True, False), - (1, 2), - (1, 2), - ("auto", "float16", "bfloat16"))) - def test_convert_config_from_hf(self, ckpt_path: Optional[str], - preloaded: bool, tp_size: int, pp_size: int, - dtype: str) -> None: - if ckpt_path is None: - self.skipTest( - f"Missing nvsmall checkpoint, define a valid checkpoint path with the NVSMALL_CKPT environment variable" - ) - - hf_config = transformers.AutoConfig.from_pretrained( - ckpt_path, trust_remote_code=True) - - mapping = Mapping(world_size=(tp_size * pp_size), - rank=0, - gpus_per_node=1, - tp_size=tp_size, - pp_size=pp_size) - - config = DeciConfig.from_hugging_face( - hf_config if preloaded else ckpt_path, - dtype=dtype, - mapping=mapping, - trust_remote_code=not preloaded) - - if getattr(hf_config, "num_key_value_heads_per_layer", - None) is not None: - # verify layers for old config - for layer_idx, num_kv_heads in enumerate( - hf_config.num_key_value_heads_per_layer): - layer_config = config.get_layer_config(layer_idx) - self.assertEqual(layer_config.attention.impl, - AttentionImplementation.ATTENTION) - self.assertEqual(num_kv_heads, - layer_config.attention.num_key_value_heads) - self.assertEqual(layer_config.ffn.impl, FFNImplementation.MLP) - self.assertEqual(layer_config.ffn.intermediate_size, - config.intermediate_size) - - elif getattr(hf_config, "block_configs", None) is not None: - # verify layers for new config - for layer_idx, block_config in enumerate(hf_config.block_configs): - layer_config = config.get_layer_config(layer_idx) - if layer_config.attention.impl == AttentionImplementation.ATTENTION: - self.assertFalse(block_config.attention.no_op) - self.assertFalse(block_config.attention.replace_with_linear) - self.assertEqual( - config.num_attention_heads // - block_config.attention.n_heads_in_group, - layer_config.attention.num_key_value_heads) - elif layer_config.attention.impl == AttentionImplementation.NO_OP: - self.assertTrue(block_config.attention.no_op) - elif layer_config.attention.impl == AttentionImplementation.LINEAR: - self.assertTrue(block_config.attention.replace_with_linear) - - if layer_config.ffn.impl == FFNImplementation.MLP: - self.assertFalse(block_config.ffn.no_op) - self.assertFalse(block_config.ffn.replace_with_linear) - self.assertEqual( - _ffn_mult_to_intermediate_size( - block_config.ffn.ffn_mult, config.hidden_size), - layer_config.ffn.intermediate_size) - elif layer_config.ffn.impl == FFNImplementation.NO_OP: - self.assertTrue(block_config.ffn.no_op) - elif layer_config.ffn.impl == FFNImplementation.LINEAR: - self.assertTrue(block_config.ffn.replace_with_linear) - - # verify config is valid enough for model creation - DeciLMForCausalLM(config) - - @parameterized.expand( - itertools.product( - (os.getenv("NVSMALL_CKPT"), ), # "deci/decilm-7b"), - (True, False), - (1, 2), - (1, 2), - ("auto", "float16", "bfloat16"))) - def test_convert_model_from_hf(self, ckpt_path: Optional[str], - preloaded: bool, tp_size: int, pp_size: int, - dtype: str) -> None: - if ckpt_path is None: - self.skipTest( - f"Missing nvsmall checkpoint, define a valid checkpoint path with the NVSMALL_CKPT environment variable" - ) - - if preloaded: - hf_model_or_dir = transformers.AutoModelForCausalLM.from_pretrained( - ckpt_path, trust_remote_code=True) - else: - hf_model_or_dir = ckpt_path - - mapping = Mapping(world_size=(tp_size * pp_size), - rank=0, - gpus_per_node=1, - tp_size=tp_size, - pp_size=pp_size) - - DeciLMForCausalLM.from_hugging_face(hf_model_or_dir=hf_model_or_dir, - dtype=dtype, - mapping=mapping, - trust_remote_code=not preloaded) diff --git a/tests/model/test_gpt.py b/tests/model/test_gpt.py index 54b2822b4..7123ab296 100644 --- a/tests/model/test_gpt.py +++ b/tests/model/test_gpt.py @@ -38,12 +38,16 @@ from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.runtime import ModelConfig, SamplingConfig from tensorrt_llm.runtime.generation import _prepare_attention_mask -from tensorrt_llm.runtime.kv_cache_manager import (GenerationSequence, - KVCacheManager) +from tensorrt_llm.runtime.kv_cache_manager import GenerationSequence +from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \ + PoolsKVCacheManager sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import skip_fp32_accum_pre_ampere, unittest_name_func +from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \ + MemoryPoolsAllocator + class TestGPT(unittest.TestCase): @@ -513,27 +517,50 @@ def test_gpt_plugin(self, test_partition, use_refit, fast_building, if enable_paged_kv_cache: max_blocks_per_seq = math.ceil(total_length / tokens_per_block) num_blocks = batch_size * beam_width * max_blocks_per_seq - block_size = gpt_config.n_head * tokens_per_block * head_size - kv_cache_manager = KVCacheManager( - num_layers=gpt_config.n_layer, + + memory_pools_allocator = MemoryPoolsAllocator( num_blocks=num_blocks, - block_size=block_size, tokens_per_block=tokens_per_block, - max_blocks_per_seq=max_blocks_per_seq, + head_size=head_size) + num_kv_heads_per_layer = MemoryPoolsAllocator.prepare_num_kv_heads_per_layer( + gpt_config.n_head, gpt_config.n_layer) + memory_pools_allocator.allocate(dtype, num_kv_heads_per_layer) + pools_kv_cache_manager = PoolsKVCacheManager( + memory_pools_allocator.pools_metadata, + max_blocks_per_seq, + num_blocks, + tokens_per_block, + head_size, max_attention_window_size=total_length, - sink_token_len=0, - beam_width=beam_width) - host_kv_cache_pool_pointers = torch.tensor( - [key_value_cache_buffers[0].data_ptr(), 0], dtype=torch.int64) + beam_width=beam_width, + sink_token_len=0) + + host_kv_cache_pool_pointers = memory_pools_allocator.get_kv_cache_pool_pointers( + ) + host_kv_cache_pool_mapping = memory_pools_allocator.pool_mapping + + # block_size = gpt_config.n_head * tokens_per_block * head_size + # kv_cache_manager = KVCacheManager( + # num_layers=gpt_config.n_layer, + # num_blocks=num_blocks, + # block_size=block_size, + # tokens_per_block=tokens_per_block, + # max_blocks_per_seq=max_blocks_per_seq, + # max_attention_window_size=total_length, + # sink_token_len=0, + # beam_width=beam_width) + # host_kv_cache_pool_pointers = torch.tensor( + # [key_value_cache_buffers[0].data_ptr(), 0], dtype=torch.int64) # Add sequences to the manager for bi in range(batch_size): generation_sequence = GenerationSequence(seq_idx=bi, batch_idx=bi) - kv_cache_manager.add_sequence(generation_sequence, seq_len) + pools_kv_cache_manager.add_sequence(generation_sequence, + seq_len) # Pre allocate the kv cache for the generated tokens. - kv_cache_manager.step([False] * batch_size) + pools_kv_cache_manager.step([False] * batch_size) def run_engine(context, input_ids, @@ -570,7 +597,7 @@ def run_engine(context, if enable_paged_kv_cache: assert beam_width == 1 # for beam_width > 1 the argument must be '1' in ctx phase and 'beam_width' in gen phase - host_kv_cache_block_offsets = kv_cache_manager.get_block_offsets( + host_kv_cache_block_offsets = pools_kv_cache_manager.get_block_offsets( beam_width=1) kv_cache_block_offsets = host_kv_cache_block_offsets.to('cuda') @@ -585,6 +612,10 @@ def run_engine(context, ctx_buffer[ f'host_kv_cache_pool_pointers'] = host_kv_cache_pool_pointers.contiguous( ) + ctx_buffer[ + f'host_kv_cache_pool_mapping'] = memory_pools_allocator.pool_mapping.contiguous( + ) + ctx_buffer[ f'host_max_attention_window_sizes'] = host_max_attention_window_sizes else: diff --git a/tests/model/test_nemotron_nas.py b/tests/model/test_nemotron_nas.py new file mode 100644 index 000000000..469a65b4e --- /dev/null +++ b/tests/model/test_nemotron_nas.py @@ -0,0 +1,989 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import math +import os +import re +import subprocess +import sys +import tempfile +import unittest +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import pytest +import tensorrt as trt +import torch +import transformers +from parameterized import parameterized +from transformers import AutoTokenizer +from typing_extensions import Literal + +import tensorrt_llm +from tensorrt_llm import logger +from tensorrt_llm._utils import str_dtype_to_torch +from tensorrt_llm.builder import Builder, Engine, EngineConfig +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import PretrainedConfig +from tensorrt_llm.models.nemotron_nas.config import DeciConfig, DeciLayerConfig +from tensorrt_llm.models.nemotron_nas.convert import ( + _ffn_mult_to_intermediate_size, load_weights_from_hf_safetensors) +from tensorrt_llm.models.nemotron_nas.layer_config import ( + AttentionImplementation, FFNImplementation) +from tensorrt_llm.models.nemotron_nas.model import DeciLMForCausalLM +from tensorrt_llm.network import Network, net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.runtime.generation import _Runtime + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.llm_data import llm_models_root +from utils.util import get_project_root, unittest_name_func + +sys.path.append( + os.path.join(os.path.dirname(__file__), '../..', 'examples/nemotron_nas')) +from calibration_utils import create_trtllm_magpie_calibration_dataset + +from tensorrt_llm.runtime.kv_cache_manager import GenerationSequence +from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \ + MemoryPoolsAllocator +from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \ + PoolsKVCacheManager +from tensorrt_llm.runtime.model_runner import ModelRunner +from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCpp + + +@dataclass(kw_only=True, frozen=True) +class TestParams: + enable_paged_kv_cache: bool + enable_remove_input_padding: bool + dtype: Literal["float16", "bfloat16"] + + batch_size: int = 1 + beam_width: int = 1 + seq_len: int = 128 + total_length: int = seq_len + 2 + tokens_per_block: int = 128 + + @property + def output_len(self): + return self.total_length - self.seq_len + + def __str__(self) -> str: + """tests/utils/util.py#L143 - > `str(x)`: parameterized test name""" + properties_without_default = (self.enable_paged_kv_cache, + self.enable_remove_input_padding, + self.dtype) + return "_".join((parameterized.to_safe_name(prop).lower() + for prop in properties_without_default)) + + @property + def mapping(self) -> Mapping: + return Mapping(world_size=1, rank=0, tp_size=1) + + +@dataclass +class RuntimeHandle: + """Deleting `Runtime().runtime` will **definitively** deallocate the weights.""" + runtime: _Runtime + + +class TestNemotronNas(unittest.TestCase): + + def _make_config(self, + layer_configs: List[Union[DeciLayerConfig, + Dict[str, Dict[str, Any]]]], + dtype: str = 'bfloat16', + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = None, + hidden_size: int = 4096, + intermediate_size: int = 16384, + vocab_size: int = 32128, + max_positions_embedding: int = 1024, + norm_epsilon: float = 1e-05) -> DeciConfig: + config = { + 'architecture': 'DeciLMForCausalLM', + 'num_hidden_layers': len(layer_configs), + 'num_attention_heads': num_attention_heads, + 'num_key_value_heads': num_key_value_heads, + 'dtype': dtype, + 'logits_dtype': dtype, + 'hidden_size': hidden_size, + 'intermediate_size': intermediate_size, + 'vocab_size': vocab_size, + 'position_embedding_type': 'rope_gpt_neox', + 'max_position_embeddings': max_positions_embedding, + 'hidden_act': 'silu', + 'norm_epsilon': norm_epsilon, + 'layer_configs': layer_configs + } + + config = DeciConfig.from_dict(config) + return config + + def _gen_tensorrt_llm_network(self, network: Network, + model: DeciLMForCausalLM, batch_size: int, + beam_width: int, input_len: int, + output_len: int, rank: int, + tensor_parallel: int, **opt_flags): + list(range(tensor_parallel)) + + with net_guard(network): + # Prepare + network.set_named_parameters(model.named_parameters()) + inputs = model.prepare_inputs(max_batch_size=batch_size, + max_input_len=input_len, + max_seq_len=input_len + output_len, + max_num_tokens=batch_size * input_len, + use_cache=True, + max_beam_width=beam_width) + # Forward + model(**inputs) + return network + + def _gen_tensorrt_llm_engine( + self, + rank: int, + world_size: int, + model: DeciLMForCausalLM, + model_name: str, + use_plugin: bool, + batch_size: int, + beam_width: int, + input_len: int, + output_len: int, + tokens_per_block: int, + use_refit: bool, + use_gemm: bool = False, + context_fmha_flag: ContextFMHAType = ContextFMHAType.disabled, + enable_remove_input_padding: bool = False, + enable_paged_kv_cache: bool = False, + **opt_flags) -> trt.IHostMemory: + + builder = Builder() + dtype = model.config.dtype + + with tempfile.TemporaryDirectory(): + builder_config = builder.create_builder_config( + name=model_name, + precision=dtype, + timing_cache='model.cache', + tensor_parallel=world_size, # TP only + use_refit=use_refit, + strongly_typed=True, + ) + network = builder.create_network() + network.plugin_config.to_legacy_setting() + if use_plugin: + network.plugin_config.gpt_attention_plugin = dtype + if use_gemm: + network.plugin_config.gemm_plugin = dtype + if enable_remove_input_padding: + network.plugin_config.remove_input_padding = True + if enable_paged_kv_cache: + network.plugin_config.enable_paged_kv_cache(tokens_per_block) + + network.plugin_config.set_context_fmha(context_fmha_flag) + + self._gen_tensorrt_llm_network(network=network, + model=model, + batch_size=batch_size, + beam_width=beam_width, + input_len=input_len, + output_len=output_len, + rank=rank, + tensor_parallel=world_size, + **opt_flags) + engine_buffer = builder.build_engine(network, builder_config) + return engine_buffer + + def _from_hf_model( + self, + hf_model: transformers.AutoModelForCausalLM, + params: TestParams, + *, + model_name: str = "nemotron-nas", + use_plugin: bool = True, + use_refit: bool = False, + use_gemm: bool = True, + context_fmha_flag: ContextFMHAType = ContextFMHAType.disabled, + **opt_flags) -> Tuple[RuntimeHandle, PretrainedConfig]: + model = DeciLMForCausalLM.from_hugging_face(hf_model) + logger.set_level("warning") + mapping = params.mapping + engine_buffer = self._gen_tensorrt_llm_engine( + rank=mapping.rank, + world_size=mapping.world_size, + model=model, + model_name=model_name, + use_plugin=use_plugin, + batch_size=params.batch_size, + beam_width=params.beam_width, + input_len=params.seq_len, + output_len=params.output_len, + use_refit=use_refit, + use_gemm=use_gemm, + context_fmha_flag=context_fmha_flag, + enable_remove_input_padding=params.enable_remove_input_padding, + tokens_per_block=params.tokens_per_block, + enable_paged_kv_cache=params.enable_paged_kv_cache, + **opt_flags) + runtime = RuntimeHandle(_Runtime(engine_buffer, mapping)) + return runtime, model.config + + def _from_fp8_quantized_engine( + self, + *, + model_dir: str, + quantize_dir: str, + dataset: Optional[str] = "cnn_dailymail", + params: TestParams) -> Tuple[RuntimeHandle, PretrainedConfig]: + root = get_project_root(__file__) + quantize_path = str(root / "examples/quantization/quantize.py") + + with tempfile.TemporaryDirectory( + prefix="transformed_magpie") as dataset_dir: + create_trtllm_magpie_calibration_dataset(dataset_dir) + quantize = [ + sys.executable, + quantize_path, + f"--model_dir={model_dir}", + f"--output_dir={quantize_dir}", + f"--calib_dataset={dataset_dir}", + "--dtype=bfloat16", + "--kv_cache_dtype=fp8", + "--qformat=fp8", + "--calib_size=512", + ] + print(f"Running quantize: {quantize}") + subprocess.run(quantize, check=True) + + engine_path = f"{quantize_dir}/engine" + build = [ + "trtllm-build", + f"--checkpoint_dir={quantize_dir}", + f"--output_dir={engine_path}", + f"--max_input_len={params.seq_len}", + f"--max_batch_size={params.batch_size}", + f"--remove_input_padding={'enable' if params.enable_remove_input_padding else 'disable'}", + f"--kv_cache_type={'paged' if params.enable_paged_kv_cache else 'continuous'}", + "--gemm_plugin=auto", + "--gpt_attention_plugin=auto", + ] + + if params.enable_paged_kv_cache: + build.append(f"--tokens_per_block={params.tokens_per_block}") + + print(f"Running trtllm-build: {build}") + subprocess.run(build, check=True) + + engine = Engine.from_dir(engine_path) + runtime = RuntimeHandle(_Runtime(engine.engine, params.mapping)) + config = EngineConfig.from_json_file(f"{engine_path}/config.json") + + return runtime, config.pretrained_config + + def test_config_to_from_dict(self) -> None: + config = self._make_config(layer_configs=[{ + "attention": { + "num_key_value_heads": 4 + }, + "ffn": {} + }, { + "attention": { + "num_key_value_heads": 2 + }, + "ffn": { + "impl": "no_op" + } + }, { + "attention": { + "impl": "no_op" + }, + "ffn": { + "intermediate_size": 8192 + } + }]) + + config2 = DeciConfig.from_dict(config.to_dict()) + self.assertListEqual(config.layer_configs, config2.layer_configs) + + def test_save_load_config(self) -> None: + config = self._make_config(layer_configs=[{ + "attention": { + "num_key_value_heads": 4 + }, + "ffn": {} + }, { + "attention": { + "num_key_value_heads": 2 + }, + "ffn": { + "impl": "no_op" + } + }, { + "attention": { + "impl": "no_op" + }, + "ffn": { + "intermediate_size": 8192 + } + }]) + + with tempfile.TemporaryDirectory( + prefix="test_save_load_checkpoint") as ckpt_dir: + config_file = f"{ckpt_dir}/config.json" + config.to_json_file(config_file) + config2 = DeciConfig.from_json_file(config_file) + + self.assertDictEqual(config.to_dict(), config2.to_dict()) + self.assertListEqual(config.layer_configs, config2.layer_configs) + + def get_loader_test_cases(): + model_root = llm_models_root(check=True) + test_models_base_path = Path(model_root, "nvsmall/tests") + models_path = [ + os.path.join(test_models_base_path, x) + for x in os.listdir(test_models_base_path) + ] + + params_product = [ + TestParams( + enable_paged_kv_cache=paged, + enable_remove_input_padding=padded, + dtype=dtype, + ) for paged, padded, dtype in itertools.product( + [True, False], + [True, False], + ["bfloat16", "float16"], + ) + ] + test_cases = list(itertools.product(models_path, params_product)) + + return test_cases + + @parameterized.expand(get_loader_test_cases, name_func=unittest_name_func) + def test_allclose_to_hf(self, hf_model_dir: str, params: TestParams): + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + hf_model_dir, + trust_remote_code=True, + torch_dtype=tensorrt_llm._utils.str_dtype_to_torch(params.dtype), + ).cuda() + runtime, config = self._from_hf_model(hf_model, params) + self.allclose( + runtime, + config=config, + params=params, + obtain_hf_model=lambda: hf_model, + ) + + def allclose( + self, + runtime_handle: RuntimeHandle, + *, + config: PretrainedConfig, + params: TestParams, + obtain_hf_model: Callable[[], transformers.AutoModelForCausalLM], + ): + batch_size = params.batch_size + beam_width = params.beam_width + seq_len = params.seq_len + total_length = params.total_length + dtype = config.dtype + tokens_per_block = params.tokens_per_block + enable_remove_input_padding = params.enable_remove_input_padding + enable_paged_kv_cache = params.enable_paged_kv_cache + + key_value_cache_buffers = [] + head_size = config.hidden_size // config.num_attention_heads + attn_layer_idx = [ + i for i in range(config.num_hidden_layers) + if config.get_layer_config(i).attention.needs_kv_cache + ] + + if enable_paged_kv_cache: + num_blocks = batch_size * beam_width * math.ceil( + total_length / tokens_per_block) + + memory_pools_allocator = MemoryPoolsAllocator( + num_blocks=num_blocks, + tokens_per_block=tokens_per_block, + head_size=head_size) + if config.num_kv_heads_per_layer is None: + num_kv_heads = config.get_layer_config( + attn_layer_idx[0]).attention.num_key_value_heads + num_kv_heads_per_layer = MemoryPoolsAllocator.prepare_num_kv_heads_per_layer( + num_kv_heads, len(attn_layer_idx)) + else: + num_kv_heads_per_layer = config.num_kv_heads_per_layer + + memory_pools_allocator.allocate(dtype, num_kv_heads_per_layer) + max_blocks_per_seq = math.ceil(total_length / tokens_per_block) + num_blocks = batch_size * beam_width * max_blocks_per_seq + + pools_kv_cache_manager = PoolsKVCacheManager( + memory_pools_allocator.pools_metadata, + max_blocks_per_seq, + num_blocks, + tokens_per_block, + head_size, + max_attention_window_size=total_length, + beam_width=beam_width, + sink_token_len=0) + # Add sequences to the manager + for bi in range(batch_size): + generation_sequence = GenerationSequence(seq_idx=bi, + batch_idx=bi) + pools_kv_cache_manager.add_sequence(generation_sequence, + seq_len) + + # Pre allocate the kv cache for the generated tokens. + pools_kv_cache_manager.step([False] * batch_size) + + else: + for layer_idx in attn_layer_idx: + layer_config = config.get_layer_config(layer_idx) + new_cache = torch.zeros(( + batch_size, + 2, + layer_config.attention.num_key_value_heads, + total_length, + head_size, + ), + dtype=str_dtype_to_torch(dtype), + device='cuda') + key_value_cache_buffers.append(new_cache) + + cache_indirections = [ + torch.full(( + batch_size, + beam_width, + total_length, + ), + 0, + dtype=torch.int32, + device='cuda'), + torch.full(( + batch_size, + beam_width, + total_length, + ), + 0, + dtype=torch.int32, + device='cuda') + ] # ping-pong buffers + + def run_engine(context, + input_ids, + context_lengths, + host_request_types, + position_ids, + last_token_ids, + cache_indirection, + host_past_key_value_lengths, + host_max_attention_window_sizes, + host_sink_token_length, + host_runtime_perf_knobs, + sequence_length=None, + host_context_lengths=None): + + ctx_buffer = { + 'input_ids': input_ids, + 'context_lengths': context_lengths, + 'host_request_types': host_request_types, + 'position_ids': position_ids, + 'last_token_ids': last_token_ids, + 'cache_indirection': cache_indirection, + 'host_past_key_value_lengths': host_past_key_value_lengths, + 'sequence_length': sequence_length, + 'host_sink_token_length': host_sink_token_length, + 'host_runtime_perf_knobs': host_runtime_perf_knobs + } + + assert host_request_types is not None + if enable_remove_input_padding: + assert host_context_lengths is not None, "host_context_lengths is required for ragged input" + ctx_buffer['host_context_lengths'] = host_context_lengths + + if enable_paged_kv_cache: + assert beam_width == 1 + # for beam_width > 1 the argument must be '1' in ctx phase and 'beam_width' in gen phase + host_kv_cache_block_offsets = pools_kv_cache_manager.get_block_offsets( + beam_width=1) + kv_cache_block_offsets = host_kv_cache_block_offsets.to('cuda') + shape = kv_cache_block_offsets.shape + target_shape = [shape[0], shape[1] * shape[2], *shape[3:]] + ctx_buffer[ + f'kv_cache_block_offsets'] = kv_cache_block_offsets.reshape( + target_shape) + ctx_buffer[ + f'host_kv_cache_block_offsets'] = host_kv_cache_block_offsets.reshape( + target_shape) + ctx_buffer[ + f'host_kv_cache_pool_pointers'] = memory_pools_allocator.get_kv_cache_pool_pointers( + ).contiguous() + ctx_buffer[ + f'host_kv_cache_pool_mapping'] = memory_pools_allocator.pool_mapping.contiguous( + ) + ctx_buffer[ + f'host_max_attention_window_sizes'] = host_max_attention_window_sizes + else: + for layer_idx, buf in zip(attn_layer_idx, + key_value_cache_buffers): + ctx_buffer[f'past_key_value_{layer_idx}'] = buf + ctx_buffer[f'present_key_value_{layer_idx}'] = buf + ctx_buffer[ + f'host_max_attention_window_sizes'] = host_max_attention_window_sizes + + ctx_shape = { + key: buffer.shape + for key, buffer in ctx_buffer.items() + } + + runtime_handle.runtime._set_shape(context, ctx_shape) + runtime_handle.runtime._set_buffer(context, ctx_buffer) + runtime_handle.runtime._run(context) + torch.cuda.synchronize() + res = ctx_buffer['logits'] + return res + + step0_ids = torch.randint(100, (batch_size, seq_len)).int().cuda() + step1_ids = torch.randint(100, (batch_size, 1)).int().cuda() + + def tllm() -> Tuple[np.ndarray, np.ndarray]: + ctx_ids = step0_ids.clone() + + ctx_context_lengths = seq_len * torch.ones( + (batch_size), dtype=torch.int32, device='cuda') + ctx_position_ids = torch.tensor(range(seq_len), + dtype=torch.int32).reshape([ + 1, seq_len + ]).expand([batch_size, + seq_len]).cuda() + ctx_last_token_ids = ctx_context_lengths.clone() + + if enable_remove_input_padding: + ctx_ids = ctx_ids.view([batch_size * seq_len]) + ctx_position_ids = ctx_position_ids.view([batch_size * seq_len]) + ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, + dim=0).int() + + host_max_attention_window_sizes = torch.tensor([total_length] * + len(attn_layer_idx), + dtype=torch.int32) + host_sink_token_length = torch.tensor([0], dtype=torch.int32) + + host_context_lengths = ctx_context_lengths.cpu( + ) if enable_remove_input_padding else None + host_request_types = torch.tensor([0 for i in range(batch_size)], + dtype=torch.int32).cpu() + + host_past_key_value_lengths = ctx_context_lengths.detach().clone( + ).cpu() + # We need sequence_lengths start as context_lengths for step 0 (context), + # and it will be added one after each step. + sequence_length = ctx_context_lengths.detach().clone() + + perf_knob_tensor_size = 16 + ctx_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size, + dtype=torch.int64) + + step0 = run_engine( + context=runtime_handle.runtime.ctx_context, + input_ids=ctx_ids, + context_lengths=ctx_context_lengths, + position_ids=ctx_position_ids, + last_token_ids=ctx_last_token_ids, + cache_indirection=cache_indirections[0], + host_past_key_value_lengths=host_past_key_value_lengths, + host_max_attention_window_sizes=host_max_attention_window_sizes, + host_sink_token_length=host_sink_token_length, + sequence_length=sequence_length, + host_context_lengths=host_context_lengths, + host_request_types=host_request_types, + host_runtime_perf_knobs=ctx_runtime_perf_knobs) + + step = 1 + gen_ids = step1_ids.clone() + + gen_context_lengths = seq_len * torch.ones( + (batch_size), dtype=torch.int32, device='cuda') + gen_position_ids = torch.ones_like(gen_ids).int().cuda() * seq_len + gen_last_token_ids = torch.zeros_like( + gen_context_lengths).int().cuda() + + if enable_remove_input_padding: + gen_ids = gen_ids.view([batch_size]) + gen_position_ids = gen_position_ids.view([batch_size]) + gen_last_token_ids = torch.ones_like( + gen_context_lengths).int().cuda() + gen_last_token_ids = torch.cumsum(gen_last_token_ids, + dim=0).int() + + host_past_key_value_lengths = torch.tensor([seq_len + step - 1] * + batch_size, + dtype=torch.int32) + host_max_attention_window_sizes = torch.tensor([seq_len + step] * + len(attn_layer_idx), + dtype=torch.int32) + host_sink_token_length = torch.tensor([0], dtype=torch.int32) + + host_context_lengths = gen_context_lengths.cpu( + ) if enable_remove_input_padding else None + host_request_types = torch.tensor([1 for i in range(batch_size)], + dtype=torch.int32).cpu() + + # For step 1, the sequence_lengths = context_lengths + 1. + sequence_length = torch.add(gen_context_lengths.detach().clone(), 1) + + perf_knob_tensor_size = 16 + gen_runtime_perf_knobs = torch.tensor([-1] * perf_knob_tensor_size, + dtype=torch.int64) + + step1 = run_engine( + context=runtime_handle.runtime.context_1, + input_ids=gen_ids, + context_lengths=gen_context_lengths, + position_ids=gen_position_ids, + last_token_ids=gen_last_token_ids, + cache_indirection=cache_indirections[1], + host_past_key_value_lengths=host_past_key_value_lengths, + host_max_attention_window_sizes=host_max_attention_window_sizes, + host_sink_token_length=host_sink_token_length, + sequence_length=sequence_length, + host_context_lengths=host_context_lengths, + host_request_types=host_request_types, + host_runtime_perf_knobs=gen_runtime_perf_knobs) + + return step0, step1 + + def hf() -> Tuple[np.ndarray, np.ndarray]: + with torch.no_grad(): + hf_model = obtain_hf_model() + step0_outputs = hf_model.forward(step0_ids.clone()) + torch.cuda.synchronize() + step0 = step0_outputs.logits[:, -1, :] + step1_outputs = hf_model.forward( + step1_ids.clone(), + past_key_values=step0_outputs.past_key_values, + use_cache=True, + ) + torch.cuda.synchronize() + step1 = step1_outputs.logits[:, -1, :] + + return step0, step1 + + res_step0, res_step1 = tllm() + del runtime_handle.runtime + ref_step0, ref_step1 = hf() + np.testing.assert_allclose(ref_step0.cpu().numpy().flatten(), + res_step0.cpu().numpy().flatten(), + atol=1e-1) + np.testing.assert_allclose(ref_step1.cpu().numpy().flatten(), + res_step1.cpu().numpy().flatten(), + atol=1e-1) + + @parameterized.expand(get_loader_test_cases, name_func=unittest_name_func) + @pytest.mark.skipif( + os.environ.get("TEST_NEMOTRON_NAS_FP8_ALLCLOSE") is None, + reason="fp8 accuracy is low.") + def test_allclose_to_hf_fp8(self, hf_model_dir: str, params: TestParams): + with tempfile.TemporaryDirectory("quantize_dir") as quantize_dir: + runtime, config = self._from_fp8_quantized_engine( + model_dir=hf_model_dir, + quantize_dir=quantize_dir, + params=params) + self.allclose( + runtime, + config=config, + params=params, + obtain_hf_model=lambda: transformers.AutoModelForCausalLM. + from_pretrained( + hf_model_dir, + trust_remote_code=True, + torch_dtype=tensorrt_llm._utils.str_dtype_to_torch(params.dtype + ), + ).cuda(), + ) + + @pytest.mark.skipif( + os.environ.get("NEMOTRON_NAS_CKPT") is None + or os.environ.get("NEMOTRON_NAS_OUTPUT_DIR") is None, + reason="You must define NEMOTRON_NAS_CKPT, NEMOTRON_NAS_OUTPUT_DIR", + ) + def test_allclose_to_hf_fp8_accelerate(self): + hf_model_dir = os.environ["NEMOTRON_NAS_CKPT"] + output_dir = os.environ["NEMOTRON_NAS_OUTPUT_DIR"] + params = TestParams(enable_paged_kv_cache=True, + enable_remove_input_padding=True, + dtype="float16", + seq_len=2048) + runtime, config = self._from_fp8_quantized_engine( + model_dir=hf_model_dir, quantize_dir=str(output_dir), params=params) + self.allclose( + runtime, + config=config, + params=params, + obtain_hf_model=lambda: transformers.AutoModelForCausalLM. + from_pretrained( + hf_model_dir, + trust_remote_code=True, + torch_dtype=tensorrt_llm._utils.str_dtype_to_torch(params.dtype + ), + device_map="auto", + ), + ) + + @parameterized.expand( + itertools.product(("nvidia/Llama-3_1-Nemotron-51B-Instruct", ), + (True, False), (1, 2), (1, 2), + ("auto", "float16", "bfloat16"))) + def test_convert_config_from_hf(self, ckpt_path: Optional[str], + preloaded: bool, tp_size: int, pp_size: int, + dtype: str) -> None: + hf_config = transformers.AutoConfig.from_pretrained( + ckpt_path, trust_remote_code=True) + + mapping = Mapping(world_size=(tp_size * pp_size), + rank=0, + gpus_per_node=1, + tp_size=tp_size, + pp_size=pp_size) + + config = DeciConfig.from_hugging_face( + hf_config if preloaded else ckpt_path, + dtype=dtype, + mapping=mapping, + trust_remote_code=not preloaded) + + if getattr(hf_config, "num_key_value_heads_per_layer", + None) is not None: + # verify layers for old config + for layer_idx, num_kv_heads in enumerate( + hf_config.num_key_value_heads_per_layer): + layer_config = config.get_layer_config(layer_idx) + self.assertEqual(layer_config.attention.impl, + AttentionImplementation.ATTENTION) + self.assertEqual(num_kv_heads, + layer_config.attention.num_key_value_heads) + self.assertEqual(layer_config.ffn.impl, FFNImplementation.MLP) + self.assertEqual(layer_config.ffn.intermediate_size, + config.intermediate_size) + + elif getattr(hf_config, "block_configs", None) is not None: + # verify layers for new config + for layer_idx, block_config in enumerate(hf_config.block_configs): + layer_config = config.get_layer_config(layer_idx) + if layer_config.attention.impl == AttentionImplementation.ATTENTION: + self.assertFalse(block_config.attention.no_op) + self.assertFalse(block_config.attention.replace_with_linear) + self.assertEqual( + config.num_attention_heads // + block_config.attention.n_heads_in_group, + layer_config.attention.num_key_value_heads) + elif layer_config.attention.impl == AttentionImplementation.NO_OP: + self.assertTrue(block_config.attention.no_op) + elif layer_config.attention.impl == AttentionImplementation.LINEAR: + self.assertTrue(block_config.attention.replace_with_linear) + + if layer_config.ffn.impl == FFNImplementation.MLP: + self.assertFalse(block_config.ffn.no_op) + self.assertFalse(block_config.ffn.replace_with_linear) + self.assertEqual( + _ffn_mult_to_intermediate_size( + block_config.ffn.ffn_mult, config.hidden_size), + layer_config.ffn.intermediate_size) + elif layer_config.ffn.impl == FFNImplementation.NO_OP: + self.assertTrue(block_config.ffn.no_op) + elif layer_config.ffn.impl == FFNImplementation.LINEAR: + self.assertTrue(block_config.ffn.replace_with_linear) + + # verify config is valid enough for model creation + DeciLMForCausalLM(config) + + @parameterized.expand( + itertools.product( + os.listdir( + Path(llm_models_root(check=True), "nvsmall/tests").as_posix()), + (True, False), (1, 2), (1, 2), ("auto", "float16", "bfloat16"))) + def test_convert_model_from_hf(self, model_dir: Optional[str], + preloaded: bool, tp_size: int, pp_size: int, + dtype: str) -> None: + ckpt_path = Path(llm_models_root(check=True), "nvsmall/tests", + model_dir) + + if preloaded: + hf_model_or_dir = transformers.AutoModelForCausalLM.from_pretrained( + ckpt_path, trust_remote_code=True) + else: + hf_model_or_dir = ckpt_path + + mapping = Mapping(world_size=(tp_size * pp_size), + rank=0, + gpus_per_node=1, + tp_size=tp_size, + pp_size=pp_size) + + DeciLMForCausalLM.from_hugging_face(hf_model_or_dir=hf_model_or_dir, + dtype=dtype, + mapping=mapping, + trust_remote_code=not preloaded) + + @parameterized.expand( + itertools.product( + os.listdir( + Path(llm_models_root(check=True), "nvsmall/tests").as_posix()), + (1, 2, 4))) + def test_weights_loader(self, model_dir: str, tp_size: int) -> None: + + ckpt_path = Path(llm_models_root(check=True), "nvsmall/tests", + model_dir) + config = DeciConfig.from_hugging_face(ckpt_path, trust_remote_code=True) + weights = load_weights_from_hf_safetensors(ckpt_path, config) + + shard_configs = [ + DeciConfig.from_hugging_face(ckpt_path, + trust_remote_code=True, + mapping=Mapping(world_size=tp_size, + tp_size=tp_size, + rank=rank)) + for rank in range(tp_size) + ] + shard_weights = [ + load_weights_from_hf_safetensors(ckpt_path, shard_config) + for shard_config in shard_configs + ] + + for name, param in weights.items(): + shards = [shard[name] for shard in shard_weights] + + if name.endswith("attention.weight"): + # linear attention + combined = torch.cat(shards, dim=0) + torch.testing.assert_close(combined, param, atol=0, rtol=0) + elif name.endswith("attention.qkv.weight"): + # proper attention + layer_idx = int( + re.match("transformer.layers.(\\d+).", name).groups()[0]) + layer_config = config.layer_configs[layer_idx] + num_kv_heads = int(layer_config.attention.num_key_value_heads) + num_kv_heads_tp = (num_kv_heads + tp_size - 1) // tp_size + dups = tp_size // num_kv_heads or 1 + q, k, v = torch.split(param, [ + config.num_attention_heads * config.head_size, + num_kv_heads * config.head_size, + num_kv_heads * config.head_size + ]) + + q_shards, k_shards, v_shards = [], [], [] + for rank, shard in enumerate(shards): + qt, kt, vt = torch.split( + shard, + [(config.num_attention_heads // tp_size) * + config.head_size, num_kv_heads_tp * config.head_size, + num_kv_heads_tp * config.head_size]) + q_shards.append(qt) + if rank % dups == 0: + k_shards.append(kt) + v_shards.append(vt) + + combined_q = torch.cat(q_shards, dim=0) + combined_k = torch.cat(k_shards, dim=0) + combined_v = torch.cat(v_shards, dim=0) + + torch.testing.assert_close(combined_q, q, atol=0, rtol=0) + torch.testing.assert_close(combined_k, k, atol=0, rtol=0) + torch.testing.assert_close(combined_v, v, atol=0, rtol=0) + + @parameterized.expand(itertools.product([True, False], + ["float16", "bfloat16"], [None], + [None]), + name_func=unittest_name_func) + def test_vgqa_model_runner_allclose(self, use_py_session, dtype, engine_dir, + hf_model_dir): + input_text = "Born in north-east France, Soyer trained as a" + tokenizer_dir = hf_model_dir + + if engine_dir is None or not Path(engine_dir).exists: + self.skipTest(f"Engine dir is either None or doesn't exist") + if hf_model_dir is None or not Path(hf_model_dir).exists: + self.skipTest( + f"Missing HF checkpoint, define a valid checkpoint path with the NEMOTRON_NAS_CKPT environment variable" + ) + + dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype) + + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + hf_model_dir, trust_remote_code=True, torch_dtype=dtype).cuda() + + batch_size = 1 + max_seq_len = 30 + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, + padding_side="left", + truncation_side="left", + trust_remote_code=True, + use_fast=True) + batch_input_ids = [ + torch.tensor(tokenizer.encode(input_text, + add_special_tokens=True, + truncation=True), + dtype=torch.int32) + ] + + hf_batch_ids = batch_input_ids[0].unsqueeze(0).repeat(batch_size, + 1).cuda() + in_tokens = batch_input_ids[0].shape[0] + + with torch.no_grad(): + hf_outputs = hf_model.generate(hf_batch_ids, max_length=max_seq_len) + + torch.cuda.synchronize() + + if use_py_session: + runner = ModelRunner.from_dir(engine_dir=engine_dir, + rank=0, + debug_mode=False) + + else: + runner = ModelRunnerCpp.from_dir(engine_dir=engine_dir, + rank=0, + debug_mode=False) + + pad_token_id = tokenizer.pad_token_id + if tokenizer.pad_token_id is None: + pad_token_id = tokenizer.eos_token_id + + with torch.no_grad(): + runner_outputs = runner.generate(batch_input_ids=batch_input_ids, + max_new_tokens=max_seq_len - + in_tokens, + end_id=tokenizer.eos_token_id, + pad_id=pad_token_id, + output_sequence_lengths=True, + return_dict=False) + + torch.cuda.synchronize() + + del runner + + if not use_py_session: + np.testing.assert_allclose( + runner_outputs[0][0][:max_seq_len].cpu().numpy(), + hf_outputs[0].cpu().numpy()) + else: + np.testing.assert_allclose(runner_outputs[0].cpu().numpy(), + hf_outputs.cpu().numpy()) diff --git a/tests/test_export.py b/tests/test_export.py index e34cfa1a7..71b3c3445 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -17,12 +17,12 @@ import numpy as np import torch -from tensorrt_llm.models.gpt.convert import generate_int8 +from tensorrt_llm.models.convert_utils import generate_int8 def dist(x, y): - x = x.flatten().astype(float) - y = y.flatten().astype(float) + x = x.flatten().to(float) + y = y.flatten().to(float) l2_x = np.linalg.norm(x) l2_y = np.linalg.norm(y) @@ -47,6 +47,9 @@ def setUp(self): "y": torch.from_numpy(np.abs(y).max(axis=0)), "w": torch.from_numpy(np.abs(w).max(axis=0)), } + x = torch.from_numpy(x) + y = torch.from_numpy(y) + w = torch.from_numpy(w) values = generate_int8(w, ranges) self.x, self.y, self.w = x, y, w @@ -67,13 +70,15 @@ def test_weight_quantization(self): def test_e2e_gemm_quantization(self): # mimic what CUTLASS would do x_i8 = (self.x * self.values["scale_x_orig_quant"]).round().clip( - -127, 127) - y_i32 = x_i8 @ self.values["weight.int8"].astype(np.int32) + -127, 127).to(torch.int32) + # import pdb + # pdb.set_trace() + y_i32 = x_i8 @ self.values["weight.int8"].to(torch.int32) y_quant = y_i32 * self.values["scale_y_accum_quant"] * self.values[ "scale_y_quant_orig"] y_angle = dist(self.y, y_quant)[1] - y_i32_col = x_i8 @ self.values["weight.int8.col"].astype(np.int32) + y_i32_col = x_i8 @ self.values["weight.int8.col"].to(torch.int32) y_quant_col = y_i32_col * self.values[ "scale_y_accum_quant.col"] * self.values["scale_y_quant_orig"] y_angle_col = dist(self.y, y_quant_col)[1] diff --git a/tests/utils/util.py b/tests/utils/util.py index eedf3aac4..74ac97fe8 100644 --- a/tests/utils/util.py +++ b/tests/utils/util.py @@ -1,6 +1,7 @@ import os import unittest from difflib import SequenceMatcher +from pathlib import Path import pytest import tensorrt as trt @@ -13,7 +14,7 @@ from tensorrt_llm.hlapi.utils import get_total_gpu_memory from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import TensorInfo +from tensorrt_llm.runtime import Session, TensorInfo def ASSERT_DRV(err): @@ -229,11 +230,11 @@ def create_session(builder, builder_config.trt_builder_config.clear_flag(trt.BuilderFlag.TF32) engine = builder.build_engine(network, builder_config) assert engine is not None, "Failed to build engine" - session = tensorrt_llm.runtime.Session.from_serialized_engine(engine) + session = Session.from_serialized_engine(engine) return session -def run_session(session, inputs, outputs={}, override_shapes={}): +def run_session(session: Session, inputs, outputs={}, override_shapes={}): """ The current session object needs to pass in both inputs and outputs bindings. For test convenience, create a function that infers output shapes automatically, @@ -279,3 +280,8 @@ def similarity_score(a, b): def similar(a, b, threshold=0.8): "similar compare a and b " return similarity_score(a, b) >= threshold + + +def get_project_root(test_file: str) -> Path: + return next(p for p in Path(test_file).resolve().parents + if (p / 'tests').is_dir() and (p / "tensorrt_llm").is_dir()) diff --git a/windows/setup_build_env.ps1 b/windows/setup_build_env.ps1 index e2de12356..b9148da2c 100644 --- a/windows/setup_build_env.ps1 +++ b/windows/setup_build_env.ps1 @@ -45,21 +45,21 @@ if (-not $skipVSBuildTools) { Write-Output "Skipping Visual Studio Build Tools installation" } -# Install TensorRT 10.3.0.26 for TensorRT-LLM +# Install TensorRT 10.4.0.26 for TensorRT-LLM if (-not $skipTRT) { Write-Output "Downloading TensorRT" - Invoke-WebRequest -Uri 'https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip' -OutFile 'TensorRT-10.3.0.26.zip' + Invoke-WebRequest -Uri 'https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip' -OutFile 'TensorRT-10.4.0.26.zip' Write-Output "Extracting TensorRT" # Get path $absolutePath = Resolve-Path $TRTPath - Expand-Archive -Path '.\TensorRT-10.3.0.26.zip' -DestinationPath $absolutePath + Expand-Archive -Path '.\TensorRT-10.4.0.26.zip' -DestinationPath $absolutePath Write-Output "Removing TensorRT zip" - Remove-Item -Path 'TensorRT-10.3.0.26.zip' -Force + Remove-Item -Path 'TensorRT-10.4.0.26.zip' -Force Write-Output "Adding TensorRT to system Path" - [Environment]::SetEnvironmentVariable('Path', "$env:Path;$absolutePath\TensorRT-10.3.0.26\lib", [EnvironmentVariableTarget]::Machine) + [Environment]::SetEnvironmentVariable('Path', "$env:Path;$absolutePath\TensorRT-10.4.0.26\lib", [EnvironmentVariableTarget]::Machine) Write-Output "Installing TensorRT Python wheel" - python3 -m pip install $absolutePath\TensorRT-10.3.0.26\python\tensorrt-10.3.0-cp310-none-win_amd64.whl - Write-Output "Done TensorRT installation at '$absolutePath\TensorRT-10.3.0.26'" + python3 -m pip install $absolutePath\TensorRT-10.4.0.26\python\tensorrt-10.4.0-cp310-none-win_amd64.whl + Write-Output "Done TensorRT installation at '$absolutePath\TensorRT-10.4.0.26'" } else { Write-Output "Skipping TensorRT installation" } diff --git a/windows/setup_env.ps1 b/windows/setup_env.ps1 index dba449d8d..cc79ab2f2 100644 --- a/windows/setup_env.ps1 +++ b/windows/setup_env.ps1 @@ -43,6 +43,8 @@ if (-not $skipCUDA){ $cudaUri = 'https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_551.78_windows.exe' } elseif ($cudaVersion -eq "12.5.1"){ $cudaUri = 'https://developer.download.nvidia.com/compute/cuda/12.5.1/local_installers/cuda_12.5.1_555.85_windows.exe' + } elseif ($cudaVersion -eq "12.6.0"){ + $cudaUri = 'https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda_12.6.0_560.76_windows.exe' } else { $cudaUri = Read-Host "Please go to https://developer.nvidia.com/cuda-downloads and input the url of the CUDA version you wish to use" } @@ -146,7 +148,7 @@ if(-not $skipCUDNN){ Add-Content -Path $env:LOCALAPPDATA\trt_env_outlog.txt -Value "0" New-Item -Path $env:LOCALAPPDATA\CUDNN -ItemType Directory -Force $ProgressPreference = 'SilentlyContinue' - Invoke-WebRequest -Uri 'https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.1.0.70_cuda12-archive.zip' -OutFile $env:LOCALAPPDATA\CUDNN\cudnn.zip + Invoke-WebRequest -Uri 'https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/windows-x86_64/cudnn-windows-x86_64-9.2.1.18_cuda12-archive.zip' -OutFile $env:LOCALAPPDATA\CUDNN\cudnn.zip Expand-Archive -Path $env:LOCALAPPDATA\CUDNN\cudnn.zip -DestinationPath $env:LOCALAPPDATA\CUDNN\cudnn_unzip New-Item -Path ".\" -Name "CUDNN" -ItemType "directory" @@ -156,9 +158,9 @@ if(-not $skipCUDNN){ New-Item -Path $binPath -ItemType Directory New-Item -Path $includePath -ItemType Directory New-Item -Path $libPath -ItemType Directory - Copy-Item -Path "$env:LOCALAPPDATA\CUDNN\cudnn_unzip\cudnn-windows-x86_64-9.1.0.70_cuda12-archive\bin\*" -Destination $binPath - Copy-Item -Path "$env:LOCALAPPDATA\CUDNN\cudnn_unzip\cudnn-windows-x86_64-9.1.0.70_cuda12-archive\include\*" -Destination $includePath - Copy-Item -Path "$env:LOCALAPPDATA\CUDNN\cudnn_unzip\cudnn-windows-x86_64-9.1.0.70_cuda12-archive\lib\x64\*" -Destination $libPath + Copy-Item -Path "$env:LOCALAPPDATA\CUDNN\cudnn_unzip\cudnn-windows-x86_64-9.2.1.18_cuda12-archive\bin\*" -Destination $binPath + Copy-Item -Path "$env:LOCALAPPDATA\CUDNN\cudnn_unzip\cudnn-windows-x86_64-9.2.1.18_cuda12-archive\include\*" -Destination $includePath + Copy-Item -Path "$env:LOCALAPPDATA\CUDNN\cudnn_unzip\cudnn-windows-x86_64-9.2.1.18_cuda12-archive\lib\x64\*" -Destination $libPath [Environment]::SetEnvironmentVariable("CUDNN", "$PWD;$binPath;$includePath;$libPath", [EnvironmentVariableTarget]::Machine) @@ -181,10 +183,10 @@ if (-not ($skipTRT)) { Write-Output "Grabbing TensorRT..." $ProgressPreference = 'SilentlyContinue' New-Item -Path .\TensorRT -ItemType Directory - Invoke-WebRequest -Uri 'https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip' -OutFile .\TensorRT\trt.zip + Invoke-WebRequest -Uri 'https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip' -OutFile .\TensorRT\trt.zip Expand-Archive -Path .\TensorRT\trt.zip -DestinationPath .\TensorRT\ Remove-Item -Path .\TensorRT\trt.zip -Force - $trtPath = Join-Path $TRT_BASE TensorRT-10.3.0.26 + $trtPath = Join-Path $TRT_BASE TensorRT-10.4.0.26 Write-Output "TensorRT installed at ${trtPath}" $trtSubPaths = @{