diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index f171fd0b7..95fb95a99 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -172,16 +172,16 @@ struct BenchmarkParams std::optional>> medusaChoices; }; -class InferenceRequestsSyncSend +class InferenceRequestsAsyncSend { public: - InferenceRequestsSyncSend(std::shared_ptr comm, + InferenceRequestsAsyncSend(std::shared_ptr comm, std::list> const& inferenceRequests, int const peer) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_LOG_DEBUG("start send requests to rank %d", peer); mNumNewWorkItems = static_cast(inferenceRequests.size()); - comm->send(&mNumNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0); + mRequest1 = comm->sendAsync(&mNumNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0); if (mNumNewWorkItems > 0) { for (auto const& infReq : inferenceRequests) @@ -191,16 +191,31 @@ class InferenceRequestsSyncSend mPacked.insert(mPacked.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); } mVecSize = static_cast(mPacked.size()); - comm->send(&mVecSize, 1, mpi::MpiType::kINT64, peer, 1); - comm->send(mPacked.data(), mPacked.size(), mpi::MpiType::kINT64, peer, 2); + mRequest2 = comm->sendAsync(&mVecSize, 1, mpi::MpiType::kINT64, peer, 1); + mRequest3 = comm->sendAsync(mPacked.data(), mPacked.size(), mpi::MpiType::kINT64, peer, 2); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } + ~InferenceRequestsAsyncSend() + { + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mRequest1->wait(); + if (mRequest2) + mRequest2->wait(); + if (mRequest3) + mRequest3->wait(); + TLLM_LOG_DEBUG("end send requests"); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + } + private: int64_t mNumNewWorkItems; int64_t mVecSize; std::vector mPacked; + std::shared_ptr mRequest1; + std::shared_ptr mRequest2; + std::shared_ptr mRequest3; }; } // namespace @@ -930,7 +945,6 @@ class GptServer , mStaticEmulatedBatchSize(staticEmulatedBatchSize) , mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0})) , mActiveCount(0) - , mInferReqSyncSndHdl(nullptr) { auto const jsonConfig = GptJsonConfig::parse(trtEnginePath / "config.json"); mWorldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(), @@ -966,6 +980,12 @@ class GptServer ~GptServer() { + if (mInferReqWaitThread) + { + mInferReqWaitThread->join(); + mInferReqWaitThread.reset(nullptr); + } + mWorkItemsQueue.clear(); } @@ -1031,7 +1051,11 @@ class GptServer // Return up to max_num_requests inference requests. std::list> getInferenceRequests(int const max_num_requests) { - mInferReqSyncSndHdl = nullptr; + if (mInferReqWaitThread) + { + mInferReqWaitThread->join(); + mInferReqWaitThread.reset(nullptr); + } std::list> inferenceRequests; auto& comm = COMM_SESSION; if (max_num_requests > 0) @@ -1134,8 +1158,9 @@ class GptServer if (!mWorldConfig.isLastPipelineParallelRank()) { auto const peer = mWorldConfig.getPipelineParallelRank() + 1; - mInferReqSyncSndHdl - = std::make_shared(mCommPipelineParallel, inferenceRequests, peer); + auto inferReqAsyncSndHdl + = std::make_unique(mCommPipelineParallel, inferenceRequests, peer); + mInferReqWaitThread = std::make_unique([handle = std::move(inferReqAsyncSndHdl)]() {}); } } return inferenceRequests; @@ -1184,7 +1209,7 @@ class GptServer WorldConfig mWorldConfig; std::shared_ptr mCommTensorParallel; std::shared_ptr mCommPipelineParallel; - std::shared_ptr mInferReqSyncSndHdl; + std::unique_ptr mInferReqWaitThread; }; // class GptServer diff --git a/benchmarks/python/allowed_configs.py b/benchmarks/python/allowed_configs.py index 3b0550abb..9f8ff7ce2 100644 --- a/benchmarks/python/allowed_configs.py +++ b/benchmarks/python/allowed_configs.py @@ -60,14 +60,20 @@ class BuildConfig: parallel_attention: bool = None new_decoder_architecture: bool = None state_size: int = 0 - state_dtype: Optional[str] = None + state_dtype: Optional[str] = "" conv_kernel: int = 0 layer_types: List[str] = field(default_factory=list) rnn_hidden_size: int = 0 + rnn_head_size: int = 0 + rnn_conv_dim_size: int = 0 logits_soft_cap: float = 0.0 opt_batch_size: Optional[int] = None opt_num_tokens: Optional[int] = None use_bias: bool = None + mamba_version: str = 'Mamba1' + ssm_rmsnorm: bool = True + ngroups: int = 1 + chunk_size: int = 256 @dataclass @@ -1218,6 +1224,7 @@ class ModelConfig: state_size=16, conv_kernel=4, rnn_hidden_size=5120, + rnn_conv_dim_size=5120, layer_types=["recurrent"], use_bias=False, )), @@ -1238,6 +1245,7 @@ class ModelConfig: state_size=16, conv_kernel=4, rnn_hidden_size=4096, + rnn_conv_dim_size=4096, layer_types=["recurrent"], use_bias=False, )), @@ -1258,6 +1266,7 @@ class ModelConfig: state_size=16, conv_kernel=4, rnn_hidden_size=3072, + rnn_conv_dim_size=3072, layer_types=["recurrent"], use_bias=False, )), @@ -1278,6 +1287,7 @@ class ModelConfig: state_size=16, conv_kernel=4, rnn_hidden_size=2048, + rnn_conv_dim_size=2048, layer_types=["recurrent"], use_bias=False, )), @@ -1298,9 +1308,62 @@ class ModelConfig: state_size=16, conv_kernel=4, rnn_hidden_size=1536, + rnn_conv_dim_size=1536, layer_types=["recurrent"], use_bias=False, )), + "mamba2_2.7b": + ModelConfig(name="mamba2_2.7b", + family="mamba", + benchmark_type="gpt", + build_config=BuildConfig( + num_layers=64, + num_heads=1, + hidden_size=2560, + vocab_size=50288, + hidden_act="silu", + n_positions=8192, + max_batch_size=64, + max_input_len=1024, + max_seq_len=2048, + state_size=128, + conv_kernel=4, + rnn_hidden_size=5120, + rnn_conv_dim_size=5376, + rnn_head_size=64, + layer_types=["recurrent"], + use_bias=False, + mamba_version='Mamba2', + ssm_rmsnorm=True, + ngroups=1, + chunk_size=256, + )), + "mamba2_130m": + ModelConfig(name="mamba2_130m", + family="mamba", + benchmark_type="gpt", + build_config=BuildConfig( + num_layers=24, + num_heads=1, + hidden_size=768, + vocab_size=50288, + hidden_act="silu", + n_positions=8192, + max_batch_size=64, + max_input_len=1024, + max_seq_len=2048, + state_size=128, + conv_kernel=4, + rnn_hidden_size=1536, + rnn_conv_dim_size=1792, + rnn_head_size=64, + layer_types=["recurrent"], + use_bias=False, + mamba_version='Mamba2', + ssm_rmsnorm=True, + ngroups=1, + chunk_size=256, + )), "whisper_large_v3": ModelConfig(name="whisper_large_v3", family="whisper", @@ -1344,6 +1407,7 @@ class ModelConfig: state_size=1, layer_types=["recurrent", "recurrent", "attention"], rnn_hidden_size=2560, + rnn_conv_dim_size=2560, logits_soft_cap=30.0, state_dtype="float32", )), diff --git a/benchmarks/python/build.py b/benchmarks/python/build.py index 7adcc0dfa..b7fa4dd3f 100644 --- a/benchmarks/python/build.py +++ b/benchmarks/python/build.py @@ -295,7 +295,8 @@ def build_gpt(args): builder_config_extra_kwargs = {} extra_items = [ 'layer_types', 'conv_kernel', 'rnn_hidden_size', 'logits_soft_cap', - 'state_size', 'use_bias' + 'state_size', 'use_bias', 'rnn_head_size', 'rnn_conv_dim_size', + 'mamba_version', 'ssm_rmsnorm', 'ngroups', 'chunk_size' ] for item in extra_items: if item in build_config: @@ -876,10 +877,16 @@ def build_gpt(args): 'state_size': build_config['state_size'], 'conv_kernel': build_config['conv_kernel'], 'rnn_hidden_size': build_config['rnn_hidden_size'], + 'rnn_head_size': build_config['rnn_head_size'], + 'rnn_conv_dim_size': build_config['rnn_conv_dim_size'], 'rms_norm': True, 'residual_in_fp32': True, 'pad_vocab_size_multiple': 8, 'use_bias': build_config['use_bias'], + 'mamba_version': build_config['mamba_version'], + 'ssm_rmsnorm': build_config['ssm_rmsnorm'], + 'ngroups': build_config['ngroups'], + 'chunk_size': build_config['chunk_size'], } config = PretrainedConfig.from_dict(config) tensorrt_llm_model = tensorrt_llm.models.MambaForCausalLM(config) @@ -912,6 +919,8 @@ def build_gpt(args): 'state_size': build_config['state_size'], 'layer_types': build_config['layer_types'], 'rnn_hidden_size': build_config['rnn_hidden_size'], + 'rnn_head_size': build_config['rnn_head_size'], + 'rnn_conv_dim_size': build_config['rnn_conv_dim_size'], 'logits_soft_cap': build_config['logits_soft_cap'], 'rotary_pct': build_config['rotary_pct'], } diff --git a/benchmarks/python/gpt_benchmark.py b/benchmarks/python/gpt_benchmark.py index d0dea2855..6d8a840aa 100644 --- a/benchmarks/python/gpt_benchmark.py +++ b/benchmarks/python/gpt_benchmark.py @@ -126,7 +126,7 @@ def __init__(self, args, batch_sizes, in_out_lens, gpu_weights_percents, rnn_config_items = [ 'conv_kernel', 'layer_types', 'rnn_hidden_size', 'state_size', - 'state_dtype' + 'state_dtype', 'rnn_head_size', 'rnn_conv_dim_size' ] rnn_configs_kwargs = {} for item in rnn_config_items: diff --git a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h index 65d12b388..a79ce5840 100644 --- a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h @@ -116,7 +116,7 @@ class GenericInferenceRequest uint64_t requestId, std::optional logitsPostProcessor = std::nullopt) : mRequestId{requestId} , mIsStreaming{false} - , mlogitsPostProcessor(logitsPostProcessor) + , mLogitsPostProcessor(logitsPostProcessor) { } @@ -125,7 +125,7 @@ class GenericInferenceRequest : mRequestId{requestId} , mIsStreaming{false} , mInputTensors{std::move(tensorMap)} - , mlogitsPostProcessor(logitsPostProcessor) + , mLogitsPostProcessor(logitsPostProcessor) { for (auto const& [name, tensor] : mInputTensors) { @@ -161,12 +161,12 @@ class GenericInferenceRequest void setLogitsPostProcessor(std::optional cb) { - mlogitsPostProcessor = cb; + mLogitsPostProcessor = cb; } std::optional getLogitsPostProcessor() { - return mlogitsPostProcessor; + return mLogitsPostProcessor; } static std::array constexpr kTensorNames = { @@ -280,7 +280,7 @@ class GenericInferenceRequest uint64_t mRequestId; bool mIsStreaming; TensorMap mInputTensors; - std::optional mlogitsPostProcessor; + std::optional mLogitsPostProcessor; }; class InferenceRequest : public GenericInferenceRequest diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 589cd280b..b7a1cf639 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -248,16 +248,6 @@ class GenerationRequest } } - void setNumPrepopulatedTokens(std::vector numPrepopulatedTokens) - { - mNumPrepopulatedTokens = std::move(numPrepopulatedTokens); - } - - [[nodiscard]] std::vector const& getNumPrepopulatedTokens() const - { - return mNumPrepopulatedTokens; - } - private: // Slot id of the sequence SizeType32 mSeqSlotIdx; @@ -267,10 +257,6 @@ class GenerationRequest SizeType32 mBeamWidth; // List of blocks allocated for each beam of the sequence std::vector> mCacheBlockIds; - // Number of tokens already in kv cache before context phase. - // A value > 0 indicates cached kv cache blocks were reused. - // One value per beam. - std::vector mNumPrepopulatedTokens; }; // BlockManager manages overall metadata of KVCacheBlocks in a layer of the @@ -400,7 +386,10 @@ class BlockManager private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. - void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx, SizeType32 seqSlotIdx); + void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx); + + //! \brief Add single block to all beams of sequence. + void addBlockToAllBeams(BlockPtr& block, GenerationRequest& sequence); //! \brief Store blocks in cached blocks. //! \param blockedTokens Tokens of each block. @@ -410,11 +399,8 @@ class BlockManager //! \brief Try to load blocks from cache. Allocate new blocks if necessary. //! \param blockedTokens Tokens of each block. //! \param sequence Sequence to which blocks are assigned. - //! \param beamIdx Beam of sequence to which blocks are assigned. - //! \param seqSlotIdx Batch slot of sequence to which blocks are assigned. //! \return Number of matched tokens from loaded blocks. - SizeType32 loadOrAllocateBlocks(std::list const& blockedTokens, GenerationRequest& sequence, - SizeType32 beamIdx, SizeType32 seqSlotIdx); + SizeType32 loadOrAllocateBlocks(std::list const& blockedTokens, GenerationRequest& sequence); //! \brief Find best primary block to free. //! \details The best primary block to free is the primary block that appears first in the queue and have no primary @@ -598,12 +584,6 @@ class KVCacheManager nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager); - [[nodiscard]] SizeType32 getNumPrepopulatedTokens(SizeType32 batchSlotIdx, SizeType32 beamIdx) const - { - auto const& prepopulatedTokens = mSequences.at(batchSlotIdx)->getNumPrepopulatedTokens(); - return prepopulatedTokens.size() > 0 ? prepopulatedTokens.at(beamIdx) : 0; - } - [[nodiscard]] bool isEnableBlockReuse() const { return mEnableBlockReuse; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index eb959234f..41c170f48 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -84,7 +84,6 @@ class GenericLlmRequest , mSamplingConfig(samplingConfig) , mState(REQUEST_STATE_CONTEXT_INIT) , mIsStreaming(isStreaming) - , mReturnAllGeneratedTokens(isStreaming && (samplingConfig.beamWidth > 1)) , mEndId(endId) , mPadId(padId) , mLogitsPostProcessor(logitsPostProcessor) @@ -127,7 +126,6 @@ class GenericLlmRequest , mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig()) , mState(REQUEST_STATE_CONTEXT_INIT) , mIsStreaming(req.getStreaming()) - , mReturnAllGeneratedTokens(req.getReturnAllGeneratedTokens()) , mEndId(req.getEndId()) , mPadId(req.getPadId()) , mOrigPromptLen(mPromptLen) @@ -154,16 +152,6 @@ class GenericLlmRequest , mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput) , mDecodingIter(0) { - if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnAllGeneratedTokens == false) - { - TLLM_LOG_WARNING( - "Setting mReturnAllGeneratedTokens to True since streaming AND beam search are done simultaneously. " - "Returning the full beams at each streaming step is needed because beam search + streaming can change " - "previous outputs. Initialize request with mReturnAllGeneratedTokens = True to dismiss this error." - "WARNING: using this option may increase network usage significantly (quadratically w.r.t output " - "length)."); - mReturnAllGeneratedTokens = true; - } if (req.getEncoderInputTokenIds()) { mState = REQUEST_STATE_ENCODER_INIT; @@ -575,6 +563,16 @@ class GenericLlmRequest return mOrigPromptLen; } + void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen) + { + mPrepopulatedPromptLen = prepopulatedPromptLen; + } + + [[nodiscard]] SizeType32 getPrepopulatedPromptLen() const + { + return mPrepopulatedPromptLen; + } + void setDraftTokens(std::shared_ptr const& draftTokens) { mDraftTokens = draftTokens; @@ -585,7 +583,7 @@ class GenericLlmRequest mDraftLogits = draftLogits; } - SizeType32 getNumDraftTokens() const + [[nodiscard]] SizeType32 getNumDraftTokens() const { return mDraftTokens->size(); } @@ -604,7 +602,7 @@ class GenericLlmRequest mNumTokensPerIteration = numTokensPerIteration; } - SizeType32 getNumTokensPerIteration() const + [[nodiscard]] SizeType32 getNumTokensPerIteration() const { return mNumTokensPerIteration; } @@ -883,22 +881,16 @@ class GenericLlmRequest // FIXME(nkorobov): For streaming we do not allow beam search and // streaming index calculation here applies only for sampling // getNumTokensPerIteration takes accepted draft tokens into account - auto nbTokensOut - = (mReturnAllGeneratedTokens || !mIsStreaming) ? maxNbTokens : std::max(getNumTokensPerIteration(), 1); - + int nbTokensOut = mIsStreaming ? std::max(getNumTokensPerIteration(), 1) : maxNbTokens; if (mExcludeInputFromOutput && !mIsStreaming) { nbTokensOut -= getOrigPromptLen(); } result.outputTokenIds.resize(nbBeams); + SizeType32 tokenPos = maxNbTokens - nbTokensOut; - // in the case of streaming + beam search - // we need to return the full beams at all iterations - - SizeType32 tokenPos{maxNbTokens - nbTokensOut}; - auto const shouldSendResponse = isGenerationCompleteState() - || (mIsStreaming && tokenPos > getMaxSentTokenPos()) || mReturnAllGeneratedTokens; + bool shouldSendResponse = isGenerationCompleteState() || (mIsStreaming && tokenPos > getMaxSentTokenPos()); if (!shouldSendResponse) { @@ -909,8 +901,7 @@ class GenericLlmRequest for (SizeType32 beam = 0; beam < nbBeams; ++beam) { auto tokens = getTokens(beam); - auto nbTokens = (mReturnAllGeneratedTokens || !mIsStreaming) ? tokens.size() - : (tokenPos - getMaxSentTokenPos()); + auto nbTokens = mIsStreaming ? (tokenPos - getMaxSentTokenPos()) : tokens.size(); // Take accepted draft tokens into account when streaming auto const numAcceptedTokens = std::max(0, getNumTokensPerIteration() - 1); @@ -982,8 +973,6 @@ class GenericLlmRequest runtime::SamplingConfig mSamplingConfig; LlmRequestState_t mState; bool mIsStreaming; - // whether to return the full beams on each iteration. True when doing streaming + beamsearch - bool mReturnAllGeneratedTokens; std::optional mEndId; std::optional mPadId; std::optional mSeqSlot; @@ -993,6 +982,10 @@ class GenericLlmRequest protected: BeamTokens mTokens; SizeType32 mOrigPromptLen; + // Number of tokens already in KV cache before context phase. + // A value > 0 indicates cached KV cache blocks were reused. + // Up to inputLen - 1 tokens can be reused. + SizeType32 mPrepopulatedPromptLen{0}; SizeType32 mMaxSentTokenPos; std::optional mEmbeddingBias; diff --git a/cpp/include/tensorrt_llm/common/mpiUtils.h b/cpp/include/tensorrt_llm/common/mpiUtils.h index 9c8a24765..243cff7cd 100644 --- a/cpp/include/tensorrt_llm/common/mpiUtils.h +++ b/cpp/include/tensorrt_llm/common/mpiUtils.h @@ -385,7 +385,7 @@ class MpiComm bool mFreeComm; }; -void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_FUNNELED, bool forwardAbortToParent = false); +void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_MULTIPLE, bool forwardAbortToParent = false); } // namespace tensorrt_llm::mpi diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index f154e7917..c6f1769b8 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -252,8 +252,6 @@ class Request /// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor /// name provided to the ExecutorConfig. /// @param encoderInputTokenIds The encoder input token ids for encoder-decoder models, or encoder-only models - /// @param returnAllGeneratedTokens Indicates whether to return the full beams or just the newly generated tokens - /// after every streaming step. Request(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming = false, SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(), std::optional const& endId = std::nullopt, std::optional const& padId = std::nullopt, @@ -264,7 +262,7 @@ class Request std::optional pTuningConfig = std::nullopt, std::optional loraConfig = std::nullopt, std::optional logitsPostProcessorName = std::nullopt, - std::optional encoderInputTokenIds = std::nullopt, bool returnAllGeneratedTokens = false); + std::optional encoderInputTokenIds = std::nullopt); /// @brief This logits postprocessor name will dispatch to the batched logits postprocessor static auto constexpr kBatchedPostProcessorName = "batched"; @@ -290,7 +288,6 @@ class Request [[nodiscard]] std::optional getLoraConfig() const; [[nodiscard]] std::optional getLogitsPostProcessorName() const; [[nodiscard]] std::optional getEncoderInputTokenIds() const; - [[nodiscard]] bool getReturnAllGeneratedTokens() const; void setStreaming(bool streaming); void setSamplingConfig(SamplingConfig const& config); @@ -305,7 +302,6 @@ class Request void setLoraConfig(LoraConfig const& loraConfig); void setLogitsPostProcessorName(std::string const& logitsPostProcessorName); void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds); - void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens); private: friend class Serialization; diff --git a/cpp/include/tensorrt_llm/executor/tensor.h b/cpp/include/tensorrt_llm/executor/tensor.h index 22e78d661..ccce2379e 100644 --- a/cpp/include/tensorrt_llm/executor/tensor.h +++ b/cpp/include/tensorrt_llm/executor/tensor.h @@ -175,7 +175,7 @@ class Tensor { TLLM_CHECK(data.size() <= std::numeric_limits::max()); } - return of(data.data(), {static_cast(data.size())}); + return of(data.data(), {static_cast(data.size())}); } Tensor() noexcept = default; diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h index 67be816a9..5c349fa05 100644 --- a/cpp/include/tensorrt_llm/runtime/modelConfig.h +++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h @@ -51,6 +51,8 @@ class ModelConfig SizeType32 stateSize = 0; SizeType32 convKernel = 0; SizeType32 rnnHiddenSize = 0; + SizeType32 rnnHeadSize = 0; + SizeType32 rnnConvDimSize = 0; }; enum class LayerType : std::int32_t diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a index c6e8dad75..aac51e69e 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:33f2d6b3e871b0a0e651883607887777fe03d6822f06e4154ffc7e35a8d5cc70 -size 3938416 +oid sha256:5804fde474d6489db29204259b7e6c368117acadb7fb6dc807868ee0391c458b +size 3953206 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index 731820756..5334df1df 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8412aa4ca15c232ced1cd4bdfcc54177c7b257aef493d50650c960e0fb527cfc -size 4002178 +oid sha256:85802a0e66148acb17d017a64dd982287775ce7bf5aa4e8bb7e5466b3736c7ee +size 4019734 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt index 979f9c627..7e3984878 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -d4aa7db860caf8feedb79a280aa70da3 libtensorrt_llm_batch_manager_static.a -02b4363342ccea3e2abccc474f3506bb libtensorrt_llm_batch_manager_static.pre_cxx11.a -0e1417f27d93de67940c1062cf230017cd8be5f1 commit \ No newline at end of file +00fb525bdf4ff217c16940540b2357c4 libtensorrt_llm_batch_manager_static.a +97d2db7f62745001d871bc89fb38eed6 libtensorrt_llm_batch_manager_static.pre_cxx11.a +d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 5b09fa994..8da91110d 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:86f34c84883f1dfed04c6fb18811198da636e4457617a47db71f045cb3066eb4 -size 3825822 +oid sha256:33a724d7e9eabc358c0d674151d45cef8849ae702cc5f2f88b259299a8306574 +size 3842582 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index a151b44cf..924e7920e 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c07c30d986591bbe93bb30d67fc8ebbba3eb55c5875ce939c3265151747656ae -size 3782506 +oid sha256:490a93ff13a67949a30e279fc3df27456c7f5d4084158c3089befccf78118b7f +size 3799140 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib index 81d6c151f..f51e23a15 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib +++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e0190b794e437fa6a0e2140e9446195413abde0dfbc5109423c790397fbb95a6 -size 22445474 +oid sha256:663a163c3177644ed86fa7a2145fe5e9dbf6f2f0ed06c96d367236da323a3432 +size 22523526 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a index 29476045d..a1f3bb344 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8729077e2bfb9cf3f647cc6ca9be42a8953c0ddf58426485ae3bded76dc9d5c3 -size 1403008 +oid sha256:497b00031131c1dc705e848e52f3d43148f55505e37bdad97f4933b2c074469d +size 1400502 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index 871ecac8f..ebe9134cb 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2b68c06565f1b3f795e070420c73d085c620b42c1c2131f9895d2687178a6b54 -size 1427780 +oid sha256:417978bdb5c19f97d9758475acacfa18a4038fc3c5a83f981b02ee220104e0c7 +size 1425792 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt index 6b4481d70..a50698ac4 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -db98ffd911c3c1dde3413e934ce8deb8 libtensorrt_llm_executor_static.a -8dc57746aa2c29d8a2fa50196b552bc3 libtensorrt_llm_executor_static.pre_cxx11.a -0e1417f27d93de67940c1062cf230017cd8be5f1 commit \ No newline at end of file +1df55ac2948ca7b7fe2d5e79934e660e libtensorrt_llm_executor_static.a +ea1641928d184d117deec0696763b274 libtensorrt_llm_executor_static.pre_cxx11.a +d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a index af916f3ad..06ef09b57 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cbc3a279681e877a982c0ebbdd0c13d7792d67a87bad0be125ec81bfe3f87399 -size 1454684 +oid sha256:d0441d473852d11f50bcf23f4934b38d7e4c6d4a42f057eb04beb8aea4211cac +size 1451118 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index de53b7818..c8c4f0536 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aa15303c38a748c4bf7b82e1f9c58cb63418efbd60bfede62820f5a62d65710a -size 1381738 +oid sha256:dc8619f99cf5a2e04bdb1482f157a9852bd745e90cf9e03a7878f73ed07e5610 +size 1383936 diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib index 58c55281e..1d89e1a8d 100644 --- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib +++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1a3e774d6700444b7164e1b31e26936ea6fcddc73e3e17bba1d8492c65a57b78 -size 14036486 +oid sha256:772d1b83e739b926729b99999fbb81768569ffb172c2e120665b2d31b987bb47 +size 14071986 diff --git a/cpp/tensorrt_llm/kernels/chunkScan/Cn.h b/cpp/tensorrt_llm/kernels/chunkScan/Cn.h new file mode 100644 index 000000000..24b191e56 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/chunkScan/Cn.h @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#ifdef __CUDACC__ // for CUDA +#define FT_DEV_CEXPR __device__ __host__ inline constexpr +#else +#define FT_DEV_CEXPR inline constexpr +#endif + +//---------------------------------------------------------------------------- +// Cn: constant integer +//---------------------------------------------------------------------------- + +template +struct Cn : public std::integral_constant +{ +}; + +template +constexpr auto cn = Cn(); + +//---------------------------------------------------------------------------- +// Operators for Cn +//---------------------------------------------------------------------------- + +template +FT_DEV_CEXPR auto operator+(Cn) +{ + return cn<+value_>; +} + +template +FT_DEV_CEXPR auto operator-(Cn) +{ + return cn<-value_>; +} + +template +FT_DEV_CEXPR auto operator!(Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator~(Cn) +{ + return cn<~value_>; +} + +template +FT_DEV_CEXPR auto operator+(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator-(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator*(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator/(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator%(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator<<(Cn, Cn) +{ + return cn<(a_ << b_)>; +} + +template +FT_DEV_CEXPR auto operator>>(Cn, Cn) +{ + return cn<(a_ >> b_)>; +} + +template +FT_DEV_CEXPR auto operator<(Cn, Cn) +{ + return cn<(a_ < b_)>; +} + +template +FT_DEV_CEXPR auto operator<=(Cn, Cn) +{ + return cn<(a_ <= b_)>; +} + +template +FT_DEV_CEXPR auto operator>(Cn, Cn) +{ + return cn<(a_ > b_)>; +} + +template +FT_DEV_CEXPR auto operator>=(Cn, Cn) +{ + return cn<(a_ >= b_)>; +} + +template +FT_DEV_CEXPR auto operator==(Cn, Cn) +{ + return cn<(a_ == b_)>; +} + +template +FT_DEV_CEXPR auto operator!=(Cn, Cn) +{ + return cn<(a_ != b_)>; +} + +template +FT_DEV_CEXPR auto operator^(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator&(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator&&(Cn, Cn) +{ + return cn < a_ && b_ > ; +} + +template +FT_DEV_CEXPR auto operator|(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto operator||(Cn, Cn) +{ + return cn < a_ || b_ > ; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator*(Cn, B_) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator/(Cn, B_) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator%(Cn, B_) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator<<(Cn, B_) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator>>(Cn, B_) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator&(Cn, B_) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator&&(Cn, B_) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator*(A_, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator%(A_, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator%(A_, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator&(A_, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator&&(A_, Cn) +{ + return cn; +} + +//---------------------------------------------------------------------------- +// div_up & round_up +//---------------------------------------------------------------------------- + +template +FT_DEV_CEXPR auto cexpr_abs(T_ a_) // abs is not constexpr until C++20 +{ + return a_ >= cn<0> ? +a_ : -a_; +} + +template +FT_DEV_CEXPR auto div_up(T_ a_, U_ b_) +{ + auto tmp = a_ >= cn<0> ? a_ + (cexpr_abs(b_) - cn<1>) : a_ - (cexpr_abs(b_) - cn<1>); + + return tmp / b_; +} + +template +FT_DEV_CEXPR auto round_up(T_ a_, U_ b_) +{ + auto tmp = a_ >= cn<0> ? a_ + (cexpr_abs(b_) - cn<1>) : a_ - (cexpr_abs(b_) - cn<1>); + + return tmp - tmp % b_; +} + +template +FT_DEV_CEXPR auto div_up(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR auto round_up(Cn, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> div_up(Cn, B_) +{ + return cn; +} + +template +FT_DEV_CEXPR std::enable_if_t> round_up(Cn, B_) +{ + return cn; +} + +//---------------------------------------------------------------------------- +// IsTuple: std::tuple, but not std::pair, std::array, etc. +//---------------------------------------------------------------------------- + +template +struct IsTuple : public std::false_type +{ +}; + +template +struct IsTuple> : public std::true_type +{ +}; + +template +struct IsTuple : public IsTuple +{ +}; + +template +struct IsTuple : public IsTuple +{ +}; + +template +struct IsTuple : public IsTuple +{ +}; + +template +constexpr bool IsTuple_v = IsTuple::value; + +// vim: ts=2 sw=2 sts=2 et sta diff --git a/cpp/tensorrt_llm/kernels/chunkScan/Common.h b/cpp/tensorrt_llm/kernels/chunkScan/Common.h new file mode 100644 index 000000000..403f08734 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/chunkScan/Common.h @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +extern "C" __device__ unsigned __nvvm_get_smem_pointer(void* ptr); + +template +__device__ inline int swizzle(int x_) +{ + return x_ ^ x_ / line_ % (mode_ / 16) * (16 / sizeof(T_)); +} + +template +__device__ inline int swizzle(int x_, int y_) +{ + return x_ ^ y_ * (16 / sizeof(T_)); +} + +template +__device__ inline void cp_shared_global(unsigned s_ptr, void const* g_ptr) +{ + static_assert(size_ == 4 || size_ == 8 || size_ == 16); + +#if __CUDA_ARCH__ >= 800 + if constexpr (size_ == 16) + asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_)); + else if constexpr (size_ == 8) + asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_)); + else if constexpr (size_ == 4) + asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_)); +#else + register unsigned tmp[size_ / 4]; + + if constexpr (size_ == 16) + { + asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3]) + : "l"(g_ptr)); + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]), + "r"(tmp[3])); + } + else if constexpr (size_ == 8) + { + asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" : "=r"(tmp[0]), "=r"(tmp[1]) : "l"(g_ptr)); + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1])); + } + else if constexpr (size_ == 4) + { + asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(g_ptr)); + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "r"(tmp[0])); + } +#endif +} + +template +__device__ inline void cp_shared_global(unsigned s_ptr, void const* g_ptr, bool valid_) +{ + static_assert(size_ == 4 || size_ == 8 || size_ == 16); + +#if __CUDA_ARCH__ >= 800 + if constexpr (size_ == 16) + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_), + "r"(valid_ ? size_ : 0)); + else if constexpr (size_ == 8) + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_), + "r"(valid_ ? size_ : 0)); + else if constexpr (size_ == 4) + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(s_ptr), "l"(g_ptr), "n"(size_), + "r"(valid_ ? size_ : 0)); +#else + register unsigned tmp[size_ / 4]; + + if constexpr (size_ == 16) + { + if (valid_) + { + asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3]) + : "l"(g_ptr)); + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1]), + "r"(tmp[2]), "r"(tmp[3])); + } + else + { + asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(s_ptr), "n"(0), "n"(0), "n"(0), "n"(0)); + } + } + else if constexpr (size_ == 8) + { + if (valid_) + { + asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" : "=r"(tmp[0]), "=r"(tmp[1]) : "l"(g_ptr)); + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "r"(tmp[0]), "r"(tmp[1])); + } + else + { + asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(s_ptr), "n"(0), "n"(0)); + } + } + else if constexpr (size_ == 4) + { + if (valid_) + { + asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(g_ptr)); + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "r"(tmp[0])); + } + else + { + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(s_ptr), "n"(0)); + } + } +#endif +} + +__device__ inline void cp_commit_group() +{ +#if __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n"); +#endif +} + +template +__device__ inline void cp_wait_group() +{ +#if __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" ::"n"(remain_)); +#endif +} + +template +__device__ inline void ldmatrix(unsigned& r0_, unsigned& r1_, unsigned& r2_, unsigned& r3_, unsigned addr_) +{ + if (trans_) + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(r0_), "=r"(r1_), "=r"(r2_), "=r"(r3_) + : "r"(addr_)); + else + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(r0_), "=r"(r1_), "=r"(r2_), "=r"(r3_) + : "r"(addr_)); +} + +typedef __nv_bfloat16 bf16; +typedef __nv_bfloat162 bf162; + +template +__device__ int swz(int x_) +{ + return x_ ^ x_ / line_ % (mode_ / 16) * 8; +} + +// vim: ts=2 sw=2 sts=2 et sta diff --git a/cpp/tensorrt_llm/kernels/chunkScan/Poly.h b/cpp/tensorrt_llm/kernels/chunkScan/Poly.h new file mode 100644 index 000000000..20d606f70 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/chunkScan/Poly.h @@ -0,0 +1,1277 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "Cn.h" + +//---------------------------------------------------------------------------- +// Rn: ranged integer (with size and multiplier) +//---------------------------------------------------------------------------- + +enum Kind +{ + NONE, + UNROLL, + ID, +}; + +template +class Ranged +{ + static_assert(multiplier_); + +public: + typedef decltype(size_) type; + + FT_DEV_CEXPR Ranged(type var_ = 0) + : var(var_) + { + } + + static constexpr auto min = type(size_ && multiplier_ < 0 ? size_ - 1 : 0) * multiplier_; + static constexpr auto max = type(size_ && multiplier_ > 0 ? size_ - 1 : 0) * multiplier_; + static constexpr auto abs = max - min; + + static constexpr bool minInf = (size_ == 0 && multiplier_ < 0); + static constexpr bool maxInf = (size_ == 0 && multiplier_ > 0); + static constexpr bool inf = (size_ == 0); + + static constexpr auto zero = decltype(size_)(0); + static constexpr auto ZERO = decltype(size_ * multiplier_)(0); + + static constexpr Kind kind = kind_; + static constexpr type size = size_; + static constexpr auto multiplier = multiplier_; + + type var; + + FT_DEV_CEXPR + auto operator+() const + { + return *this; + } + + FT_DEV_CEXPR + auto operator-() const + { + return Ranged{var}; + } +}; + +template ? 0 : decltype(kind_)(1), + auto multiplier_ = decltype(size_)(1)> +using Rn = std::conditional_t, Ranged, + Ranged>; + +//---------------------------------------------------------------------------- +// Poly: polynomial integer +//---------------------------------------------------------------------------- + +template +class Poly +{ +public: + typedef std::tuple Terms; + + FT_DEV_CEXPR Poly(Cn, Terms ts_) + : terms(ts_) + { + } + + FT_DEV_CEXPR Poly(Cn, Ts_... ts_) + : terms(ts_...) + { + } + + FT_DEV_CEXPR Poly(Terms ts_) + : terms(ts_) + { + } + + FT_DEV_CEXPR Poly(Ts_... ts_) + : terms(ts_...) + { + } + + static constexpr auto min = (bias_ + ... + Ts_::min); + static constexpr auto max = (bias_ + ... + Ts_::max); + + static constexpr bool minInf = (false || ... || Ts_::minInf); + static constexpr bool maxInf = (false || ... || Ts_::maxInf); + + static constexpr auto zero = decltype(bias_)(0); + static constexpr auto ZERO = decltype((bias_ + ... + Ts_::zero))(0); + + static constexpr auto bias = bias_; + + Terms terms; + + FT_DEV_CEXPR + auto operator+() const + { + return *this; + } + + FT_DEV_CEXPR + auto operator-() const + { + return negateImp(std::index_sequence_for()); + } + + template + FT_DEV_CEXPR auto mul(Cn) const + { + return mulImp(cn, cn); + } + + template + FT_DEV_CEXPR auto operator/(Cn) const + { + return divImp(cn, cn); + } + + template + FT_DEV_CEXPR auto operator%(Cn) const + { + return modImp(cn, cn); + } + + template + FT_DEV_CEXPR auto filter(Cn) const + { + return filterImp(cn, cn); + } + + template + FT_DEV_CEXPR auto filterDiv(Cn) const + { + if constexpr (b_ == 0) + return *this; // return itself if indivisible + else if constexpr (!divisible(cn)) + return *this; // return itself if indivisible + else + return filterDivImp(cn, cn); + } + + template + FT_DEV_CEXPR static bool divisible(Cn) + { + static_assert(b_); + + constexpr auto dMin = divisibleMin(cn, std::index_sequence_for()); + constexpr auto dMax = divisibleMax(cn, std::index_sequence_for()); + constexpr auto iMin = indivisibleMin(cn, std::index_sequence_for()); + constexpr auto iMax = indivisibleMax(cn, std::index_sequence_for()); + + constexpr auto iBigTerm = indivisibleBigTerm(cn, cn); + constexpr auto iBig = iBigTerm.abs + iBigTerm.inf * cexpr_abs(b_); + + if constexpr (bias_ % b_) + { + return dMin == 0 && iMin == 0 && bias_ >= 0 && iMax + bias_ % b_ < cexpr_abs(b_) + || dMax == 0 && iMax == 0 && bias_ <= 0 && iMin + bias_ % b_ > -cexpr_abs(b_); + } + else if constexpr (!std::is_same_v>) + { + return dMin == 0 && iMin == 0 && bias_ >= 0 + && (iMax < cexpr_abs(b_) || iMax - iBig < iBigTerm.multiplier && b_ % iBigTerm.multiplier == 0) + || dMax == 0 && iMax == 0 && bias_ <= 0 + && (iMin > -cexpr_abs(b_) || iMin + iBig > iBigTerm.multiplier && b_ % iBigTerm.multiplier == 0); + } + + return true; + } + + template + FT_DEV_CEXPR static bool hasOnly(Cn) + { + return hasOnlyImp(cn, cn); + } + +private: + template + FT_DEV_CEXPR auto negateImp(std::index_sequence) const + { + return Poly<-bias_, decltype(-std::get(terms))...>{cn<-bias_>, std::tuple{-std::get(terms)...}}; + } + + template + FT_DEV_CEXPR static auto divisibleMin(Cn, std::index_sequence) + { + return (zero + ... + + (std::tuple_element_t::multiplier % b_ && std::tuple_element_t::size != 1 + ? std::tuple_element_t::ZERO + : std::tuple_element_t::min + + std::tuple_element_t::minInf * -cexpr_abs(b_))); + } + + template + FT_DEV_CEXPR static auto divisibleMax(Cn, std::index_sequence) + { + return (zero + ... + + (std::tuple_element_t::multiplier % b_ && std::tuple_element_t::size != 1 + ? std::tuple_element_t::ZERO + : std::tuple_element_t::max + + std::tuple_element_t::maxInf * cexpr_abs(b_))); + } + + template + FT_DEV_CEXPR static auto indivisibleMin(Cn, std::index_sequence) + { + return (zero + ... + + (std::tuple_element_t::multiplier % b_ && std::tuple_element_t::size != 1 + ? std::tuple_element_t::min + std::tuple_element_t::minInf * -cexpr_abs(b_) + : std::tuple_element_t::ZERO)); + } + + template + FT_DEV_CEXPR static auto indivisibleMax(Cn, std::index_sequence) + { + return (zero + ... + + (std::tuple_element_t::multiplier % b_ && std::tuple_element_t::size != 1 + ? std::tuple_element_t::max + std::tuple_element_t::maxInf * cexpr_abs(b_) + : std::tuple_element_t::ZERO)); + } + + template + FT_DEV_CEXPR static auto indivisibleBigTerm(Cn, Cn) + { + if constexpr (i_ == 0) + return Rn{false}; + else + { + constexpr auto prev = indivisibleBigTerm(cn, cn); + constexpr auto curr = std::tuple_element_t{0}; + + if constexpr (curr.multiplier % b_ == 0 || curr.size == 1) + return prev; + else if constexpr (std::is_same_v>) + return curr; + else if constexpr (curr.max - curr.min > prev.max - prev.min || curr.inf) + return curr; + else + return prev; + } + } + + template + FT_DEV_CEXPR static bool hasOnlyImp(Cn, Cn) + { + if constexpr (i_ == 0) + return true; + else if constexpr (std::tuple_element_t::kind != kind_ + && std::tuple_element_t::size != 1) + return false; + else + return hasOnlyImp(cn, cn); + } + + template + FT_DEV_CEXPR auto mulImp(Cn, Cn) const + { + if constexpr (i_ == 0) + return cn; + else + return mulImp(cn, cn) + std::get(terms) * cn; + } + + template + FT_DEV_CEXPR auto divImp(Cn, Cn) const + { + static_assert(b_); + static_assert(divisible(cn)); + + if constexpr (i_ == 0) + return cn; + else if constexpr (std::tuple_element_t::abs >= cexpr_abs(b_) + || std::tuple_element_t::inf) + return divImp(cn, cn) + std::get(terms) / cn; + else + return divImp(cn, cn); + } + + template + FT_DEV_CEXPR auto modImp(Cn, Cn) const + { + static_assert(b_); + static_assert(divisible(cn)); + + if constexpr (i_ == 0) + return cn; + else if constexpr (std::tuple_element_t::multiplier % b_ + && std::tuple_element_t::size != 1) + return modImp(cn, cn) + std::get(terms) % cn; + else + return modImp(cn, cn); + } + + template + FT_DEV_CEXPR auto filterImp(Cn, Cn) const + { + if constexpr (i_ == 0) + return Poly{}; + else if constexpr (std::tuple_element_t::kind == kind_ + && std::tuple_element_t::size != 1) + return filterImp(cn, cn) + std::get(terms); + else + return filterImp(cn, cn); + } + + template + FT_DEV_CEXPR auto filterDivImp(Cn, Cn) const + { + static_assert(b_); + static_assert(divisible(cn)); + + if constexpr (i_ == 0) + return Poly{}; + else if constexpr (std::tuple_element_t::abs >= cexpr_abs(b_) + || std::tuple_element_t::inf) + return filterDivImp(cn, cn) + std::get(terms); + else + return filterDivImp(cn, cn); + } +}; + +// constructs Poly from Cn and Rns +template +Poly(Cn, Ranged...) -> Poly...>; +// constructs Poly from Rns +template +Poly(Ranged...) -> Poly...>; + +//---------------------------------------------------------------------------- +// Operators for Rn and Poly +//---------------------------------------------------------------------------- + +/* We should never use int * Rn +template FT_DEV_CEXPR std::enable_if_t, +Rn> operator * (T_ a_, Ranged x_) { return Rn{x_.var * +a_}; } template FT_DEV_CEXPR std::enable_if_t, +Rn> operator * (Ranged x_, T_ b_) { return Rn{x_.var * +b_}; } +*/ + +template +FT_DEV_CEXPR std::enable_if_t> operator*(Cn, Ranged x_) +{ + return Rn{x_.var}; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator*(Ranged x_, Cn) +{ + return Rn{x_.var}; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator/(Ranged x_, Cn) +{ + return Rn{x_.var}; +} + +template +FT_DEV_CEXPR std::enable_if_t> operator%(Ranged x_, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR Rn operator<<(Ranged x_, Cn) +{ + return Rn{x_.var}; +} + +template +FT_DEV_CEXPR std::enable_if_t> b_)>> operator>>(Ranged x_, Cn) +{ + return Rn> b_)>{x_.var}; +} + +template +FT_DEV_CEXPR + std::enable_if_t<(Rn::abs < cexpr_abs(b_) && !Rn::inf && m_ % b_ != 0), Cn> + operator/(Ranged x_, Cn) +{ + return cn; +} + +template +FT_DEV_CEXPR + std::enable_if_t<(Rn::abs < cexpr_abs(b_) && !Rn::inf && m_ % b_ != 0), Rn> + operator%(Ranged x_, Cn) +{ + return Rn{x_.var}; +} + +template +FT_DEV_CEXPR + std::enable_if_t<(Rn::abs < (1 << b_) && !Rn::inf && m_ % (1 << b_) != 0), Cn<(m_ >> b_)>> + operator>>(Ranged x_, Cn) +{ + return cn<(m_ >> b_)>; +} + +template +FT_DEV_CEXPR + std::enable_if_t<(Rn::abs >= cexpr_abs(b_) || Rn::inf) && (m_ % b_ != 0 && b_ % m_ == 0), + Rn> + operator/(Ranged x_, Cn) +{ + return Rn{x_.var / cexpr_abs(b_ / m_)}; +} + +template +FT_DEV_CEXPR + std::enable_if_t<(Rn::abs >= cexpr_abs(b_) || Rn::inf) && (m_ % b_ != 0 && b_ % m_ == 0), + Rn> + operator%(Ranged x_, Cn) +{ + return Rn{x_.var % cexpr_abs(b_ / m_)}; +} + +template +FT_DEV_CEXPR std::enable_if_t<(m_ > 0) && // correct only when positive + (Rn::abs >= (1 << b_) || Rn::inf) && (m_ % (1 << b_) != 0 && (1 << b_) % m_ == 0), + Rn> +operator>>(Ranged x_, Cn) +{ + return Rn{ + x_.var / cexpr_abs((1 << b_) / m_)}; +} + +template +FT_DEV_CEXPR auto operator+(Ranged a_, Cn) +{ + return Poly{cn, std::tuple{a_}}; +} + +template +FT_DEV_CEXPR auto operator-(Ranged a_, Cn) +{ + return Poly{cn<-b_>, std::tuple{a_}}; +} + +template +FT_DEV_CEXPR auto operator+(Cn, Ranged b_) +{ + return Poly{cn, std::tuple{b_}}; +} + +template +FT_DEV_CEXPR auto operator-(Cn, Ranged b_) +{ + return Poly{cn, std::tuple{-b_}}; +} + +template +FT_DEV_CEXPR auto operator+(Poly a_, Cn) +{ + return Poly{cn, a_.terms}; +} + +template +FT_DEV_CEXPR auto operator-(Poly a_, Cn) +{ + return Poly{cn, a_.terms}; +} + +template +FT_DEV_CEXPR auto operator+(Cn, Poly b_) +{ + return Poly{cn, b_.terms}; +} + +template +FT_DEV_CEXPR auto operator-(Cn, Poly b_) +{ + return Poly{cn, (-b_).terms}; +} + +template +FT_DEV_CEXPR auto operator+(Ranged a_, Ranged b_) +{ + return Poly{cn, std::tuple{a_, b_}}; +} + +template +FT_DEV_CEXPR auto operator-(Ranged a_, Ranged b_) +{ + return Poly{cn, std::tuple{a_, -b_}}; +} + +template +FT_DEV_CEXPR auto operator+(Poly a_, Ranged b_) +{ + return Poly{cn, std::tuple_cat(a_.terms, std::tuple{b_})}; +} + +template +FT_DEV_CEXPR auto operator-(Poly a_, Ranged b_) +{ + return Poly{cn, std::tuple_cat(a_.terms, std::tuple{-b_})}; +} + +template +FT_DEV_CEXPR auto operator+(Ranged a_, Poly b_) +{ + return Poly{cn, std::tuple_cat(std::tuple{a_}, b_.terms)}; +} + +template +FT_DEV_CEXPR auto operator-(Ranged a_, Poly b_) +{ + return Poly{cn<-B_>, std::tuple_cat(std::tuple{a_}, (-b_).terms)}; +} + +template +FT_DEV_CEXPR auto operator+(Poly a_, Poly b_) +{ + return Poly{cn, std::tuple_cat(a_.terms, b_.terms)}; +} + +template +FT_DEV_CEXPR auto operator-(Poly a_, Poly b_) +{ + return Poly{cn, std::tuple_cat(a_.terms, (-b_).terms)}; +} + +template +FT_DEV_CEXPR std::enable_if_t>().mul(cn))> operator*( + Cn, Poly x_) +{ + return x_.mul(cn); +} + +template +FT_DEV_CEXPR std::enable_if_t>().mul(cn))> operator*( + Poly x_, Cn) +{ + return x_.mul(cn); +} + +template +FT_DEV_CEXPR auto operator<<(Poly x_, Cn) +{ + return x_ * cn<(1 << b_)>; +} + +template +FT_DEV_CEXPR auto operator>>(Poly x_, Cn) +{ + return x_ / cn<(1 << b_)>; +} + +template +FT_DEV_CEXPR auto operator*(Ranged a_, Ranged b_) +{ + return Ranged < kA == kB ? kA : NONE, zA * zB, mA * mB > {a_.var * b_.var}; +} + +template +FT_DEV_CEXPR auto operator*(Poly a_, Ranged b_) +{ + return Ranged < Poly::hasOnly(cn) ? k_ : NONE > {get(a_) * b_.var} * cn; +} + +template +FT_DEV_CEXPR auto operator*(Ranged a_, Poly b_) +{ + return Ranged < Poly::hasOnly(cn) ? k_ : NONE > {a_.var * get(b_)} * cn; +} + +/* We should never use Poly * Poly +template +FT_DEV_CEXPR +auto operator * (Poly a_, Poly b_) +{ + return Ranged::hasOnly(cn) && + Poly::hasOnly(cn) ? UNROLL : ( + Poly::hasOnly(cn) && + Poly::hasOnly(cn) ? ID : NONE)>{get(a_) * get(b_)}; +} +*/ + +//---------------------------------------------------------------------------- +// get() for Cn, Rn and Poly +//---------------------------------------------------------------------------- + +template +FT_DEV_CEXPR auto get(Cn) +{ + return value_; +} + +template +FT_DEV_CEXPR auto get(Ranged x_) +{ + if constexpr (size_ == 1) + return Rn::ZERO; + else + return x_.var * multiplier_; +} + +template +FT_DEV_CEXPR auto getImp(Poly x_, std::index_sequence) +{ + return (bias_ + ... + get(std::get(x_.terms))); +} + +template +FT_DEV_CEXPR auto get(Poly x_) +{ + return getImp(x_.filter(cn), + std::make_index_sequence))::Terms>>()); +} + +template +FT_DEV_CEXPR auto get(Poly x_) +{ + return bias_ + get(x_) + get(x_) + get(x_); +} + +//---------------------------------------------------------------------------- +// Comparison operators for Rn and Poly +//---------------------------------------------------------------------------- + +template +FT_DEV_CEXPR auto ltzeroImp(Poly x_) +{ + constexpr decltype(A_) p2 = ((-A_) ^ (-A_ - 1)) / 2 + 1; + + constexpr auto n1 = std::tuple_size_v) .filterDiv(cn<1>))::Terms>; + constexpr auto n2 = std::tuple_size_v) .filterDiv(cn))::Terms>; + constexpr auto nA = std::tuple_size_v) .filterDiv(cn))::Terms>; + + if constexpr (Poly::min >= 0 && !Poly::minInf) + return cn; + else if constexpr (Poly::max < 0 && !Poly::maxInf) + return cn; + + else if constexpr (A_ < 0 && nA < n2 && nA < n1) + return ltzeroImp((x_ - cn) .filterDiv(cn) + cn); + else if constexpr (A_ < 0 && n2 < n1) + return ltzeroImp((x_ - cn) .filterDiv(cn) + cn); + + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::min >= 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::minInf) + return false; + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::max < 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::maxInf) + return true; + + else if constexpr (A_ < 0) + return get(x_) + get(x_) + get(x_) < -A_; + else + return get(-x_) + get(-x_) + get(-x_) > A_; +} + +template +FT_DEV_CEXPR auto lezeroImp(Poly x_) +{ + constexpr decltype(A_) p2 = ((+A_) ^ (+A_ - 1)) / 2 + 1; + + constexpr auto n1 = std::tuple_size_v) .filterDiv(cn<1>))::Terms>; + constexpr auto n2 = std::tuple_size_v) .filterDiv(cn))::Terms>; + constexpr auto nA = std::tuple_size_v) .filterDiv(cn))::Terms>; + + if constexpr (Poly::min > 0 && !Poly::minInf) + return cn; + else if constexpr (Poly::max <= 0 && !Poly::maxInf) + return cn; + + else if constexpr (A_ > 0 && nA < n2 && nA < n1) + return lezeroImp((x_ - cn) .filterDiv(cn) + cn); + else if constexpr (A_ > 0 && n2 < n1) + return lezeroImp((x_ - cn) .filterDiv(cn) + cn); + + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::min > 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::minInf) + return false; + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::max <= 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::maxInf) + return true; + + else if constexpr (A_ < 0) + return get(x_) + get(x_) + get(x_) <= -A_; + else + return get(-x_) + get(-x_) + get(-x_) >= A_; +} + +template +FT_DEV_CEXPR auto gtzeroImp(Poly x_) +{ + constexpr decltype(A_) p2 = ((+A_) ^ (+A_ - 1)) / 2 + 1; + + constexpr auto n1 = std::tuple_size_v) .filterDiv(cn<1>))::Terms>; + constexpr auto n2 = std::tuple_size_v) .filterDiv(cn))::Terms>; + constexpr auto nA = std::tuple_size_v) .filterDiv(cn))::Terms>; + + if constexpr (Poly::max <= 0 && !Poly::maxInf) + return cn; + else if constexpr (Poly::min > 0 && !Poly::minInf) + return cn; + + else if constexpr (A_ > 0 && nA < n2 && nA < n1) + return gtzeroImp((x_ - cn) .filterDiv(cn) + cn); + else if constexpr (A_ > 0 && n2 < n1) + return gtzeroImp((x_ - cn) .filterDiv(cn) + cn); + + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::max <= 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::maxInf) + return false; + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::min > 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::minInf) + return true; + + else if constexpr (A_ < 0) + return get(x_) + get(x_) + get(x_) > -A_; + else + return get(-x_) + get(-x_) + get(-x_) < A_; +} + +template +FT_DEV_CEXPR auto gezeroImp(Poly x_) +{ + constexpr decltype(A_) p2 = ((-A_) ^ (-A_ - 1)) / 2 + 1; + + constexpr auto n1 = std::tuple_size_v) .filterDiv(cn<1>))::Terms>; + constexpr auto n2 = std::tuple_size_v) .filterDiv(cn))::Terms>; + constexpr auto nA = std::tuple_size_v) .filterDiv(cn))::Terms>; + + if constexpr (Poly::max < 0 && !Poly::maxInf) + return cn; + else if constexpr (Poly::min >= 0 && !Poly::minInf) + return cn; + + else if constexpr (A_ < 0 && nA < n2 && nA < n1) + return gezeroImp((x_ - cn) .filterDiv(cn) + cn); + else if constexpr (A_ < 0 && n2 < n1) + return gezeroImp((x_ - cn) .filterDiv(cn) + cn); + + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::max < 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::maxInf) + return false; + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::min >= 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::minInf) + return true; + + else if constexpr (A_ < 0) + return get(x_) + get(x_) + get(x_) >= -A_; + else + return get(-x_) + get(-x_) + get(-x_) <= A_; +} + +template +FT_DEV_CEXPR auto eqzeroImp(Poly x_) +{ + if constexpr (Poly::min > 0 && !Poly::minInf) + return cn; + else if constexpr (Poly::max < 0 && !Poly::maxInf) + return cn; + else if constexpr (Poly::min == 0 && !Poly::minInf && Poly::max == 0 + && !Poly::maxInf) + return cn; + + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::min > 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::minInf) + return false; + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::max < 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::maxInf) + return false; + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::min == 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::minInf + && A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::max == 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::maxInf) + return true; + + else if constexpr (A_ < 0) + return get(x_) + get(x_) + get(x_) == -A_; + else + return get(-x_) + get(-x_) + get(-x_) == A_; +} + +template +FT_DEV_CEXPR auto nezeroImp(Poly x_) +{ + if constexpr (Poly::min > 0 && !Poly::minInf) + return cn; + else if constexpr (Poly::max < 0 && !Poly::maxInf) + return cn; + else if constexpr (Poly::min == 0 && !Poly::minInf && Poly::max == 0 + && !Poly::maxInf) + return cn; + + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::min > 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::minInf) + return true; + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::max < 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::maxInf) + return true; + else if (A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::min == 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::minInf + && A_ + get(x_) + decltype(x_.filter(cn) + x_.filter(cn))::max == 0 + && !decltype(x_.filter(cn) + x_.filter(cn))::maxInf) + return false; + + else if constexpr (A_ < 0) + return get(x_) + get(x_) + get(x_) != -A_; + else + return get(-x_) + get(-x_) + get(-x_) != A_; +} + +template +FT_DEV_CEXPR auto operator<(Cn, Ranged x_) +{ + return ltzeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator<=(Cn, Ranged x_) +{ + return lezeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator>(Cn, Ranged x_) +{ + return gtzeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator>=(Cn, Ranged x_) +{ + return gezeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator==(Cn, Ranged x_) +{ + return eqzeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator!=(Cn, Ranged x_) +{ + return nezeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator<(Ranged x_, Cn) +{ + return ltzeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator<=(Ranged x_, Cn) +{ + return lezeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator>(Ranged x_, Cn) +{ + return gtzeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator>=(Ranged x_, Cn) +{ + return gezeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator==(Ranged x_, Cn) +{ + return eqzeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator!=(Ranged x_, Cn) +{ + return nezeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator<(Cn, Poly x_) +{ + return ltzeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator<=(Cn, Poly x_) +{ + return lezeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator>(Cn, Poly x_) +{ + return gtzeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator>=(Cn, Poly x_) +{ + return gezeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator==(Cn, Poly x_) +{ + return eqzeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator!=(Cn, Poly x_) +{ + return nezeroImp(cn - x_); +} + +template +FT_DEV_CEXPR auto operator<(Poly x_, Cn) +{ + return ltzeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator<=(Poly x_, Cn) +{ + return lezeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator>(Poly x_, Cn) +{ + return gtzeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator>=(Poly x_, Cn) +{ + return gezeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator==(Poly x_, Cn) +{ + return eqzeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator!=(Poly x_, Cn) +{ + return nezeroImp(x_ - cn); +} + +template +FT_DEV_CEXPR auto operator<(Ranged a_, Ranged b_) +{ + return ltzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator<=(Ranged a_, Ranged b_) +{ + return lezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator>(Ranged a_, Ranged b_) +{ + return gtzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator>=(Ranged a_, Ranged b_) +{ + return gezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator==(Ranged a_, Ranged b_) +{ + return eqzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator!=(Ranged a_, Ranged b_) +{ + return nezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator<(Poly a_, Ranged b_) +{ + return ltzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator<=(Poly a_, Ranged b_) +{ + return lezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator>(Poly a_, Ranged b_) +{ + return gtzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator>=(Poly a_, Ranged b_) +{ + return gezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator==(Poly a_, Ranged b_) +{ + return eqzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator!=(Poly a_, Ranged b_) +{ + return nezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator<(Ranged a_, Poly b_) +{ + return ltzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator<=(Ranged a_, Poly b_) +{ + return lezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator>(Ranged a_, Poly b_) +{ + return gtzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator>=(Ranged a_, Poly b_) +{ + return gezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator==(Ranged a_, Poly b_) +{ + return eqzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator!=(Ranged a_, Poly b_) +{ + return nezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator<(Poly a_, Poly b_) +{ + return ltzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator<=(Poly a_, Poly b_) +{ + return lezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator>(Poly a_, Poly b_) +{ + return gtzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator>=(Poly a_, Poly b_) +{ + return gezeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator==(Poly a_, Poly b_) +{ + return eqzeroImp(a_ - b_); +} + +template +FT_DEV_CEXPR auto operator!=(Poly a_, Poly b_) +{ + return nezeroImp(a_ - b_); +} + +//---------------------------------------------------------------------------- +// swizzle() for Poly +//---------------------------------------------------------------------------- + +template +FT_DEV_CEXPR auto swizzle(A_ a_) +{ + return swizzle(Poly{a_}); +} + +template +FT_DEV_CEXPR auto swizzle(A_ a_, B_ b_) +{ + return swizzle(Poly{a_}, Poly{b_}); +} + +template +FT_DEV_CEXPR auto swizzle(A_ a_) +{ + return swizzle(Poly{a_}); +} + +template +FT_DEV_CEXPR auto swizzle(A_ a_, B_ b_) +{ + return swizzle(Poly{a_}, Poly{b_}); +} + +template +FT_DEV_CEXPR auto swizzle(A_ a_) +{ + return swizzle(Poly{a_}); +} + +template +FT_DEV_CEXPR auto swizzle(A_ a_, B_ b_) +{ + return swizzle(Poly{a_}, Poly{b_}); +} + +template +FT_DEV_CEXPR auto swizzle(Poly a_) +{ + static_assert((mode_ & (mode_ - 1)) == 0); + static_assert((unit_ & (unit_ - 1)) == 0); + static_assert(mode_ >= unit_); + + if constexpr (decltype(a_)::divisible(cn)) + { + if constexpr (decltype(a_ / cn)::divisible(cn) + && decltype(a_ % cn)::divisible(cn)) + { + if constexpr (mode_ == unit_) + return get(a_); + else if constexpr (decltype(a_ % cn / cn)::hasOnly(cn)) + return biasA_ + get(a_) + (get(a_) ^ get(a_ / cn % cn * cn)) + + get(a_); + else if constexpr (decltype(a_ % cn / cn)::hasOnly(cn)) + return biasA_ + (get(a_) ^ get(a_ / cn % cn * cn)) + get(a_) + + get(a_); + else + return get(a_) ^ get(a_ / cn % cn * cn); + } +#if 1 + else if constexpr (decltype(a_ % cn)::divisible(cn)) + { + if constexpr (mode_ == unit_) + return get(a_); + else if constexpr (decltype(a_ % cn / cn)::hasOnly(cn)) + return biasA_ + get(a_) + + (get(a_) ^ get(a_ / cn) % cn * cn) +get(a_); + else if constexpr (decltype(a_ % cn / cn)::hasOnly(cn)) + return biasA_ + (get(a_) ^ get(a_ / cn) % cn * cn) +get(a_) + + get(a_); +#endif + else + return get(a_) ^ get(a_ / cn) % cn * cn; +#if 1 + } + else + { + return get(a_) ^ get(a_) / cn % cn * cn; + } +#endif + } + else + { + return get(a_) ^ get(a_) / cn % cn * cn; + } +} + +template +FT_DEV_CEXPR auto swizzle(Poly a_, Poly b_) +{ + static_assert((mode_ & (mode_ - 1)) == 0); + static_assert((unit_ & (unit_ - 1)) == 0); + static_assert(mode_ >= unit_); + + if constexpr (decltype(a_)::divisible(cn)) + { + if constexpr (decltype(a_ / cn)::divisible(cn) + && decltype(a_ % cn)::divisible(cn)) + { + if constexpr (mode_ == unit_) + return get(b_ + a_); + else if constexpr (decltype(a_ % cn / cn)::hasOnly(cn)) + return biasB_ + biasA_ + get(b_ + a_) + + (get(b_) + (get(a_) ^ get(a_ / cn % cn * cn))) + + get(b_ + a_); + else if constexpr (decltype(a_ % cn / cn)::hasOnly(cn)) + return biasB_ + biasA_ + + (get(b_) + (get(a_) ^ get(a_ / cn % cn * cn))) + + get(b_ + a_) + get(b_ + a_); + else + return get(b_) + (get(a_) ^ get(a_ / cn % cn * cn)); + } +#if 1 + else if constexpr (decltype(a_ % cn)::divisible(cn)) + { + if constexpr (mode_ == unit_) + return get(b_ + a_); + else if constexpr (decltype(a_ % cn / cn)::hasOnly(cn)) + return biasB_ + biasA_ + get(b_ + a_) + + (get(b_) + (get(a_) ^ get(a_ / cn) % cn * cn) ) + + get(b_ + a_); + else if constexpr (decltype(a_ % cn / cn)::hasOnly(cn)) + return biasB_ + biasA_ + + (get(b_) + (get(a_) ^ get(a_ / cn) % cn * cn) ) + + get(b_ + a_) + get(b_ + a_); +#endif + else + return get(b_) + (get(a_) ^ get(a_ / cn) % cn * cn); +#if 1 + } + else + { + if constexpr (mode_ == unit_) + return get(b_ + a_); + else + return get(b_) + (get(a_) ^ get(a_) / cn % cn * cn); + } +#endif + } + else + { + if constexpr (mode_ == unit_) + return get(b_ + a_); + else + return get(b_) + (get(a_) ^ get(a_) / cn % cn * cn); + } +} + +// vim: ts=2 sw=2 sts=2 et sta diff --git a/cpp/tensorrt_llm/kernels/chunkScan/bmmchunk.h b/cpp/tensorrt_llm/kernels/chunkScan/bmmchunk.h new file mode 100644 index 000000000..68498cd66 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/chunkScan/bmmchunk.h @@ -0,0 +1,438 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" + +#include "Common.h" +#include "Poly.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +typedef void (*BmmChunkKernelFuncFp16)(int B_, int L_, int G_, int N_, + // const half *g_mxY_, // B*L*H*P + // const half *g_mxOs_, // B*C*H*N*P + // const half *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + // const float *g_mxdc_, // B*C*H*Q + // const float *g_mxdA_, // B*C*H*Q + // const half *g_mxdt_, // B*L*H + // const float *g_mxdb_, // H + // const float *g_mxA_, // H + half* g_mxCB_, // B*C*G*Q*Q + half const* g_mxBC_, // B*L*2*G*N + // const float *g_mxD_, // H + // const half *g_mxX_, // B*L*H*P + // const half *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_); + +typedef void (*BmmChunkKernelFuncBf16)(int B_, int L_, int G_, int N_, + // const bf16 *g_mxY_, // B*L*H*P + // const bf16 *g_mxOs_, // B*C*H*N*P + // const bf16 *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + // const float *g_mxdc_, // B*C*H*Q + // const float *g_mxdA_, // B*C*H*Q + // const bf16 *g_mxdt_, // B*L*H + // const float *g_mxdb_, // H + // const float *g_mxA_, // H + bf16* g_mxCB_, // B*C*G*Q*Q + bf16 const* g_mxBC_, // B*L*2*G*N + // const float *g_mxD_, // H + // const bf16 *g_mxX_, // B*L*H*P + // const bf16 *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_); + +template +__global__ std::enable_if_t || std::is_same_v> bmm_chunk_kernel(int B_, + int L_, int G_, int N_, + // const Tp_ *g_mxY_, // B*L*H*P + // const Tp_ *g_mxOs_, // B*C*H*N*P + // const Tp_ *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + // const float *g_mxdc_, // B*C*H*Q + // const float *g_mxdA_, // B*C*H*Q + // const Tp_ *g_mxdt_, // B*L*H + // const Wt_ *g_mxdb_, // H + // const Wt_ *g_mxA_, // H + Tp_* g_mxCB_, // B*C*G*Q*Q + Tp_ const* g_mxBC_, // B*L*2*G*N + // const Wt_ *g_mxD_, // H + // const Tp_ *g_mxX_, // B*L*H*P + // const Tp_ *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_) +{ +#if __CUDA_ARCH__ >= 800 + using namespace tensorrt_llm::common; + + auto blockIdx_x = Rn{int(blockIdx.x)}; + auto blockIdx_y = Rn{int(blockIdx.y)}; + auto blockIdx_z = Rn{int(blockIdx.z)}; + + auto threadIdx_x = Rn{int(threadIdx.x)}; + auto threadIdx_y = Rn{int(threadIdx.y)}; + auto threadIdx_z = Rn{int(threadIdx.z)}; + + // auto B = Rn{B_}; + auto L = Rn{L_}; + // auto H = Rn{H_}; + // auto P = Rn{P_}; + auto G = Rn{G_}; + auto N = Rn{N_}; + auto Q = cn; + auto C = Rn{div_up(L.var, Q_)}; + + auto aStart = blockIdx_z * L; + auto cStart = blockIdx_z * C; + + if (removePadding_) + { + aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; + cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; + L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; + C = Rn{div_up(L.var, Q_)}; + } + else + { + L = Rn{lastTokenIdsPtr_[blockIdx.z]}; + C = Rn{div_up(L.var, Q_)}; + } + + if (blockIdx_y * Q >= L) + return; + + auto gStart = blockIdx_x / (Q / cn) / (Q / cn); + auto mStart = blockIdx_x / (Q / cn) % (Q / cn); + auto nStart = blockIdx_x % (Q / cn); + + extern __shared__ float smem[]; + + Tp_* s_mxC = (Tp_*) smem; + Tp_* s_mxB = (Tp_*) smem + tileM_ * tileK_ * pipeS_; + Tp_* s_mxCB = (Tp_*) smem; + + unsigned b_base = __nvvm_get_smem_pointer(smem); + + unsigned b_mxC = b_base; + unsigned b_mxB = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_); + unsigned b_mxCB = b_base; + + using std::array; + + register array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxCB + = array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>(); + register array, tileM_ / wmmaM_ / warpM_> r_mxC; + register array, tileN_ / wmmaN_ / warpN_> r_mxB; + + constexpr int step = std::max( + 1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_)); + + auto baseC = [](auto iK) { return iK % cn * cn * cn; }; + auto baseB = [](auto iK) { return iK % cn * cn * cn; }; + + auto thread = [=](auto iStep) + { + return iStep * cn + threadIdx_z * cn + threadIdx_y * cn<256> + + threadIdx_x * cn<8>; + }; + +#pragma unroll + for (Rn iK; iK.var < iK.size; iK.var++) + { +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - mStart * cn) + cp_shared_global<16>(b_mxC + swizzle(thread(iStep) * cn<2>, baseC(iK) * cn<2>), + g_mxBC_ + + get( + (aStart + blockIdx_y * Q + mStart * cn + thread(iStep) / cn) *cn<2> * G * N + + cn<1> * G * N + gStart * N + iK * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(iK) * cn<2>)) + = int4{0, 0, 0, 0}; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - nStart * cn) + cp_shared_global<16>(b_mxB + swizzle(thread(iStep) * cn<2>, baseB(iK) * cn<2>), + g_mxBC_ + + get( + (aStart + blockIdx_y * Q + nStart * cn + thread(iStep) / cn) *cn<2> * G * N + + cn<0> * G * N + gStart * N + iK * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(iK) * cn<2>)) + = int4{0, 0, 0, 0}; + + cp_commit_group(); + } + + asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); + + __syncthreads(); + + for (int iK = pipeS_; iK < N_ / tileK_ + pipeS_; iK++) + { +#pragma unroll + for (int k = 0; k < tileK_ / wmmaK_; k++) + { +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0) + { + int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step; + int y1 = x1 - tileN_ / wmmaN_ / warpN_ + + (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1); + + if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_) + { + if (wmmaK_ == 16) + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(r_mxC[y1][0]), "=r"(r_mxC[y1][1]), "=r"(r_mxC[y1][2]), "=r"(r_mxC[y1][3]) + : "r"(b_mxC + iK % pipeS_ * (tileM_ * tileK_ * 2) + + 2 + * swz(y1 * warpM_ * wmmaM_ * tileK_ + k * wmmaK_ + + threadIdx.z * wmmaM_ * tileK_ + threadIdx.x % 16 * tileK_ + + threadIdx.x / 16 * 8))); + } + + if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_) + { + if (wmmaK_ == 16 && x1 % 2 == 0) + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(r_mxB[x1][0]), "=r"(r_mxB[x1][1]), "=r"(r_mxB[x1 + 1][0]), + "=r"(r_mxB[x1 + 1][1]) + : "r"(b_mxB + iK % pipeS_ * (tileK_ * tileN_ * 2) + + 2 + * swz(x1 * warpN_ * wmmaN_ * tileK_ + + k * wmmaK_ + threadIdx.y * wmmaN_ * tileK_ + + threadIdx.x % 8 * tileK_ + threadIdx.x / 8 % 2 * 8 + + threadIdx.x / wmmaK_ * warpN_ * wmmaN_ * tileK_))); + } + } + } + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + if (wmmaK_ == 16) + { + if (std::is_same_v) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(r_mxCB[y][x][0]), "+f"(r_mxCB[y][x][1]), "+f"(r_mxCB[y][x][2]), + "+f"(r_mxCB[y][x][3]) + : "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]), + "r"(r_mxB[x][0]), "r"(r_mxB[x][1])); + else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(r_mxCB[y][x][0]), "+f"(r_mxCB[y][x][1]), "+f"(r_mxCB[y][x][2]), + "+f"(r_mxCB[y][x][3]) + : "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]), + "r"(r_mxB[x][0]), "r"(r_mxB[x][1])); + } + } + } + + __syncthreads(); + + if (iK * tileK_ < N_) + { + + auto jK = Rn<>{iK}; +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - mStart * cn) + cp_shared_global<16>( + b_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>), + g_mxBC_ + + get((aStart + blockIdx_y * Q + mStart * cn + thread(iStep) / cn) *cn<2> + * G * N + + cn<1> * G * N + gStart * N + jK * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>)) + = int4{0, 0, 0, 0}; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - nStart * cn) + cp_shared_global<16>( + b_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>), + g_mxBC_ + + get((aStart + blockIdx_y * Q + nStart * cn + thread(iStep) / cn) *cn<2> + * G * N + + cn<0> * G * N + gStart * N + jK * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>)) + = int4{0, 0, 0, 0}; + } + + asm volatile("cp.async.commit_group;\n" ::); + + asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); + + __syncthreads(); + } + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + if (std::is_same_v) + { + *(half2*) &r_mxCB[y][x][0] = __floats2half2_rn(r_mxCB[y][x][0], r_mxCB[y][x][1]); + *(half2*) &r_mxCB[y][x][2] = __floats2half2_rn(r_mxCB[y][x][2], r_mxCB[y][x][3]); + } + else + { + *(bf162*) &r_mxCB[y][x][0] = __floats2bfloat162_rn(r_mxCB[y][x][0], r_mxCB[y][x][1]); + *(bf162*) &r_mxCB[y][x][2] = __floats2bfloat162_rn(r_mxCB[y][x][2], r_mxCB[y][x][3]); + } + } + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxCB + + 2 + * swz(y * warpM_ * wmmaM_ * tileN_ + x * warpN_ * wmmaN_ + + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_ + + (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))), + "r"(*(unsigned*) &r_mxCB[y][x][0])); + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxCB + + 2 + * swz(y * warpM_ * wmmaM_ * tileN_ + 8 * tileN_ + + x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_ + + (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))), + "r"(*(unsigned*) &r_mxCB[y][x][2])); + } + + __syncthreads(); + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn) + *(int4*) (g_mxCB_ + + get(cStart * G * Q * Q + blockIdx_y * G * Q * Q + gStart * Q * Q + + (mStart * cn + thread(iStep) / cn) *Q + nStart * cn + + thread(iStep) % cn)) + = *(int4*) ((char*) s_mxCB + swizzle(thread(iStep) * cn<2>)); + + asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); +#endif +} + +BmmChunkKernelFuncFp16 getBmmChunkKernelFp16( + int B_, int L_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + // int H = H_; + // int P = P_; + int G = G_; + // int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileM = 128; + int tileN = 64; + int tileK = 32; + int warpM = 2; + int warpN = 1; + int pipeS = 2; + + auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2, (tileM * tileN) * 2); + + *blockDims_ = dim3(G * Q / tileN * Q / tileM, C, B); + *threadDims_ = dim3(32, warpN, warpM); + *sharedMem_ = sharedMem; + + if (Q_ == 128) + return bmm_chunk_kernel<128, 128, 64, 32, 16, 8, 16, 2, 1, 2, half>; + else if (Q_ == 256) + return bmm_chunk_kernel<256, 128, 64, 32, 16, 8, 16, 2, 1, 2, half>; + else + return nullptr; +} + +BmmChunkKernelFuncBf16 getBmmChunkKernelBf16( + int B_, int L_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + // int H = H_; + // int P = P_; + int G = G_; + // int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileM = 128; + int tileN = 64; + int tileK = 32; + int warpM = 2; + int warpN = 1; + int pipeS = 2; + + auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2, (tileM * tileN) * 2); + + *blockDims_ = dim3(G * Q / tileN * Q / tileM, C, B); + *threadDims_ = dim3(32, warpN, warpM); + *sharedMem_ = sharedMem; + + if (Q_ == 128) + return bmm_chunk_kernel<128, 128, 64, 32, 16, 8, 16, 2, 1, 2, bf16>; + else if (Q_ == 256) + return bmm_chunk_kernel<256, 128, 64, 32, 16, 8, 16, 2, 1, 2, bf16>; + else + return nullptr; +} + +} // namespace kernels +} // namespace tensorrt_llm + +// vim: ts=2 sw=2 sts=2 et sta diff --git a/cpp/tensorrt_llm/kernels/chunkScan/chunkcumsum.h b/cpp/tensorrt_llm/kernels/chunkScan/chunkcumsum.h new file mode 100644 index 000000000..ab12309e9 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/chunkScan/chunkcumsum.h @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" + +#include "Common.h" +#include "Poly.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +typedef void (*ChunkCumsumKernelFuncFp16)(int B_, int L_, int H_, + // const half *g_mxY_, // B*L*H*P + // const half *g_mxOs_, // B*C*H*N*P + // const half *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + float* g_mxdc_, // B*C*H*Q + float* g_mxdA_, // B*C*H*Q + half const* g_mxdt_, // B*L*H + float const* g_mxdb_, // H + float const* g_mxA_, // H + // const half *g_mxCB_, // B*C*G*Q*Q + // const half *g_mxBC_, // B*L*2*G*N + // const float *g_mxD_, // H + // const half *g_mxX_, // B*L*H*P + // const half *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_); + +typedef void (*ChunkCumsumKernelFuncBf16)(int B_, int L_, int H_, + // const bf16 *g_mxY_, // B*L*H*P + // const bf16 *g_mxOs_, // B*C*H*N*P + // const bf16 *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + float* g_mxdc_, // B*C*H*Q + float* g_mxdA_, // B*C*H*Q + bf16 const* g_mxdt_, // B*L*H + float const* g_mxdb_, // H + float const* g_mxA_, // H + // const bf16 *g_mxCB_, // B*C*G*Q*Q + // const bf16 *g_mxBC_, // B*L*2*G*N + // const float *g_mxD_, // H + // const bf16 *g_mxX_, // B*L*H*P + // const bf16 *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_); + +template +__global__ std::enable_if_t || std::is_same_v> chunk_cumsum_kernel(int B_, + int L_, int H_, + // const Tp_ *g_mxY_, // B*L*H*P + // const Tp_ *g_mxOs_, // B*C*H*N*P + // const Tp_ *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + float* g_mxdc_, // B*C*H*Q + float* g_mxdA_, // B*C*H*Q + Tp_ const* g_mxdt_, // B*L*H + Wt_ const* g_mxdb_, // H + Wt_ const* g_mxA_, // H + // const Tp_ *g_mxCB_, // B*C*G*Q*Q + // const Tp_ *g_mxBC_, // B*L*2*G*N + // const Wt_ *g_mxD_, // H + // const Tp_ *g_mxX_, // B*L*H*P + // const Tp_ *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_) +{ + using namespace tensorrt_llm::common; + + auto blockIdx_x = Rn{int(blockIdx.x)}; + auto blockIdx_y = Rn{int(blockIdx.y)}; + auto blockIdx_z = Rn{int(blockIdx.z)}; + + auto threadIdx_x = Rn{int(threadIdx.x)}; + auto threadIdx_y = Rn{int(threadIdx.y)}; + + // auto B = Rn{B_}; + auto L = Rn{L_}; + auto H = Rn{H_}; + // auto P = Rn{P_}; + // auto G = Rn{G_}; + // auto N = Rn{N_}; + auto Q = cn; + auto C = Rn{div_up(L.var, Q_)}; + + auto aStart = blockIdx_z * L; + auto cStart = blockIdx_z * C; + + if (removePadding_) + { + aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; + cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; + L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; + C = Rn{div_up(L.var, Q_)}; + } + else + { + L = Rn{lastTokenIdsPtr_[blockIdx.z]}; + C = Rn{div_up(L.var, Q_)}; + } + + if (blockIdx_y * Q >= L) + return; + + auto thread = [=](auto iStep) { return iStep * cn + threadIdx_y * cn<32> + threadIdx_x; }; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + { + float r_A = 0.f, r_db = 0.f, sum = 0.f; + + if (thread(iStep) < cn) + r_A = g_mxA_[get(blockIdx_x * cn + thread(iStep))]; + if (thread(iStep) < cn && g_mxdb_) + r_db = g_mxdb_[get(blockIdx_x * cn + thread(iStep))]; + +#pragma unroll + for (Rn iQ; iQ.var < iQ.size; iQ.var++) + { + float r_dt = 0.f; + + if (thread(iStep) < cn && blockIdx_y * Q + iQ < L) + { + r_dt = float(g_mxdt_[get((aStart + blockIdx_y * Q + iQ) * H + blockIdx_x * cn + thread(iStep))]) + + r_db; + + if (dtSoftplus_) + r_dt = r_dt > 32.f ? r_dt : log1p(expf(r_dt)); + + sum += r_dt; + } + + if (thread(iStep) < cn) + { + g_mxdc_[get((cStart + blockIdx_y) * H * Q + (blockIdx_x * cn + thread(iStep)) * Q + iQ)] = r_dt; + g_mxdA_[get((cStart + blockIdx_y) * H * Q + (blockIdx_x * cn + thread(iStep)) * Q + iQ)] + = sum * r_A; + } + } + } +} + +ChunkCumsumKernelFuncFp16 getChunkCumsumKernelFp16( + int B_, int L_, int H_, int Q_, bool dtSoftPlus_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + int H = H_; + // int P = P_; + // int G = G_; + // int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileH = 1; + int warpH = 1; + + auto sharedMem = 0; + + *blockDims_ = dim3(H / tileH, C, B); + *threadDims_ = dim3(32, warpH); + *sharedMem_ = sharedMem; + + if (dtSoftPlus_) + { + if (Q_ == 128) + return chunk_cumsum_kernel<128, 1, 1, true, half>; + else if (Q_ == 256) + return chunk_cumsum_kernel<256, 1, 1, true, half>; + else + return nullptr; + } + else + { + if (Q_ == 128) + return chunk_cumsum_kernel<128, 1, 1, false, half>; + else if (Q_ == 256) + return chunk_cumsum_kernel<256, 1, 1, false, half>; + else + return nullptr; + } +} + +ChunkCumsumKernelFuncBf16 getChunkCumsumKernelBf16( + int B_, int L_, int H_, int Q_, bool dtSoftPlus_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + int H = H_; + // int P = P_; + // int G = G_; + // int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileH = 1; + int warpH = 1; + + auto sharedMem = 0; + + *blockDims_ = dim3(H / tileH, C, B); + *threadDims_ = dim3(32, warpH); + *sharedMem_ = sharedMem; + + if (dtSoftPlus_) + { + if (Q_ == 128) + return chunk_cumsum_kernel<128, 1, 1, true, bf16>; + else if (Q_ == 256) + return chunk_cumsum_kernel<256, 1, 1, true, bf16>; + else + return nullptr; + } + else + { + if (Q_ == 128) + return chunk_cumsum_kernel<128, 1, 1, false, bf16>; + else if (Q_ == 256) + return chunk_cumsum_kernel<256, 1, 1, false, bf16>; + else + return nullptr; + } +} + +} // namespace kernels +} // namespace tensorrt_llm + +// vim: ts=2 sw=2 sts=2 et sta diff --git a/cpp/tensorrt_llm/kernels/chunkScan/chunkscan.h b/cpp/tensorrt_llm/kernels/chunkScan/chunkscan.h new file mode 100644 index 000000000..40b3bf4fa --- /dev/null +++ b/cpp/tensorrt_llm/kernels/chunkScan/chunkscan.h @@ -0,0 +1,639 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" + +#include "Common.h" +#include "Poly.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +typedef void (*ChunkScanKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_, + half* g_mxY_, // B*L*H*P + half const* g_mxOs_, // B*C*H*N*P + // const half *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + float const* g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const half *g_mxdt_, // B*L*H + // const float *g_mxdb_, // H + // const float *g_mxA_, // H + half const* g_mxCB_, // B*C*G*Q*Q + half const* g_mxBC_, // B*L*2*G*N + float const* g_mxD_, // H + half const* g_mxX_, // B*L*H*P + half const* g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_); + +typedef void (*ChunkScanKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_, + bf16* g_mxY_, // B*L*H*P + bf16 const* g_mxOs_, // B*C*H*N*P + // const bf16 *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + float const* g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const bf16 *g_mxdt_, // B*L*H + // const float *g_mxdb_, // H + // const float *g_mxA_, // H + bf16 const* g_mxCB_, // B*C*G*Q*Q + bf16 const* g_mxBC_, // B*L*2*G*N + float const* g_mxD_, // H + bf16 const* g_mxX_, // B*L*H*P + bf16 const* g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_); + +template +__global__ std::enable_if_t || std::is_same_v> chunk_scan_kernel(int B_, + int L_, int H_, int P_, int G_, int N_, + Tp_* g_mxY_, // B*L*H*P + Tp_ const* g_mxOs_, // B*C*H*N*P + // const Tp_ *g_mxFs_, // B *H*N*P + // const float *g_mxSt_, // B*C*H*N*P + float const* g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const Tp_ *g_mxdt_, // B*L*H + // const Wt_ *g_mxdb_, // H + // const Wt_ *g_mxA_, // H + Tp_ const* g_mxCB_, // B*C*G*Q*Q + Tp_ const* g_mxBC_, // B*L*2*G*N + Wt_ const* g_mxD_, // H + Tp_ const* g_mxX_, // B*L*H*P + Tp_ const* g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_) +{ +#if __CUDA_ARCH__ >= 800 + using namespace tensorrt_llm::common; + + auto blockIdx_x = Rn{int(blockIdx.x)}; + auto blockIdx_y = Rn{int(blockIdx.y)}; + auto blockIdx_z = Rn{int(blockIdx.z)}; + + auto threadIdx_x = Rn{int(threadIdx.x)}; + auto threadIdx_y = Rn{int(threadIdx.y)}; + auto threadIdx_z = Rn{int(threadIdx.z)}; + + // auto B = Rn{B_}; + auto L = Rn{L_}; + auto H = Rn{H_}; + auto P = Rn{P_}; + auto G = Rn{G_}; + auto N = Rn{N_}; + auto Q = cn; + auto C = Rn{div_up(L.var, Q_)}; + + auto aStart = blockIdx_z * L; + auto cStart = blockIdx_z * C; + + if (removePadding_) + { + aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; + cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; + L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; + C = Rn{div_up(L.var, Q_)}; + } + else + { + L = Rn{lastTokenIdsPtr_[blockIdx.z]}; + C = Rn{div_up(L.var, Q_)}; + } + + if (blockIdx_y * Q >= L) + return; + + auto hStart = Rn{blockIdx_x.var / (P_ / cn) / (Q / cn) }; + auto mStart = Rn{blockIdx_x.var / (P_ / cn) % (Q / cn) }; + auto nStart = Rn{blockIdx_x.var % (P_ / cn) }; + auto gStart = Rn{hStart.var / (H_ / G_)}; + + extern __shared__ float smem[]; + + Tp_* s_mxC = (Tp_*) smem; + Tp_* s_mxOs = (Tp_*) smem + tileM_ * tileK_ * pipeS_; + Tp_* s_mxY = (Tp_*) smem; + + float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2; + float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_; + + unsigned b_base = __nvvm_get_smem_pointer(smem); + + unsigned b_mxC = b_base; + unsigned b_mxOs = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_); + unsigned b_mxY = b_base; + + using std::array; + + register array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxY + = array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>(); + register array, tileM_ / wmmaM_ / warpM_> r_mxC; + register array, tileN_ / wmmaN_ / warpN_> r_mxOs; + + constexpr int step = std::max( + 1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_)); + + auto baseC = [](auto iK) { return iK % cn * cn * cn; }; + auto baseOs = [](auto iK) { return iK % cn * cn * cn; }; + + auto thread = [=](auto iStep) + { + return iStep * cn + threadIdx_z * cn + threadIdx_y * cn<256> + + threadIdx_x * cn<8>; + }; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn) + { +#pragma unroll + for (int i = 0; i < 8; i += 4) + { + *(int4*) (s_mxdc + get(thread(iStep)) + i) + = *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i); + *(int4*) (s_mxdA + get(thread(iStep)) + i) + = *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i); + } + } + +#pragma unroll + for (Rn iK; iK.var < iK.size; iK.var++) + { +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - mStart * cn) + cp_shared_global<16>(b_mxC + swizzle(thread(iStep) * cn<2>, baseC(iK) * cn<2>), + g_mxBC_ + + get( + (aStart + blockIdx_y * Q + mStart * cn + thread(iStep) / cn) *cn<2> * G * N + + cn<1> * G * N + gStart * N + iK * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(iK) * cn<2>)) + = int4{0, 0, 0, 0}; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn) + cp_shared_global<16>( + b_mxOs + swizzle(thread(iStep) * cn<2>, baseOs(iK) * cn<2>), + g_mxOs_ + + get((cStart + blockIdx_y) * H * N * P + hStart * N * P + + (iK * cn + thread(iStep) / cn) *P + nStart * cn + + thread(iStep) % cn)); + + cp_commit_group(); + } + + asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); + + __syncthreads(); + + for (int iK = pipeS_; iK < (N_ + Q_) / tileK_ + pipeS_; iK++) + { + auto jK = Rn<>{iK}; + if ((iK - pipeS_) * cn == N_) + { + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + float2 tmp2 = float2{expf(s_mxdA[get(mStart * cn + Rn{y} * cn + + threadIdx_z * cn + threadIdx_x / cn<4>)]), + expf(s_mxdA[get(mStart * cn + Rn{y} * cn + cn<8> + + threadIdx_z * cn + threadIdx_x / cn<4>)])}; + + r_mxY[y][x][0] *= tmp2.x; + r_mxY[y][x][1] *= tmp2.x; + r_mxY[y][x][2] *= tmp2.y; + r_mxY[y][x][3] *= tmp2.y; + } + } + + if ((iK - pipeS_) * cn >= N_) + { + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn) + { + register Tp_ tmpCB[8]; + + *(int4*) &tmpCB[0] = *(int4*) ((char*) s_mxC + + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>)); + +#pragma unroll + for (int i = 0; i < 8; i += 2) + { + float2 tmp2 = std::is_same_v ? __half22float2(*(half2*) &tmpCB[i]) + : bf1622float2(*(bf162*) &tmpCB[i]); + + int kStart = (iK - pipeS_) * cn - N_; + + tmp2.x *= expf(s_mxdA[get(mStart * cn + thread(iStep) / cn)] + - s_mxdA[kStart + get(thread(iStep) % cn + Rn{i})]) + * s_mxdc[kStart + get(thread(iStep) % cn + Rn{i})]; + tmp2.y *= expf(s_mxdA[get(mStart * cn + thread(iStep) / cn)] + - s_mxdA[kStart + get(thread(iStep) % cn + Rn{i + 1})]) + * s_mxdc[kStart + get(thread(iStep) % cn + Rn{i + 1})]; + + if (get(mStart * cn + thread(iStep) / cn) + < kStart + get(thread(iStep) % cn + Rn{i})) + tmp2.x = 0; + if (get(mStart * cn + thread(iStep) / cn) + < kStart + get(thread(iStep) % cn + Rn{i + 1})) + tmp2.y = 0; + + if (std::is_same_v) + *(half2*) &tmpCB[i] = __float22half2_rn(tmp2); + else + *(bf162*) &tmpCB[i] = __float22bfloat162_rn(tmp2); + } + + *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>)) + = *(int4*) &tmpCB[0]; + } + + __syncthreads(); + } + +#pragma unroll + for (int k = 0; k < tileK_ / wmmaK_; k++) + { +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0) + { + int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step; + int y1 = x1 - tileN_ / wmmaN_ / warpN_ + + (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1); + + if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_) + { + if (wmmaK_ == 16) + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(r_mxC[y1][0]), "=r"(r_mxC[y1][1]), "=r"(r_mxC[y1][2]), "=r"(r_mxC[y1][3]) + : "r"(b_mxC + iK % pipeS_ * (tileM_ * tileK_ * 2) + + 2 + * swz(y1 * warpM_ * wmmaM_ * tileK_ + k * wmmaK_ + + threadIdx.z * wmmaM_ * tileK_ + threadIdx.x % 16 * tileK_ + + threadIdx.x / 16 * 8))); + } + + if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_) + { + if (wmmaK_ == 16 && x1 % 2 == 0) + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(r_mxOs[x1][0]), "=r"(r_mxOs[x1][1]), "=r"(r_mxOs[x1 + 1][0]), + "=r"(r_mxOs[x1 + 1][1]) + : "r"(b_mxOs + iK % pipeS_ * (tileK_ * tileN_ * 2) + + 2 + * swz(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_ + + threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_ + + threadIdx.x / wmmaK_ * warpN_ * wmmaN_))); + } + } + } + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + if (wmmaK_ == 16) + { + if (std::is_same_v) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3]) + : "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]), + "r"(r_mxOs[x][0]), "r"(r_mxOs[x][1])); + else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(r_mxY[y][x][0]), "+f"(r_mxY[y][x][1]), "+f"(r_mxY[y][x][2]), "+f"(r_mxY[y][x][3]) + : "r"(r_mxC[y][0]), "r"(r_mxC[y][1]), "r"(r_mxC[y][2]), "r"(r_mxC[y][3]), + "r"(r_mxOs[x][0]), "r"(r_mxOs[x][1])); + } + } + } + + __syncthreads(); + + if (iK * cn < N_) + { + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - mStart * cn) + cp_shared_global<16>( + b_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>), + g_mxBC_ + + get((aStart + blockIdx_y * Q + mStart * cn + thread(iStep) / cn) *cn<2> + * G * N + + cn<1> * G * N + gStart * N + jK * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>)) + = int4{0, 0, 0, 0}; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn) + cp_shared_global<16>( + b_mxOs + swizzle(thread(iStep) * cn<2>, baseOs(jK) * cn<2>), + g_mxOs_ + + get((cStart + blockIdx_y) * H * N * P + hStart * N * P + + (jK * cn + thread(iStep) / cn) *P + nStart * cn + + thread(iStep) % cn)); + } + else if (iK * cn < N_ + Q_) + { + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn) + cp_shared_global<16>( + b_mxC + swizzle(thread(iStep) * cn<2>, baseC(jK) * cn<2>), + g_mxCB_ + + get((cStart + blockIdx_y) * G * Q * Q + gStart * Q * Q + + (mStart * cn + thread(iStep) / cn) *Q + jK * cn + - N + thread(iStep) % cn)); + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - jK * cn + N) + cp_shared_global<16>( + b_mxOs + swizzle(thread(iStep) * cn<2>, baseOs(jK) * cn<2>), + g_mxX_ + + get((aStart + blockIdx_y * Q + jK * cn - N + thread(iStep) / cn) *H * P + + hStart * P + nStart * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxOs + + swizzle(thread(iStep) * cn<2>, baseOs(jK) * cn<2>)) + = int4{0, 0, 0, 0}; + } + + asm volatile("cp.async.commit_group;\n" ::); + + asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); + + __syncthreads(); + } + + if (g_mxD_) + { + float r_D = g_mxD_[hStart.var]; + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + Tp_ tmp16[4] = {0}; + float tmp32[4] = {0}; + + if (blockIdx_y * Q + mStart * cn + Rn{y} * cn + + threadIdx_z * cn + threadIdx_x / cn<4> + < L) + { + *(int*) &tmp16[0] = *(int*) (g_mxX_ + + get((aStart + blockIdx_y * Q + mStart * cn + Rn{y} * cn + + threadIdx_z * cn + threadIdx_x / cn<4>) *H + * P + + hStart * P + nStart * cn + Rn{x} * cn + + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)); + + *(float2*) &tmp32[0] = std::is_same_v ? __half22float2(*(half2*) &tmp16[0]) + : bf1622float2(*(bf162*) &tmp16[0]); + + r_mxY[y][x][0] += r_D * tmp32[0]; + r_mxY[y][x][1] += r_D * tmp32[1]; + } + + if (blockIdx_y * Q + mStart * cn + Rn{y} * cn + cn<8> + + threadIdx_z * cn + threadIdx_x / cn<4> + < L) + { + *(int*) &tmp16[2] = *(int*) (g_mxX_ + + get((aStart + blockIdx_y * Q + mStart * cn + Rn{y} * cn + + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4>) *H + * P + + hStart * P + nStart * cn + Rn{x} * cn + + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)); + + *(float2*) &tmp32[2] = std::is_same_v ? __half22float2(*(half2*) &tmp16[2]) + : bf1622float2(*(bf162*) &tmp16[2]); + + r_mxY[y][x][2] += r_D * tmp32[2]; + r_mxY[y][x][3] += r_D * tmp32[3]; + } + } + } + + if (g_mxZ_) + { +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + Tp_ tmp16[4] = {0}; + float tmp32[4] = {0}; + + if (blockIdx_y * Q + mStart * cn + Rn{y} * cn + + threadIdx_z * cn + threadIdx_x / cn<4> + < L) + { + *(int*) &tmp16[0] = *(int*) (g_mxZ_ + + get((aStart + blockIdx_y * Q + mStart * cn + Rn{y} * cn + + threadIdx_z * cn + threadIdx_x / cn<4>) *H + * P + + hStart * P + nStart * cn + Rn{x} * cn + + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)); + + *(float2*) &tmp32[0] = std::is_same_v ? __half22float2(*(half2*) &tmp16[0]) + : bf1622float2(*(bf162*) &tmp16[0]); + + r_mxY[y][x][0] *= tmp32[0] > 32.f ? tmp32[0] : tmp32[0] / (1.f + expf(-tmp32[0])); + r_mxY[y][x][1] *= tmp32[1] > 32.f ? tmp32[1] : tmp32[1] / (1.f + expf(-tmp32[1])); + } + + if (blockIdx_y * Q + mStart * cn + Rn{y} * cn + cn<8> + + threadIdx_z * cn + threadIdx_x / cn<4> + < L) + { + *(int*) &tmp16[2] = *(int*) (g_mxZ_ + + get((aStart + blockIdx_y * Q + mStart * cn + Rn{y} * cn + + cn<8> + threadIdx_z * cn + threadIdx_x / cn<4>) *H + * P + + hStart * P + nStart * cn + Rn{x} * cn + + threadIdx_y * cn + threadIdx_x % cn<4> * cn<2>)); + + *(float2*) &tmp32[2] = std::is_same_v ? __half22float2(*(half2*) &tmp16[2]) + : bf1622float2(*(bf162*) &tmp16[2]); + + r_mxY[y][x][2] *= tmp32[2] > 32.f ? tmp32[2] : tmp32[2] / (1.f + expf(-tmp32[2])); + r_mxY[y][x][3] *= tmp32[3] > 32.f ? tmp32[3] : tmp32[3] / (1.f + expf(-tmp32[3])); + } + } + } + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + if (std::is_same_v) + { + *(half2*) &r_mxY[y][x][0] = __floats2half2_rn(r_mxY[y][x][0], r_mxY[y][x][1]); + *(half2*) &r_mxY[y][x][2] = __floats2half2_rn(r_mxY[y][x][2], r_mxY[y][x][3]); + } + else + { + *(bf162*) &r_mxY[y][x][0] = __floats2bfloat162_rn(r_mxY[y][x][0], r_mxY[y][x][1]); + *(bf162*) &r_mxY[y][x][2] = __floats2bfloat162_rn(r_mxY[y][x][2], r_mxY[y][x][3]); + } + } + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY + + 2 + * swz(y * warpM_ * wmmaM_ * tileN_ + x * warpN_ * wmmaN_ + + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_ + + (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))), + "r"(*(unsigned*) &r_mxY[y][x][0])); + asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(b_mxY + + 2 + * swz(y * warpM_ * wmmaM_ * tileN_ + 8 * tileN_ + + x * warpN_ * wmmaN_ + (threadIdx.z * wmmaM_ + threadIdx.x / 4) * tileN_ + + (threadIdx.y * wmmaN_ + threadIdx.x % 4 * 2))), + "r"(*(unsigned*) &r_mxY[y][x][2])); + } + + __syncthreads(); + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - mStart * cn) + *(int4*) (g_mxY_ + + get((aStart + blockIdx_y * Q + mStart * cn + thread(iStep) / cn) *H * P + hStart * P + + nStart * cn + thread(iStep) % cn)) + = *(int4*) ((char*) s_mxY + swizzle(thread(iStep) * cn<2>)); + + asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); +#endif +} + +ChunkScanKernelFuncFp16 getChunkScanKernelFp16( + int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + int H = H_; + int P = P_; + // int G = G_; + // int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileM = 128; + int tileN = 64; + int tileK = 32; + int warpM = 4; + int warpN = 1; + int pipeS = 2; + + auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2); + + *blockDims_ = dim3(H * P / tileN * Q / tileM, C, B); + *threadDims_ = dim3(32, warpN, warpM); + *sharedMem_ = sharedMem; + + if (Q_ == 128) + return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>; + else if (Q_ == 256) + return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, half>; + else + return nullptr; +} + +ChunkScanKernelFuncBf16 getChunkScanKernelBf16( + int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + int H = H_; + int P = P_; + // int G = G_; + // int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileM = 128; + int tileN = 64; + int tileK = 32; + int warpM = 4; + int warpN = 1; + int pipeS = 2; + + auto sharedMem = std::max((tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8, (tileM * tileN) * 2); + + *blockDims_ = dim3(H * P / tileN * Q / tileM, C, B); + *threadDims_ = dim3(32, warpN, warpM); + *sharedMem_ = sharedMem; + + if (Q_ == 128) + return chunk_scan_kernel<128, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>; + else if (Q_ == 256) + return chunk_scan_kernel<256, 128, 64, 32, 16, 8, 16, 4, 1, 2, bf16>; + else + return nullptr; +} + +} // namespace kernels +} // namespace tensorrt_llm + +// vim: ts=2 sw=2 sts=2 et sta diff --git a/cpp/tensorrt_llm/kernels/chunkScan/chunkstate.h b/cpp/tensorrt_llm/kernels/chunkScan/chunkstate.h new file mode 100644 index 000000000..b25442795 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/chunkScan/chunkstate.h @@ -0,0 +1,453 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" + +#include "Common.h" +#include "Poly.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +typedef void (*ChunkStateKernelFuncFp16)(int B_, int L_, int H_, int P_, int G_, int N_, + // const half *g_mxY_, // B*L*H*P + // const half *g_mxOs_, // B*C*H*N*P + // const half *g_mxFs_, // B *H*N*P + float* g_mxSt_, // B*C*H*N*P + float const* g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const half *g_mxdt_, // B*L*H + // const float *g_mxdb_, // H + // const float *g_mxA_, // H + // const half *g_mxCB_, // B*C*G*Q*Q + half const* g_mxBC_, // B*L*2*G*N + // const float *g_mxD_, // H + half const* g_mxX_, // B*L*H*P + // const half *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_); + +typedef void (*ChunkStateKernelFuncBf16)(int B_, int L_, int H_, int P_, int G_, int N_, + // const bf16 *g_mxY_, // B*L*H*P + // const bf16 *g_mxOs_, // B*C*H*N*P + // const bf16 *g_mxFs_, // B *H*N*P + float* g_mxSt_, // B*C*H*N*P + float const* g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const bf16 *g_mxdt_, // B*L*H + // const float *g_mxdb_, // H + // const float *g_mxA_, // H + // const bf16 *g_mxCB_, // B*C*G*Q*Q + bf16 const* g_mxBC_, // B*L*2*G*N + // const float *g_mxD_, // H + bf16 const* g_mxX_, // B*L*H*P + // const bf16 *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_); + +template +__global__ std::enable_if_t || std::is_same_v> chunk_state_kernel(int B_, + int L_, int H_, int P_, int G_, int N_, + // const Tp_ *g_mxY_, // B*L*H*P + // const Tp_ *g_mxOs_, // B*C*H*N*P + // const Tp_ *g_mxFs_, // B *H*N*P + float* g_mxSt_, // B*C*H*N*P + float const* g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const Tp_ *g_mxdt_, // B*L*H + // const Wt_ *g_mxdb_, // H + // const Wt_ *g_mxA_, // H + // const Tp_ *g_mxCB_, // B*C*G*Q*Q + Tp_ const* g_mxBC_, // B*L*2*G*N + // const Wt_ *g_mxD_, // H + Tp_ const* g_mxX_, // B*L*H*P + // const Tp_ *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_) +{ +#if __CUDA_ARCH__ >= 800 + using namespace tensorrt_llm::common; + + auto blockIdx_x = Rn{int(blockIdx.x)}; + auto blockIdx_y = Rn{int(blockIdx.y)}; + auto blockIdx_z = Rn{int(blockIdx.z)}; + + auto threadIdx_x = Rn{int(threadIdx.x)}; + auto threadIdx_y = Rn{int(threadIdx.y)}; + auto threadIdx_z = Rn{int(threadIdx.z)}; + + // auto B = Rn{B_}; + auto L = Rn{L_}; + auto H = Rn{H_}; + auto P = Rn{P_}; + auto G = Rn{G_}; + auto N = Rn{N_}; + auto Q = cn; + auto C = Rn{div_up(L.var, Q_)}; + + auto aStart = blockIdx_z * L; + auto cStart = blockIdx_z * C; + + if (removePadding_) + { + aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; + cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; + L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; + C = Rn{div_up(L.var, Q_)}; + } + else + { + L = Rn{lastTokenIdsPtr_[blockIdx.z]}; + C = Rn{div_up(L.var, Q_)}; + } + + if (blockIdx_y * Q >= L) + return; + + auto hStart = Rn{blockIdx_x.var / (P_ / cn) / (N_ / cn) }; + auto mStart = Rn{blockIdx_x.var / (P_ / cn) % (N_ / cn) }; + auto nStart = Rn{blockIdx_x.var % (P_ / cn) }; + auto gStart = Rn{hStart.var / (H_ / G_)}; + + extern __shared__ float smem[]; + + Tp_* s_mxB = (Tp_*) smem; + Tp_* s_mxX = (Tp_*) smem + tileM_ * tileK_ * pipeS_; + + float* s_mxdc = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2; + float* s_mxdA = smem + (tileM_ + tileN_) * tileK_ * pipeS_ / 2 + Q_; + + unsigned b_base = __nvvm_get_smem_pointer(smem); + + unsigned b_mxB = b_base; + unsigned b_mxX = b_base + tileM_ * tileK_ * pipeS_ * sizeof(Tp_); + + using std::array; + + register array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_> r_mxSt + = array, tileN_ / wmmaN_ / warpN_>, tileM_ / wmmaM_ / warpM_>(); + register array, tileM_ / wmmaM_ / warpM_> r_mxB; + register array, tileN_ / wmmaN_ / warpN_> r_mxX; + + constexpr int step = std::max( + 1, tileM_ / wmmaM_ / warpM_ * tileN_ / wmmaN_ / warpN_ / (tileM_ / wmmaM_ / warpM_ + tileN_ / wmmaN_ / warpN_)); + + auto baseB = [](auto iK) { return iK % cn * cn * cn; }; + auto baseX = [](auto iK) { return iK % cn * cn * cn; }; + + auto thread = [=](auto iStep) + { + return iStep * cn + threadIdx_z * cn + threadIdx_y * cn<256> + + threadIdx_x * cn<8>; + }; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn) + { +#pragma unroll + for (int i = 0; i < 8; i += 4) + { + *(int4*) (s_mxdc + get(thread(iStep)) + i) + = *(int4*) (g_mxdc_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i); + *(int4*) (s_mxdA + get(thread(iStep)) + i) + = *(int4*) (g_mxdA_ + get((cStart + blockIdx_y) * H * Q + hStart * Q + thread(iStep)) + i); + } + } + +#pragma unroll + for (Rn iK; iK.var < iK.size; iK.var++) + { +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - iK * cn) + cp_shared_global<16>(b_mxB + swizzle(thread(iStep) * cn<2>, baseB(iK) * cn<2>), + g_mxBC_ + + get((aStart + blockIdx_y * Q + iK * cn + thread(iStep) / cn) *cn<2> * G * N + + gStart * N + mStart * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(iK) * cn<2>)) + = int4{0, 0, 0, 0}; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn + && thread(iStep) / cn < L - blockIdx_y * Q - iK * cn) + cp_shared_global<16>(b_mxX + swizzle(thread(iStep) * cn<2>, baseX(iK) * cn<2>), + g_mxX_ + + get((aStart + blockIdx_y * Q + iK * cn + thread(iStep) / cn) *H * P + + hStart * P + nStart * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn) + *(int4*) ((char*) s_mxX + swizzle(thread(iStep) * cn<2>, baseX(iK) * cn<2>)) + = int4{0, 0, 0, 0}; + + cp_commit_group(); + } + + asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); + + __syncthreads(); + + for (int iK = pipeS_; iK < Q_ / tileK_ + pipeS_; iK++) + { + auto jK = Rn<>{iK}; +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn) + { + register Tp_ tmpB[8]; + + *(int4*) &tmpB[0] = *( + int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>)); + +#pragma unroll + for (int i = 0; i < 8; i += 2) + { + float2 tmp2 = std::is_same_v ? __half22float2(*(half2*) &tmpB[i]) + : bf1622float2(*(bf162*) &tmpB[i]); + + int kStart = (iK - pipeS_) * cn; + + tmp2.x *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn)]) + * s_mxdc[kStart + get(thread(iStep) / cn)]; + tmp2.y *= expf(s_mxdA[Q_ - 1] - s_mxdA[kStart + get(thread(iStep) / cn)]) + * s_mxdc[kStart + get(thread(iStep) / cn)]; + + if (std::is_same_v) + *(half2*) &tmpB[i] = __float22half2_rn(tmp2); + else + *(bf162*) &tmpB[i] = __float22bfloat162_rn(tmp2); + } + + *(int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>)) + = *(int4*) &tmpB[0]; + } + + __syncthreads(); + +#pragma unroll + for (int k = 0; k < tileK_ / wmmaK_; k++) + { +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + if ((y * tileN_ / wmmaN_ / warpN_ + x) % step == 0) + { + int x1 = (y * tileN_ / wmmaN_ / warpN_ + x) / step; + int y1 = x1 - tileN_ / wmmaN_ / warpN_ + + (tileM_ / wmmaM_ / warpM_ == 1 || tileN_ / wmmaN_ / warpN_ == 1); + + if (y1 >= 0 && y1 < tileM_ / wmmaM_ / warpM_) + { + if (wmmaK_ == 16) + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(r_mxB[y1][0]), "=r"(r_mxB[y1][1]), "=r"(r_mxB[y1][2]), "=r"(r_mxB[y1][3]) + : "r"(b_mxB + iK % pipeS_ * (tileM_ * tileK_ * 2) + + 2 + * swz(y1 * warpM_ * wmmaM_ + k * wmmaK_ * tileM_ + + threadIdx.z * wmmaM_ + threadIdx.x % 8 * tileM_ + + threadIdx.x / 8 % 2 * 8 + threadIdx.x / wmmaK_ * 8 * tileM_))); + } + + if (x1 >= 0 && x1 < tileN_ / wmmaN_ / warpN_) + { + if (wmmaK_ == 16 && x1 % 2 == 0) + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(r_mxX[x1][0]), "=r"(r_mxX[x1][1]), "=r"(r_mxX[x1 + 1][0]), + "=r"(r_mxX[x1 + 1][1]) + : "r"(b_mxX + iK % pipeS_ * (tileK_ * tileN_ * 2) + + 2 + * swz(x1 * warpN_ * wmmaN_ + k * wmmaK_ * tileN_ + + threadIdx.y * wmmaN_ + threadIdx.x % wmmaK_ * tileN_ + + threadIdx.x / wmmaK_ * warpN_ * wmmaN_))); + } + } + } + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + if (wmmaK_ == 16) + { + if (std::is_same_v) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(r_mxSt[y][x][0]), "+f"(r_mxSt[y][x][1]), "+f"(r_mxSt[y][x][2]), + "+f"(r_mxSt[y][x][3]) + : "r"(r_mxB[y][0]), "r"(r_mxB[y][1]), "r"(r_mxB[y][2]), "r"(r_mxB[y][3]), + "r"(r_mxX[x][0]), "r"(r_mxX[x][1])); + else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(r_mxSt[y][x][0]), "+f"(r_mxSt[y][x][1]), "+f"(r_mxSt[y][x][2]), + "+f"(r_mxSt[y][x][3]) + : "r"(r_mxB[y][0]), "r"(r_mxB[y][1]), "r"(r_mxB[y][2]), "r"(r_mxB[y][3]), + "r"(r_mxX[x][0]), "r"(r_mxX[x][1])); + } + } + } + + __syncthreads(); + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - jK * cn + && jK * cn < Q) + cp_shared_global<16>(b_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>), + g_mxBC_ + + get((aStart + blockIdx_y * Q + jK * cn + thread(iStep) / cn) *cn<2> * G * N + + gStart * N + mStart * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn && jK * cn < Q) + *(int4*) ((char*) s_mxB + swizzle(thread(iStep) * cn<2>, baseB(jK) * cn<2>)) + = int4{0, 0, 0, 0}; + +#pragma unroll + for (Rn iStep; iStep.var < iStep.size; iStep.var++) + if (thread(iStep) < cn && thread(iStep) / cn < L - blockIdx_y * Q - jK * cn + && jK * cn < Q) + cp_shared_global<16>(b_mxX + swizzle(thread(iStep) * cn<2>, baseX(jK) * cn<2>), + g_mxX_ + + get((aStart + blockIdx_y * Q + jK * cn + thread(iStep) / cn) *H * P + + hStart * P + nStart * cn + thread(iStep) % cn)); + else if (thread(iStep) < cn && jK * cn < Q) + *(int4*) ((char*) s_mxX + swizzle(thread(iStep) * cn<2>, baseX(jK) * cn<2>)) + = int4{0, 0, 0, 0}; + + asm volatile("cp.async.commit_group;\n" ::); + + asm volatile("cp.async.wait_group %0;\n" ::"n"(pipeS_ - 1)); + + __syncthreads(); + } + +#pragma unroll + for (int y = 0; y < tileM_ / wmmaM_ / warpM_; y++) +#pragma unroll + for (int x = 0; x < tileN_ / wmmaN_ / warpN_; x++) + { + *(float2*) (g_mxSt_ + + get((cStart + blockIdx_y) * H * N * P + hStart * N * P + + (mStart * cn + Rn{y} * cn + threadIdx_z * cn + + threadIdx_x / cn<4>) *P + + nStart * cn + Rn{x} * cn + threadIdx_y * cn + + threadIdx_x % cn<4> * cn<2>)) + = *(float2*) &r_mxSt[y][x][0]; + + *(float2*) (g_mxSt_ + + get((cStart + blockIdx_y) * H * N * P + hStart * N * P + + (mStart * cn + Rn{y} * cn + cn<8> + threadIdx_z * cn + + threadIdx_x / cn<4>) *P + + nStart * cn + Rn{x} * cn + threadIdx_y * cn + + threadIdx_x % cn<4> * cn<2>)) + = *(float2*) &r_mxSt[y][x][2]; + } + + asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); +#endif +} + +ChunkStateKernelFuncFp16 getChunkStateKernelFp16( + int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + int H = H_; + int P = P_; + // int G = G_; + int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileM = 64; + int tileN = 64; + int tileK = 32; + int warpM = 1; + int warpN = 2; + int pipeS = 3; + + auto sharedMem = (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8; + + *blockDims_ = dim3(H * P / tileN * N / tileM, C, B); + *threadDims_ = dim3(32, warpN, warpM); + *sharedMem_ = sharedMem; + + if (Q_ == 128) + return chunk_state_kernel<128, 64, 64, 32, 16, 8, 16, 1, 2, 3, half>; + else if (Q_ == 256) + return chunk_state_kernel<256, 64, 64, 32, 16, 8, 16, 1, 2, 3, half>; + else + return nullptr; +} + +ChunkStateKernelFuncBf16 getChunkStateKernelBf16( + int B_, int L_, int H_, int P_, int G_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + int H = H_; + int P = P_; + // int G = G_; + int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileM = 64; + int tileN = 64; + int tileK = 32; + int warpM = 1; + int warpN = 2; + int pipeS = 3; + + auto sharedMem = (tileM * tileK + tileK * tileN) * pipeS * 2 + Q * 8; + + *blockDims_ = dim3(H * P / tileN * N / tileM, C, B); + *threadDims_ = dim3(32, warpN, warpM); + *sharedMem_ = sharedMem; + + if (Q_ == 128) + return chunk_state_kernel<128, 64, 64, 32, 16, 8, 16, 1, 2, 3, bf16>; + else if (Q_ == 256) + return chunk_state_kernel<256, 64, 64, 32, 16, 8, 16, 1, 2, 3, bf16>; + else + return nullptr; +} + +} // namespace kernels +} // namespace tensorrt_llm + +// vim: ts=2 sw=2 sts=2 et sta diff --git a/cpp/tensorrt_llm/kernels/chunkScan/statepassing.h b/cpp/tensorrt_llm/kernels/chunkScan/statepassing.h new file mode 100644 index 000000000..0445ddcb1 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/chunkScan/statepassing.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" + +#include "Common.h" +#include "Poly.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +typedef void (*StatePassingKernelFuncFp16)(int B_, int L_, int H_, int P_, int N_, + // const half *g_mxY_, // B*L*H*P + half* g_mxOs_, // B*C*H*N*P + half* g_mxFs_, // B *H*N*P + float const* g_mxSt_, // B*C*H*N*P + // const float *g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const half *g_mxdt_, // B*L*H + // const float *g_mxdb_, // H + // const float *g_mxA_, // H + // const half *g_mxCB_, // B*C*G*Q*Q + // const half *g_mxBC_, // B*L*2*G*N + // const float *g_mxD_, // H + // const half *g_mxX_, // B*L*H*P + // const half *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_); + +typedef void (*StatePassingKernelFuncBf16)(int B_, int L_, int H_, int P_, int N_, + // const bf16 *g_mxY_, // B*L*H*P + bf16* g_mxOs_, // B*C*H*N*P + bf16* g_mxFs_, // B *H*N*P + float const* g_mxSt_, // B*C*H*N*P + // const float *g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const bf16 *g_mxdt_, // B*L*H + // const float *g_mxdb_, // H + // const float *g_mxA_, // H + // const bf16 *g_mxCB_, // B*C*G*Q*Q + // const bf16 *g_mxBC_, // B*L*2*G*N + // const float *g_mxD_, // H + // const bf16 *g_mxX_, // B*L*H*P + // const bf16 *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_); + +template +__global__ std::enable_if_t || std::is_same_v> state_passing_kernel( + int B_, int L_, int H_, int P_, int N_, + // const Tp_ *g_mxY_, // B*L*H*P + Tp_* g_mxOs_, // B*C*H*N*P + Tp_* g_mxFs_, // B *H*N*P + float const* g_mxSt_, // B*C*H*N*P + // const float *g_mxdc_, // B*C*H*Q + float const* g_mxdA_, // B*C*H*Q + // const Tp_ *g_mxdt_, // B*L*H + // const Wt_ *g_mxdb_, // H + // const Wt_ *g_mxA_, // H + // const Tp_ *g_mxCB_, // B*C*G*Q*Q + // const Tp_ *g_mxBC_, // B*L*2*G*N + // const Wt_ *g_mxD_, // H + // const Tp_ *g_mxX_, // B*L*H*P + // const Tp_ *g_mxZ_, // B*L*H*P + bool removePadding_, int const* lastTokenIdsPtr_, int const* stateSlotMappingPtr_) +{ + using namespace tensorrt_llm::common; + + auto blockIdx_x = Rn{int(blockIdx.x)}; + auto blockIdx_y = Rn{int(blockIdx.y)}; + auto blockIdx_z = Rn{int(blockIdx.z)}; + + auto threadIdx_x = Rn{int(threadIdx.x)}; + auto threadIdx_y = Rn{int(threadIdx.y)}; + + // auto B = Rn{B_}; + auto L = Rn{L_}; + auto H = Rn{H_}; + auto P = Rn{P_}; + // auto G = Rn{G_}; + auto N = Rn{N_}; + auto Q = cn; + auto C = Rn{div_up(L.var, Q_)}; + + auto aStart = blockIdx_z * L; + auto cStart = blockIdx_z * C; + + if (removePadding_) + { + aStart = Rn{int(blockIdx.z ? lastTokenIdsPtr_[blockIdx.z - 1] : 0)}; + cStart = Rn{int(blockIdx.z ? div_up(aStart.var, Q_) + blockIdx.z - 1 : 0)}; + L = Rn{lastTokenIdsPtr_[blockIdx.z] - aStart.var}; + C = Rn{div_up(L.var, Q_)}; + } + else + { + L = Rn{lastTokenIdsPtr_[blockIdx.z]}; + C = Rn{div_up(L.var, Q_)}; + } + + if (stateSlotMappingPtr_) + { + g_mxFs_ += stateSlotMappingPtr_[blockIdx.z] * H_ * N_ * P_; + } + else + { + g_mxFs_ += blockIdx.z * H_ * N_ * P_; + } + + auto hStart = Rn{blockIdx_x.var * tileH_ / N_ / P_}; + + register Tp_ r_mxOs[tileH_ / (warpH_ * 32)] = {0}; + register float r_mxSt[tileH_ / (warpH_ * 32)] = {0}; + + for (int iC = 0; iC < C.var; iC++) + { + if (std::is_same_v) +#pragma unroll + for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2) + *(half2*) &r_mxOs[i] = __float22half2_rn(*(float2*) &r_mxSt[i]); + else +#pragma unroll + for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2) + *(bf162*) &r_mxOs[i] = __float22bfloat162_rn(*(float2*) &r_mxSt[i]); + +#pragma unroll + for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2) + *(int*) (g_mxOs_ + + get((cStart + Rn<>{iC}) * H * N * P + blockIdx_x * cn + + (threadIdx_y * cn<32> + threadIdx_x) * cn + Rn{i})) + = *(int*) &r_mxOs[i]; + + float scale = expf(g_mxdA_[get((cStart + Rn<>{iC}) * H * Q + hStart * Q + Q - cn<1>)]); + +#pragma unroll + for (int i = 0; i < tileH_ / (warpH_ * 32); i++) + { + float tmp = g_mxSt_[get((cStart + Rn<>{iC}) * H * N * P + blockIdx_x * cn + + (threadIdx_y * cn<32> + threadIdx_x) * cn + Rn{i})]; + + r_mxSt[i] = scale * r_mxSt[i] + tmp; + } + } + + if (std::is_same_v) +#pragma unroll + for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2) + *(half2*) &r_mxOs[i] = __float22half2_rn(*(float2*) &r_mxSt[i]); + else +#pragma unroll + for (int i = 0; i < tileH_ / (warpH_ * 32); i += 2) + *(bf162*) &r_mxOs[i] = __float22bfloat162_rn(*(float2*) &r_mxSt[i]); + +#pragma unroll + for (int i = 0; i < tileH_ / (warpH_ * 32); i += 8) + *(int4*) (g_mxFs_ + + get(blockIdx_x * cn + (threadIdx_y * cn<32> + threadIdx_x) * cn + + Rn{i})) + = *(int4*) &r_mxOs[i]; +} + +StatePassingKernelFuncFp16 getStatePassingKernelFp16( + int B_, int L_, int H_, int P_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + int H = H_; + int P = P_; + // int G = G_; + int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileH = 1024; + int warpH = 8; + + auto sharedMem = 0; + + *blockDims_ = dim3(H * N * P / tileH, 1, B); + *threadDims_ = dim3(32, warpH); + *sharedMem_ = sharedMem; + + if (Q_ == 128) + return state_passing_kernel<128, 1024, 8, half>; + else if (Q_ == 256) + return state_passing_kernel<256, 1024, 8, half>; + else + return nullptr; +} + +StatePassingKernelFuncBf16 getStatePassingKernelBf16( + int B_, int L_, int H_, int P_, int N_, int Q_, dim3* blockDims_, dim3* threadDims_, int* sharedMem_) +{ + int B = B_; + int L = L_; + int H = H_; + int P = P_; + // int G = G_; + int N = N_; + int Q = Q_; + int C = div_up(L, Q); + + int tileH = 1024; + int warpH = 8; + + auto sharedMem = 0; + + *blockDims_ = dim3(H * N * P / tileH, 1, B); + *threadDims_ = dim3(32, warpH); + *sharedMem_ = sharedMem; + + if (Q_ == 128) + return state_passing_kernel<128, 1024, 8, bf16>; + else if (Q_ == 256) + return state_passing_kernel<256, 1024, 8, bf16>; + else + return nullptr; +} + +} // namespace kernels +} // namespace tensorrt_llm + +// vim: ts=2 sw=2 sts=2 et sta diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp index 36fc94a8d..29e3ea2c4 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp @@ -61,9 +61,19 @@ bool supportConfigCommon(XQAParams const& xqaParams, bool forConfigurePlugin) { return false; } - if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads == xqaParams.num_kv_heads) + if (xqaParams.num_kv_heads != 0 && xqaParams.num_q_heads % xqaParams.num_kv_heads != 0) { - // Do not use XQA kernel for MHA. + return false; + } + bool is_vanilla_mha = xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads == xqaParams.num_kv_heads; + if (is_vanilla_mha && xqaParams.beam_width == 1) + { + // Do not use XQA kernel for vanilla MHA case for performance reasons. + return false; + } + if (is_vanilla_mha && xqaParams.head_size <= 128) + { + // TODO(yaoy): remove this when the kernel bug for num_kv_heads <= 128 gets fixed. return false; } if (xqaParams.multi_block_mode) @@ -108,11 +118,7 @@ bool supportConfigQGMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlu { return false; } - if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads % xqaParams.num_kv_heads != 0) - { - return false; - } - int32_t head_grp_size = xqaParams.num_q_heads / xqaParams.num_kv_heads; + int32_t head_grp_size = xqaParams.num_kv_heads == 0 ? 1 : xqaParams.num_q_heads / xqaParams.num_kv_heads; if (head_grp_size * xqaParams.beam_width > 32) { return false; @@ -150,11 +156,7 @@ bool supportConfigHMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlug { return false; } - if (xqaParams.num_kv_heads == 0 || xqaParams.num_q_heads % xqaParams.num_kv_heads != 0) - { - return false; - } - int32_t head_grp_size = xqaParams.num_q_heads / xqaParams.num_kv_heads; + int32_t head_grp_size = xqaParams.num_kv_heads == 0 ? 1 : xqaParams.num_q_heads / xqaParams.num_kv_heads; if (head_grp_size * xqaParams.beam_width > 32) { return false; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt index 6a59d0857..8833c1c28 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ 8b0f8deb35940359b39f876fc5e94e4f libtensorrt_llm_nvrtc_wrapper.so -0e1417f27d93de67940c1062cf230017cd8be5f1 commit \ No newline at end of file +d5f5542d2f1e10c4a6b60be56838ac79a9668665 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll index f6b3237f1..455b8d923 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:53746a0351295accb650f9e509303914ae8d8dc3c2605baf680f30cfc40d96f6 +oid sha256:78209a1351f9f21f635bf9f763f4947031ea12b7526c5782094e9869b667a23f size 1091072 diff --git a/cpp/tensorrt_llm/kernels/decodingKernels.cu b/cpp/tensorrt_llm/kernels/decodingKernels.cu index 8b795f185..e27534d62 100644 --- a/cpp/tensorrt_llm/kernels/decodingKernels.cu +++ b/cpp/tensorrt_llm/kernels/decodingKernels.cu @@ -482,12 +482,11 @@ __global__ void finalizeKernel(BeamHypotheses bh) void invokeFinalize(BeamHypotheses& bh, cudaStream_t stream) { - TLLM_LOG_TRACE("%s %s start", __FILE__, __PRETTY_FUNCTION__); + TLLM_LOG_DEBUG("%s %s start", __FILE__, __PRETTY_FUNCTION__); int const nBM = bh.nBeamWidth; size_t const smem_size = sizeof(int) * nBM * 2 + sizeof(float) * nBM * 2; finalizeKernel<<>>(bh); - TLLM_LOG_TRACE("%s %s stop", __FILE__, __PRETTY_FUNCTION__); } __global__ void initializeOutput(TokenIdType* finalOutputIds, TokenIdType const* endIds, SizeType32 const nMaxSeqLen) diff --git a/cpp/tensorrt_llm/kernels/selectiveScan.cu b/cpp/tensorrt_llm/kernels/selectiveScan.cu index b8f94f160..a4a5f12d5 100644 --- a/cpp/tensorrt_llm/kernels/selectiveScan.cu +++ b/cpp/tensorrt_llm/kernels/selectiveScan.cu @@ -28,6 +28,12 @@ #include "selectiveScan.h" +#include "chunkScan/bmmchunk.h" +#include "chunkScan/chunkcumsum.h" +#include "chunkScan/chunkscan.h" +#include "chunkScan/chunkstate.h" +#include "chunkScan/statepassing.h" + namespace tensorrt_llm { namespace kernels @@ -319,8 +325,6 @@ void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream) int samples = params.batch; int channels = params.dim; - TLLM_CHECK(params.is_variable_B); - TLLM_CHECK(params.is_variable_C); TLLM_CHECK(params.dstate == 16); int const threads = 128; @@ -331,6 +335,107 @@ void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream) selective_scan_loop_kernel<<>>(params); } +template +void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream) +{ + int B = params.batch; + int L = params.max_seqlen; + int H = params.nheads; + int P = params.dim / H; + int G = params.ngroups; + int N = params.dstate; + int Q = params.chunk_size; + + bool dtsp = params.delta_softplus; + + if constexpr (std::is_same_v) + { + dim3 bds[5], tds[5]; + int shms[5]; + + ChunkCumsumKernelFuncFp16 chunk_cumsum = getChunkCumsumKernelFp16(B, L, H, Q, dtsp, &bds[0], &tds[0], &shms[0]); + ChunkStateKernelFuncFp16 chunk_state = getChunkStateKernelFp16(B, L, H, P, G, N, Q, &bds[1], &tds[1], &shms[1]); + StatePassingKernelFuncFp16 state_passing + = getStatePassingKernelFp16(B, L, H, P, N, Q, &bds[2], &tds[2], &shms[2]); + BmmChunkKernelFuncFp16 bmm_chunk = getBmmChunkKernelFp16(B, L, G, N, Q, &bds[3], &tds[3], &shms[3]); + ChunkScanKernelFuncFp16 chunk_scan = getChunkScanKernelFp16(B, L, H, P, G, N, Q, &bds[4], &tds[4], &shms[4]); + + half* mxY = (half*) params.out_ptr; + half* mxOs = (half*) params.Os_ptr; + half* mxFs = (half*) params.x_ptr; + float* mxSt = (float*) params.St_ptr; + float* mxdc = (float*) params.dc_ptr; + float* mxdA = (float*) params.dA_ptr; + half const* mxdt = (half const*) params.delta_ptr; + float const* mxdb = (float const*) params.delta_bias_ptr; + float const* mxA = (float const*) params.A_ptr; + half* mxCB = (half*) params.CB_ptr; + half const* mxBC = (half const*) params.BC_ptr; + float const* mxD = (float const*) params.D_ptr; + half const* mxX = (half const*) params.u_ptr; + half const* mxZ = (half const*) params.z_ptr; + + auto rp = params.remove_padding; + auto ltip = params.last_token_ids_ptr; + auto ssmp = params.slot_mapping_ptr; + + cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]); + chunk_cumsum<<>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip); + cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]); + chunk_state<<>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip); + cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]); + state_passing<<>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp); + cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]); + bmm_chunk<<>>(B, L, G, N, mxCB, mxBC, rp, ltip); + cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]); + chunk_scan<<>>( + B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip); + } + else if constexpr (std::is_same_v) + { + dim3 bds[5], tds[5]; + int shms[5]; + + ChunkCumsumKernelFuncBf16 chunk_cumsum = getChunkCumsumKernelBf16(B, L, H, Q, dtsp, &bds[0], &tds[0], &shms[0]); + ChunkStateKernelFuncBf16 chunk_state = getChunkStateKernelBf16(B, L, H, P, G, N, Q, &bds[1], &tds[1], &shms[1]); + StatePassingKernelFuncBf16 state_passing + = getStatePassingKernelBf16(B, L, H, P, N, Q, &bds[2], &tds[2], &shms[2]); + BmmChunkKernelFuncBf16 bmm_chunk = getBmmChunkKernelBf16(B, L, G, N, Q, &bds[3], &tds[3], &shms[3]); + ChunkScanKernelFuncBf16 chunk_scan = getChunkScanKernelBf16(B, L, H, P, G, N, Q, &bds[4], &tds[4], &shms[4]); + + __nv_bfloat16* mxY = (__nv_bfloat16*) params.out_ptr; + __nv_bfloat16* mxOs = (__nv_bfloat16*) params.Os_ptr; + __nv_bfloat16* mxFs = (__nv_bfloat16*) params.x_ptr; + float* mxSt = (float*) params.St_ptr; + float* mxdc = (float*) params.dc_ptr; + float* mxdA = (float*) params.dA_ptr; + __nv_bfloat16 const* mxdt = (__nv_bfloat16 const*) params.delta_ptr; + float const* mxdb = (float const*) params.delta_bias_ptr; + float const* mxA = (float const*) params.A_ptr; + __nv_bfloat16* mxCB = (__nv_bfloat16*) params.CB_ptr; + __nv_bfloat16 const* mxBC = (__nv_bfloat16 const*) params.BC_ptr; + float const* mxD = (float const*) params.D_ptr; + __nv_bfloat16 const* mxX = (__nv_bfloat16 const*) params.u_ptr; + __nv_bfloat16 const* mxZ = (__nv_bfloat16 const*) params.z_ptr; + + auto rp = params.remove_padding; + auto ltip = params.last_token_ids_ptr; + auto ssmp = params.slot_mapping_ptr; + + cudaFuncSetAttribute(chunk_cumsum, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[0]); + chunk_cumsum<<>>(B, L, H, mxdc, mxdA, mxdt, mxdb, mxA, rp, ltip); + cudaFuncSetAttribute(chunk_state, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[1]); + chunk_state<<>>(B, L, H, P, G, N, mxSt, mxdc, mxdA, mxBC, mxX, rp, ltip); + cudaFuncSetAttribute(state_passing, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[2]); + state_passing<<>>(B, L, H, P, N, mxOs, mxFs, mxSt, mxdA, rp, ltip, ssmp); + cudaFuncSetAttribute(bmm_chunk, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[3]); + bmm_chunk<<>>(B, L, G, N, mxCB, mxBC, rp, ltip); + cudaFuncSetAttribute(chunk_scan, cudaFuncAttributeMaxDynamicSharedMemorySize, shms[4]); + chunk_scan<<>>( + B, L, H, P, G, N, mxY, mxOs, mxdc, mxdA, mxCB, mxBC, mxD, mxX, mxZ, rp, ltip); + } +} + #define INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE(input_t, weight_t) \ template void invokeSelectiveScan(SSMParamsBase & params, cudaStream_t stream); @@ -341,9 +446,19 @@ INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE(__nv_bfloat16, float); #endif #undef INSTANTIATE_SELECTIVE_SCAN_DATA_TYPE +#define INSTANTIATE_CHUNK_SCAN_DATA_TYPE(input_t, weight_t) \ + template void invokeChunkScan(SSMParamsBase & params, cudaStream_t stream); + +INSTANTIATE_CHUNK_SCAN_DATA_TYPE(float, float); +INSTANTIATE_CHUNK_SCAN_DATA_TYPE(half, float); +#ifdef ENABLE_BF16 +INSTANTIATE_CHUNK_SCAN_DATA_TYPE(__nv_bfloat16, float); +#endif +#undef INSTANTIATE_CHUNK_SCAN_DATA_TYPE + //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParamsBase params) { @@ -359,15 +474,21 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams weight_t* dt_bias = reinterpret_cast(params.delta_bias_ptr); bool dt_softplus = params.delta_softplus; int num_channels = params.dim; + int nheads = params.nheads; + int ngroups = params.ngroups; int const channel = blockIdx.x * blockDim.x + threadIdx.x; if (channel >= num_channels) return; int const sample = blockIdx.y; + int const head_dim = num_channels / nheads; + int const head = channel / head_dim; + int const head_chl = channel % head_dim; + int const group = head / (nheads / ngroups); int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample]; - int const bc_cols = DSTATE * 2 + params.dt_rank; - int const b_offset = params.dt_rank; - int const c_offset = params.dt_rank + DSTATE; + int const bc_offset = MAMBA_V1 ? sample * (DSTATE * 2 + params.dt_rank) : sample * DSTATE * ngroups * 2; + int const b_offset = MAMBA_V1 ? params.dt_rank : DSTATE * group; + int const c_offset = MAMBA_V1 ? params.dt_rank + DSTATE : DSTATE * (ngroups + group); input_t* my_state = &state[slot_idx * num_channels * DSTATE]; input_t* my_output = &output[sample * num_channels]; @@ -375,30 +496,45 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams float rA[DSTATE]; float rB[DSTATE]; float rC[DSTATE]; - float rState[DSTATE]; + float my_x, my_dt, my_z, my_dt_bias, my_D; + my_x = toFloat(x[sample * num_channels + channel]); + my_z = z ? toFloat(z[sample * num_channels + channel]) : 0.f; + if (MAMBA_V1) + { #pragma unroll - for (int i = 0; i < DSTATE; i++) + for (int i = 0; i < DSTATE; i++) + { + rA[i] = toFloat(A[i * num_channels + channel]); + rB[i] = toFloat(B[bc_offset + b_offset + i]); + rC[i] = toFloat(C[bc_offset + c_offset + i]); + rState[i] = toFloat(my_state[i * num_channels + channel]); + } + my_dt = toFloat(dt[sample * num_channels + channel]); + my_dt_bias = dt_bias ? toFloat(dt_bias[channel]) : 0.f; + my_D = D ? toFloat(D[channel]) : 0.f; + } + else { - rA[i] = toFloat(A[i * num_channels + channel]); - rB[i] = toFloat(B[sample * bc_cols + b_offset + i]); - rC[i] = toFloat(C[sample * bc_cols + c_offset + i]); - rState[i] = toFloat(my_state[i * num_channels + channel]); + float A_tmp = toFloat(A[head]); +#pragma unroll + for (int i = 0; i < DSTATE; i++) + { + rA[i] = A_tmp; + rB[i] = toFloat(B[bc_offset + b_offset + i]); + rC[i] = toFloat(C[bc_offset + c_offset + i]); + rState[i] = toFloat(my_state[(head * DSTATE + i) * head_dim + head_chl]); + } + my_dt = toFloat(dt[sample * nheads + head]); + my_dt_bias = dt_bias ? toFloat(dt_bias[head]) : 0.f; + my_D = D ? toFloat(D[head]) : 0.f; } - float my_x, my_dt, my_z, my_dt_bias, my_D; - my_x = toFloat(x[sample * num_channels + channel]); - my_dt = toFloat(dt[sample * num_channels + channel]); - my_z = z ? toFloat(z[sample * num_channels + channel]) : 0.f; - my_dt_bias = dt_bias ? toFloat(dt_bias[channel]) : 0.f; - my_D = D ? toFloat(D[channel]) : 0.f; - float dt_b = my_dt + my_dt_bias; float dt_b_sp; if (dt_softplus) { - // dt_b_sp = dt_b <= 20.f ? logf(1.f + expf(dt_b)) : dt_b; // softplus dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus } @@ -407,19 +543,21 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams #pragma unroll for (int i = 0; i < DSTATE; i++) { - // float dA = expf(rA[i] * dt_b_sp); float dA = __expf(rA[i] * dt_b_sp); float dB = rB[i] * dt_b_sp; float sdA = rState[i] * dA; float dBx = dB * my_x; float newState = sdA + dBx; - convertAndStore(&my_state[i * num_channels + channel], newState); // Write the new state back out to the cache + // Write the new state back out to the cache + if (MAMBA_V1) + convertAndStore(&my_state[i * num_channels + channel], newState); + else + convertAndStore(&my_state[(head * DSTATE + i) * head_dim + head_chl], newState); out += newState * rC[i]; } if (z) { - // float sig_z = 1.0 / (1.0 + exp(0.f - my_z)); float sig_z = __fdividef(1.f, (1.f + __expf(0.f - my_z))); float silu_z = my_z * sig_z; out *= silu_z; @@ -433,16 +571,25 @@ void invokeSelectiveScanUpdate(SSMParamsBase& params, cudaStream_t stream) { int samples = params.batch; int channels = params.dim; + int nheads = params.nheads; + int ngroups = params.ngroups; int const threads = 128; int const blocks = (channels + threads - 1) / threads; dim3 block(threads, 1); dim3 grid(blocks, samples); - TLLM_CHECK(params.is_variable_B); - TLLM_CHECK(params.is_variable_C); - TLLM_CHECK(params.dstate == 16); - selective_scan_update_kernel<<>>(params); + TLLM_CHECK_WITH_INFO(nheads % ngroups == 0, "nheads must be divisible by ngroups"); + if (params.is_mamab2) + { + TLLM_CHECK(params.dstate == 128); + selective_scan_update_kernel<<>>(params); + } + else + { + TLLM_CHECK(params.dstate == 16); + selective_scan_update_kernel<<>>(params); + } } #define INSTANTIATE_SELECTIVE_SCAN_UPDATE_DATA_TYPE(input_t, weight_t) \ diff --git a/cpp/tensorrt_llm/kernels/selectiveScan.h b/cpp/tensorrt_llm/kernels/selectiveScan.h index fec892923..0f02a88c2 100644 --- a/cpp/tensorrt_llm/kernels/selectiveScan.h +++ b/cpp/tensorrt_llm/kernels/selectiveScan.h @@ -40,13 +40,11 @@ namespace kernels struct SSMParamsBase { - int batch, dim, dstate, dt_rank; + int batch, dim, dstate, dt_rank, nheads, ngroups, chunk_size; int max_seqlen; // only valid for padded input. bool remove_padding; - bool is_variable_B; - bool is_variable_C; - bool delta_softplus; + bool is_mamab2; // Common data pointers. void* __restrict__ A_ptr; @@ -58,6 +56,12 @@ struct SSMParamsBase void* __restrict__ out_ptr; void* __restrict__ x_ptr; void* __restrict__ z_ptr; + // Workspace data pointers. + void* __restrict__ Os_ptr; + void* __restrict__ St_ptr; + void* __restrict__ dc_ptr; + void* __restrict__ dA_ptr; + void* __restrict__ CB_ptr; int const* __restrict__ last_token_ids_ptr; int const* __restrict__ slot_mapping_ptr; }; @@ -67,6 +71,9 @@ struct SSMParamsBase template void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream); +template +void invokeChunkScan(SSMParamsBase& params, cudaStream_t stream); + template void invokeSelectiveScanUpdate(SSMParamsBase& params, cudaStream_t stream); } // namespace kernels diff --git a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp index 3c8075bed..402d1375d 100644 --- a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp +++ b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp @@ -158,7 +158,7 @@ void GemmPluginProfiler::profileT common::check_cuda_error(cudaStreamCreate(&mStream)); int const startMinMRounded = nextPowerOfTwo(dims.minM); - for (int m = startMinMRounded; m < maxM; m *= 2) + for (int m = std::max(1, startMinMRounded); m < maxM; m *= 2) { profileTactics(m, dims.n, dims.k); } @@ -184,7 +184,7 @@ std::optional GemmPluginProfilergetMProfileMap(gemmId)->at(mRounded); } diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index a716b039e..73d95cd44 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -275,8 +275,23 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c { TLLM_CHECK(mHeadSize > 0); - int const beamWidth - = isCrossAttention() ? 1 : (useKVCache() ? in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1] : 1); + int beamWidth = -1; + if (!isCrossAttention() && useKVCache()) + { + // desc_val == -1 means beam_width is not static, we should look at min/max/opt. + // + // In prepareEnqueueGeneration, we'll prepare for all cases where beam_width doesn't exceed max. + // TODO(minwei): pass min AND max to prepareEnqueueGeneration instead of max only. + int desc_val = in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1]; + int max_val = in[getIdx(IdxEntry::CACHE_INDIR)].max.d[1]; + beamWidth = desc_val == -1 ? max_val : desc_val; + } + else + { + beamWidth = 1; + } + TLLM_CHECK(beamWidth != -1); + // Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same // unless each layer has different attention window sizes. // the kv_cache capacity. diff --git a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp index 5df036834..40f7c67b5 100644 --- a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp @@ -29,18 +29,22 @@ static char const* SELECTIVE_SCAN_PLUGIN_NAME{"SelectiveScan"}; PluginFieldCollection SelectiveScanPluginCreator::mFC{}; std::vector SelectiveScanPluginCreator::mPluginAttributes; -SelectiveScanPlugin::SelectiveScanPlugin(int dim, int dstate, int dt_rank, bool isVariableB, bool isVariableC, - bool deltaSoftplus, nvinfer1::DataType type, bool removePadding, bool pagedState) +SelectiveScanPlugin::SelectiveScanPlugin(int dim, int dstate, int dtRank, int nHeads, int nGroups, int chunkSize, + bool deltaSoftplus, nvinfer1::DataType type, bool removePadding, bool pagedState, bool zEnabled, bool isMamba2) : mDim(dim) , mDState(dstate) - , mDtRank(dt_rank) - , mIsVariableB(isVariableB) - , mIsVariableC(isVariableC) + , mDtRank(dtRank) + , mNHeads(nHeads) + , mNGroups(nGroups) + , mChunkSize(chunkSize) , mDeltaSoftplus(deltaSoftplus) , mType(type) , mRemovePadding(removePadding) , mPagedState(pagedState) + , mZEnabled(zEnabled) + , mIsMamba2(isMamba2) { + TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (!mIsMamba2), "Pre SM 80 GPUs do not support Mamba2"); TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), "Unsupported data type, pre SM 80 GPUs do not support bfloat16"); TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF), @@ -54,12 +58,15 @@ SelectiveScanPlugin::SelectiveScanPlugin(void const* data, size_t length) read(d, mDim); read(d, mDState); read(d, mDtRank); - read(d, mIsVariableB); - read(d, mIsVariableC); + read(d, mNHeads); + read(d, mNGroups); + read(d, mChunkSize); read(d, mDeltaSoftplus); read(d, mType); read(d, mRemovePadding); read(d, mPagedState); + read(d, mZEnabled); + read(d, mIsMamba2); TLLM_CHECK(d == a + length); TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), "Unsupported data type"); TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF), @@ -69,8 +76,8 @@ SelectiveScanPlugin::SelectiveScanPlugin(void const* data, size_t length) // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* SelectiveScanPlugin::clone() const noexcept { - auto* plugin = new SelectiveScanPlugin( - mDim, mDState, mDtRank, mIsVariableB, mIsVariableC, mDeltaSoftplus, mType, mRemovePadding, mPagedState); + auto* plugin = new SelectiveScanPlugin(mDim, mDState, mDtRank, mNHeads, mNGroups, mChunkSize, mDeltaSoftplus, mType, + mRemovePadding, mPagedState, mZEnabled, mIsMamba2); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } @@ -91,7 +98,8 @@ nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions( bool SelectiveScanPlugin::supportsFormatCombination( int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { - if (pos == getHostRequestTypesIdx() || pos == getLastTokenIdsIdx() || (mPagedState && pos == getSlotMappingIdx())) + if (pos == getHostRequestTypesIdx() || pos == getLastTokenIdsIdx() + || (mRemovePadding && pos == getHostContextLengthIdx()) || (mPagedState && pos == getSlotMappingIdx())) { return inOut[pos].type == nvinfer1::DataType::kINT32; } @@ -117,14 +125,56 @@ void SelectiveScanPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc cons size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { - return 0; + if (!mIsMamba2) + return 0; + + int const NUM_BUFFERS = 5; + size_t workspaces[NUM_BUFFERS]; + + if (mRemovePadding) + { + int B = inputs[getLastTokenIdsIdx()].dims.d[0]; + int BxL = inputs[getInputTensorIdx()].dims.d[0]; // num_tokens + int H = mNHeads; + int P = inputs[getInputTensorIdx()].dims.d[1] / H; + int G = mNGroups; + int N = inputs[getBCIdx()].dims.d[1] / G / 2; + int Q = mChunkSize; + int BxC = (BxL + Q - 1) / Q + B; + + workspaces[0] = BxC * H * N * P * 2; // g_mxOs_ + workspaces[1] = BxC * H * N * P * 4; // g_mxSt_ in float + workspaces[2] = BxC * H * Q * 4; // g_mxdc_ in float + workspaces[3] = BxC * H * Q * 4; // g_mxdA_ in float + workspaces[4] = BxC * G * Q * Q * 2; // g_mxCB_ + } + else + { + int B = inputs[getInputTensorIdx()].dims.d[0]; + int L = inputs[getInputTensorIdx()].dims.d[1]; + int H = mNHeads; + int P = inputs[getInputTensorIdx()].dims.d[2] / H; + int G = mNGroups; + int N = inputs[getBCIdx()].dims.d[2] / G / 2; + int Q = mChunkSize; + int C = (L + Q - 1) / Q; + + workspaces[0] = B * C * H * N * P * 2; // g_mxOs_ + workspaces[1] = B * C * H * N * P * 4; // g_mxSt_ in float + workspaces[2] = B * C * H * Q * 4; // g_mxdc_ in float + workspaces[3] = B * C * H * Q * 4; // g_mxdA_ in float + workspaces[4] = B * C * G * Q * Q * 2; // g_mxCB_ + } + + return calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS); } void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch, const size_t dim, - const size_t maxSeqLen, const size_t dstate, const size_t dtRank, bool const isVariableB, bool const isVariableC, - void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* BC, - void const* D, void const* z, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus, - bool removePadding) + const size_t maxSeqLen, const size_t dstate, const size_t dtRank, const size_t nHeads, const size_t nGroups, + const size_t chunkSize, void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, + void const* BC, void const* D, void const* z, void const* osPtr, void const* stPtr, void const* dcPtr, + void const* dAPtr, void const* cbPtr, int const* lastTokenIds, int const* slotMapping, void* out, + bool deltaSoftplus, bool removePadding) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -134,12 +184,13 @@ void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch params.max_seqlen = maxSeqLen; params.dstate = dstate; params.dt_rank = dtRank; + params.nheads = nHeads; + params.ngroups = nGroups; + params.chunk_size = chunkSize; params.delta_softplus = deltaSoftplus; params.remove_padding = removePadding; - - params.is_variable_B = isVariableB; - params.is_variable_C = isVariableC; + params.is_mamab2 = mIsMamba2; // Set the pointers and strides. params.u_ptr = const_cast(x); @@ -151,6 +202,11 @@ void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch params.out_ptr = out; params.x_ptr = statePtr; params.z_ptr = const_cast(z); + params.Os_ptr = const_cast(osPtr); + params.St_ptr = const_cast(stPtr); + params.dc_ptr = const_cast(dcPtr); + params.dA_ptr = const_cast(dAPtr); + params.CB_ptr = const_cast(cbPtr); params.last_token_ids_ptr = lastTokenIds; params.slot_mapping_ptr = slotMapping; } @@ -162,24 +218,30 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc { // inputs // 0. input_tensor [batch_size, max_seq_len, dim] or [num_tokens, dim] - // 1. state [batch_size, dstate, dim] or host [1] containing only pointer for paged_state - // 2. delta [batch_size, max_seq_len, dim] or [num_tokens, dim] - // 3. delta_bias [dim] - // 4. A [dstate, dim] - // 5. BC [batch_size, max_seq_len, dt_rank + dstate * 2] or [num_tokens, dt_rank + dstate * 2] - // 6. D [dim] - // 7. z [batch_size, max_seq_len, dim] or [num_tokens, dim] - // 8. host_request_types [batch_size] int32. 0: context; 1: generation. - // 9. last_token_ids [batch_size] int32 + // 1. state mamba: [batch_size, dstate, dim] or host [1] containing only pointer for paged_state + // mamba2: [batch_size, nheads, dstate, dim] or host [1] containing only pointer for paged_state + // 2. delta, mamba: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding + // mamba2: [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding + // 3. delta_bias, [dim] for mamba, [nheads] for mamba2 + // 4. A, [dstate, dim] for mamba, [nheads] for mamba2 + // 5. BC, mamba: [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding + // mamba2: [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for + // remove_input_padding + // 6. D, [dim] for mamba, [nheads] for mamba2 + // 7. host_request_types [batch_size] int32. 0: context; 1: generation. + // 8. last_token_ids [batch_size] int32 + // 9. host_context_lengths [batch_size] int32, optional for remove_input_padding // 10. state_slot_mapping [batch_size] int32, optional for paged state + // 11. z [batch_size, max_seq_len, dim] or [num_tokens, dim] // outputs // 0. output_tensor [batch_size, max_seq_len, dim] or [num_tokens, dim] - // 1. state [batch_size, dstate, dim] + // 1. state, [batch_size, dstate, dim] for mamba, [batch_size, nheads, dstate, dim] for mamba2 auto const batch_size = inputDesc[getHostRequestTypesIdx()].dims.d[0]; int max_seq_len; if (mRemovePadding) { - max_seq_len = -1; + int const* host_context_length = static_cast(inputs[getHostContextLengthIdx()]); + max_seq_len = *std::max_element(host_context_length, host_context_length + batch_size); } else { @@ -192,17 +254,72 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc SSMParamsBase ssm_params; int const* slotMapping = mPagedState ? static_cast(inputs[getSlotMappingIdx()]) : nullptr; + void const* z = mZEnabled ? inputs[getZIdx()] : nullptr; void* statePtr = mPagedState ? *reinterpret_cast(const_cast(inputs[getStateIdx()])) : outputs[1]; - setSSMParams(ssm_params, batch_size, mDim, max_seq_len, mDState, mDtRank, mIsVariableB, mIsVariableC, statePtr, + // Workspace pointer shift + int8_t* workspace_byte_ptr = reinterpret_cast(workspace); + size_t offset = 0; + + T* mxOs = nullptr; + float* mxSt = nullptr; + float* mxdc = nullptr; + float* mxdA = nullptr; + T* mxCB = nullptr; + + if (!mIsMamba2) /* no workspace needed */ + ; + else if (mRemovePadding) + { + int B = inputDesc[getLastTokenIdsIdx()].dims.d[0]; + int BxL = inputDesc[getInputTensorIdx()].dims.d[0]; // num_tokens + int H = mNHeads; + int P = inputDesc[getInputTensorIdx()].dims.d[1] / H; + int G = mNGroups; + int N = inputDesc[getBCIdx()].dims.d[1] / G / 2; + int Q = mChunkSize; + int BxC = (BxL + Q - 1) / Q + B; + + mxOs = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * N * P * 2)); + mxSt = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * N * P * 4)); + mxdc = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * Q * 4)); + mxdA = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * Q * 4)); + mxCB = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * G * Q * Q * 2)); + } + else + { + int B = inputDesc[getInputTensorIdx()].dims.d[0]; + int L = inputDesc[getInputTensorIdx()].dims.d[1]; + int H = mNHeads; + int P = inputDesc[getInputTensorIdx()].dims.d[2] / H; + int G = mNGroups; + int N = inputDesc[getBCIdx()].dims.d[2] / G / 2; + int Q = mChunkSize; + int C = (L + Q - 1) / Q; + + mxOs = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * N * P * 2)); + mxSt = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * N * P * 4)); + mxdc = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * Q * 4)); + mxdA = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * Q * 4)); + mxCB = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * G * Q * Q * 2)); + } + + setSSMParams(ssm_params, batch_size, mDim, max_seq_len, mDState, mDtRank, mNHeads, mNGroups, mChunkSize, statePtr, inputs[getInputTensorIdx()], inputs[getDeltaIdx()], inputs[getDeltaBiasIdx()], inputs[getAIdx()], - inputs[getBCIdx()], inputs[getDIdx()], inputs[getZIdx()], static_cast(inputs[getLastTokenIdsIdx()]), - slotMapping, outputs[0], mDeltaSoftplus, mRemovePadding); + inputs[getBCIdx()], inputs[getDIdx()], z, mxOs, mxSt, mxdc, mxdA, mxCB, + static_cast(inputs[getLastTokenIdsIdx()]), slotMapping, outputs[0], mDeltaSoftplus, mRemovePadding); if (reqTypes[0] == RequestType::kCONTEXT) { - invokeSelectiveScan(ssm_params, stream); + if (mIsMamba2) + { + invokeChunkScan(ssm_params, stream); + } + else + { + invokeSelectiveScan(ssm_params, stream); + } } else if (reqTypes[0] == RequestType::kGENERATION) { @@ -276,8 +393,9 @@ void SelectiveScanPlugin::terminate() noexcept {} size_t SelectiveScanPlugin::getSerializationSize() const noexcept { - return sizeof(mDim) + sizeof(mDState) + sizeof(mDtRank) + sizeof(mIsVariableB) + sizeof(mIsVariableC) - + sizeof(mDeltaSoftplus) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState); + return sizeof(mDim) + sizeof(mDState) + sizeof(mDtRank) + sizeof(mNHeads) + sizeof(mNGroups) + sizeof(mChunkSize) + + sizeof(mDeltaSoftplus) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState) + sizeof(mZEnabled) + + sizeof(mIsMamba2); } void SelectiveScanPlugin::serialize(void* buffer) const noexcept @@ -286,12 +404,15 @@ void SelectiveScanPlugin::serialize(void* buffer) const noexcept write(d, mDim); write(d, mDState); write(d, mDtRank); - write(d, mIsVariableB); - write(d, mIsVariableC); + write(d, mNHeads); + write(d, mNGroups); + write(d, mChunkSize); write(d, mDeltaSoftplus); write(d, mType); write(d, mRemovePadding); write(d, mPagedState); + write(d, mZEnabled); + write(d, mIsMamba2); assert(d == a + getSerializationSize()); } @@ -306,15 +427,18 @@ SelectiveScanPluginCreator::SelectiveScanPluginCreator() { // Fill PluginFieldCollection with PluginField arguments metadata mPluginAttributes.clear(); - mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 16)); - mPluginAttributes.emplace_back(PluginField("dstate", nullptr, PluginFieldType::kINT32, 16)); - mPluginAttributes.emplace_back(PluginField("dt_rank", nullptr, PluginFieldType::kINT32, 16)); - mPluginAttributes.emplace_back(PluginField("is_variable_B", nullptr, PluginFieldType::kINT8, 1)); - mPluginAttributes.emplace_back(PluginField("is_variable_C", nullptr, PluginFieldType::kINT8, 1)); + mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("dstate", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("dt_rank", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("nheads", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("ngroups", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("chunk_size", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("delta_softplus", nullptr, PluginFieldType::kINT8, 1)); mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 0)); - mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 0)); + mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 1)); + mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 1)); + mPluginAttributes.emplace_back(PluginField("z_enabled", nullptr, PluginFieldType::kINT8, 1)); + mPluginAttributes.emplace_back(PluginField("is_mamba2", nullptr, PluginFieldType::kINT8, 1)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); } @@ -337,8 +461,8 @@ PluginFieldCollection const* SelectiveScanPluginCreator::getFieldNames() noexcep IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { PluginField const* fields = fc->fields; - int dim, dstate, dtRank; - bool isVariableB, isVariableC, deltaSoftplus, removePadding, pagedState; + int dim, dstate, dtRank, nHeads, nGroups, chunkSize; + bool deltaSoftplus, removePadding, pagedState, zEnabled, isMamab2; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) @@ -359,15 +483,20 @@ IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFiel TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); dtRank = static_cast(*(static_cast(fields[i].data))); } - else if (!strcmp(attrName, "is_variable_B")) + else if (!strcmp(attrName, "nheads")) { - TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - isVariableB = static_cast(*(static_cast(fields[i].data))); + TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); + nHeads = static_cast(*(static_cast(fields[i].data))); } - else if (!strcmp(attrName, "is_variable_C")) + else if (!strcmp(attrName, "ngroups")) { - TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - isVariableC = static_cast(*(static_cast(fields[i].data))); + TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); + nGroups = static_cast(*(static_cast(fields[i].data))); + } + else if (!strcmp(attrName, "chunk_size")) + { + TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); + chunkSize = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "delta_softplus")) { @@ -389,11 +518,21 @@ IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFiel TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); pagedState = static_cast(*(static_cast(fields[i].data))); } + else if (!strcmp(attrName, "z_enabled")) + { + TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); + zEnabled = static_cast(*(static_cast(fields[i].data))); + } + else if (!strcmp(attrName, "is_mamba2")) + { + TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); + isMamab2 = static_cast(*(static_cast(fields[i].data))); + } } try { - auto* obj = new SelectiveScanPlugin( - dim, dstate, dtRank, isVariableB, isVariableC, deltaSoftplus, type, removePadding, pagedState); + auto* obj = new SelectiveScanPlugin(dim, dstate, dtRank, nHeads, nGroups, chunkSize, deltaSoftplus, type, + removePadding, pagedState, zEnabled, isMamab2); obj->setPluginNamespace(mNamespace.c_str()); return obj; } diff --git a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h index 97a0fefcf..06ce457ed 100644 --- a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h +++ b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h @@ -30,25 +30,30 @@ namespace tensorrt_llm::plugins // inputs // 0. input_tensor [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding -// 1. state [batch_size, dstate, dim] or host [1] containing only pointer for paged_state -// 2. delta [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding -// 3. delta_bias [dim] -// 4. A [dstate, dim] -// 5. BC [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding -// 6. D [dim] -// 7. z [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding -// 8. host_request_types [batch_size] int32. 0: context; 1: generation; 2: none. -// 9. last_token_ids [batch_size] int32 +// 1. state, mamba: [batch_size, dstate, dim] or host [1] containing only pointer for paged_state +// mamba2: [batch_size, nheads, dstate, dim] or host [1] containing only pointer for paged_state +// 2. delta, mamba: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding +// mamba2: [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding +// 3. delta_bias, [dim] for mamba, [nheads] for mamba2 +// 4. A, [dstate, dim] for mamba, [nheads] for mamba2 +// 5. BC, mamba: [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding +// mamba2: [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for +// remove_input_padding +// 6. D, [dim] for mamba, [nheads] for mamba2 +// 7. host_request_types [batch_size] int32. 0: context; 1: generation; 2: none. +// 8. last_token_ids [batch_size] int32 +// 9. host_context_lengths [batch_size] int32, optional for remove_input_padding // 10. state_slot_mapping [batch_size] int32, optional for paged state +// 11. z [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding // outputs // 0. output_tensor [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding -// 1. state [batch_size, dstate, dim] +// 1. state, [batch_size, dstate, dim] for mamba, [batch_size, nheads, dstate, dim] for mamba2 class SelectiveScanPlugin : public BasePlugin { public: - SelectiveScanPlugin(int dim, int dstate, int dt_rank, bool isVariableB, bool isVariableC, bool deltaSoftplus, - nvinfer1::DataType type, bool removePadding, bool pagedState); + SelectiveScanPlugin(int dim, int dstate, int dtRank, int nHeads, int nGroups, int chunkSize, bool deltaSoftplus, + nvinfer1::DataType type, bool removePadding, bool pagedState, bool zEnabled, bool isMamba2); SelectiveScanPlugin(void const* data, size_t length); @@ -128,45 +133,63 @@ class SelectiveScanPlugin : public BasePlugin return 6; }; - IndexType getZIdx() const + IndexType getHostRequestTypesIdx() const { return 7; }; - IndexType getHostRequestTypesIdx() const + IndexType getLastTokenIdsIdx() const { return 8; }; - IndexType getLastTokenIdsIdx() const + IndexType getHostContextLengthIdx() const { - return 9; + if (mRemovePadding) + return 9; + else + return 8; }; IndexType getSlotMappingIdx() const { - return 10; + if (mPagedState) + return getHostContextLengthIdx() + 1; + else + return getHostContextLengthIdx(); + }; + + IndexType getZIdx() const + { + if (mZEnabled) + return getSlotMappingIdx() + 1; + else + return getSlotMappingIdx(); }; void setSSMParams(tensorrt_llm::kernels::SSMParamsBase& params, // sizes const size_t batch, const size_t dim, const size_t maxSeqLen, const size_t dstate, const size_t dtRank, - bool const isVariableB, bool const isVariableC, + const size_t nHeads, const size_t nGroups, const size_t chunkSize, // device pointers void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* BC, - void const* D, void const* z, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus, + void const* D, void const* z, void const* osPtr, void const* stPtr, void const* dcPtr, void const* dAPtr, + void const* cbPtr, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus, bool removePadding); private: int mDim; int mDState; int mDtRank; - bool mIsVariableB; - bool mIsVariableC; + int mNHeads; + int mNGroups; + int mChunkSize; bool mDeltaSoftplus; nvinfer1::DataType mType; bool mRemovePadding = false; bool mPagedState = false; + bool mZEnabled = true; + bool mIsMamba2 = false; }; class SelectiveScanPluginCreator : public BaseCreator diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp index 9583bcfaa..eafea714f 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp @@ -300,6 +300,9 @@ int WeightOnlyQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc const* input int const n = TLLM_INT32_CAST(inputDesc[1].dims.d[1]); int const k = TLLM_INT32_CAST(inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]); + if (m == 0) + return 0; + bool const use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled; #if defined(ENABLE_BF16) TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16, diff --git a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp index 7fc401128..9db429d71 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp @@ -69,9 +69,9 @@ std::shared_ptr InferenceRequest::toTrtLlm() const auto inferenceRequest = std::make_shared(std::move(tensorMap), mRequestId); inferenceRequest->setIsStreaming(isStreaming()); - if (mlogitsPostProcessor) + if (mLogitsPostProcessor) { - inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mlogitsPostProcessor)); + inferenceRequest->setLogitsPostProcessor(LlmRequest::callbackAdapter(mLogitsPostProcessor)); } return inferenceRequest; @@ -79,7 +79,7 @@ std::shared_ptr InferenceRequest::toTrtLlm() const std::string InferenceRequest::serialize() const { - TLLM_CHECK_WITH_INFO(mlogitsPostProcessor == std::nullopt, + TLLM_CHECK_WITH_INFO(mLogitsPostProcessor == std::nullopt, "Serializing InferenceRequest with logitsPostProcessor set is not supported." "Please set the callback after de-serialization"); std::vector serialized{toTrtLlm()->serialize()}; diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index 627a8678b..b04114ee3 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -228,15 +228,14 @@ void InitBindings(pybind11::module_& m) std::optional const&, std::optional const&, std::optional>, std::optional>, std::optional, std::optional, std::optional, - std::optional, std::optional, std::optional, bool>(), + std::optional, std::optional, std::optional>(), py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false, py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"), py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(), py::arg("pad_id") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(), py::arg("embedding_bias") = py::none(), py::arg("external_draft_tokens_config") = py::none(), py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none(), - py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none(), - py::arg("return_all_generated_tokens") = false) + py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none()) .def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds) .def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens) .def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -255,9 +254,7 @@ void InitBindings(pybind11::module_& m) .def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName, &tle::Request::setLogitsPostProcessorName) .def_property( - "encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds) - .def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens, - &tle::Request::setReturnAllGeneratedTokens); + "encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds); request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName; py::class_(m, "Result") diff --git a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp index 87fb874d6..608e1ccbc 100644 --- a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp +++ b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp @@ -46,8 +46,8 @@ FieldType parseJsonFieldOr(Json const& json, std::string_view name, FieldType de } catch (nlohmann::json::out_of_range& e) { - TLLM_LOG_INFO("Parameter %s cannot be read from json:", std::string(name).c_str()); - TLLM_LOG_INFO(e.what()); + TLLM_LOG_DEBUG("Parameter %s cannot be read from json:", std::string(name).c_str()); + TLLM_LOG_DEBUG(e.what()); } return value; } @@ -62,13 +62,13 @@ std::optional parseJsonFieldOptional(Json const& json, std::string_vi } catch (nlohmann::json::out_of_range const& e) { - TLLM_LOG_INFO(e.what()); - TLLM_LOG_INFO("Optional value for parameter %s will not be set.", std::string(name).c_str()); + TLLM_LOG_DEBUG(e.what()); + TLLM_LOG_DEBUG("Optional value for parameter %s will not be set.", std::string(name).c_str()); } catch (nlohmann::json::type_error const& e) { - TLLM_LOG_INFO(e.what()); - TLLM_LOG_INFO("Optional value for parameter %s will not be set.", std::string(name).c_str()); + TLLM_LOG_DEBUG(e.what()); + TLLM_LOG_DEBUG("Optional value for parameter %s will not be set.", std::string(name).c_str()); } return value; } @@ -427,10 +427,17 @@ GptJsonConfig parseJson(InputType&& input) auto const& stateSize = pretrainedConfig.at("state_size").template get(); auto const& convKernel = pretrainedConfig.at("conv_kernel").template get(); auto const& rnnHiddenSize = pretrainedConfig.at("rnn_hidden_size").template get(); + auto const& rnnConvDimSize = pretrainedConfig.at("rnn_conv_dim_size").template get(); ModelConfig::RnnConfig rnnConfig{}; rnnConfig.stateSize = stateSize; rnnConfig.convKernel = convKernel; rnnConfig.rnnHiddenSize = rnnHiddenSize; + rnnConfig.rnnConvDimSize = rnnConvDimSize; + if (pretrainedConfig.contains("rnn_head_size")) + { + auto const& rnnHeadSize = pretrainedConfig.at("rnn_head_size").template get(); + rnnConfig.rnnHeadSize = rnnHeadSize; + } modelConfig.setRnnConfig(rnnConfig); } } @@ -449,10 +456,17 @@ GptJsonConfig parseJson(InputType&& input) auto const& stateSize = builderConfig.at("state_size").template get(); auto const& convKernel = builderConfig.at("conv_kernel").template get(); auto const& rnnHiddenSize = builderConfig.at("rnn_hidden_size").template get(); + auto const& rnnConvDimSize = builderConfig.at("rnn_conv_dim_size").template get(); ModelConfig::RnnConfig rnnConfig{}; rnnConfig.stateSize = stateSize; rnnConfig.convKernel = convKernel; rnnConfig.rnnHiddenSize = rnnHiddenSize; + rnnConfig.rnnConvDimSize = rnnConvDimSize; + if (builderConfig.contains("rnn_head_size")) + { + auto const& rnnHeadSize = builderConfig.at("rnn_head_size").template get(); + rnnConfig.rnnHeadSize = rnnHeadSize; + } modelConfig.setRnnConfig(rnnConfig); } } diff --git a/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp b/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp index aa086892f..c4f9d888b 100644 --- a/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp +++ b/cpp/tensorrt_llm/runtime/rnnStateBuffers.cpp @@ -37,7 +37,7 @@ RnnStateBuffers::RnnStateBuffers( { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(modelConfig.isRnnBased()); - TLLM_CHECK_WITH_INFO(modelConfig.hasRnnConfig(), "RNN only support Mamba now."); + TLLM_CHECK_WITH_INFO(modelConfig.hasRnnConfig(), "RNN only support Mamba1/Mamba2/RecurrentGemma now."); auto maxBatchSize = modelConfig.getMaxBatchSize(); auto maxBeamWidth = modelConfig.getMaxBeamWidth(); auto maxBatchBeam = maxBatchSize * maxBeamWidth; @@ -46,23 +46,37 @@ RnnStateBuffers::RnnStateBuffers( mConvKernel = rnnConfig->convKernel; mStateSize = rnnConfig->stateSize; mRnnHiddenSize = rnnConfig->rnnHiddenSize; + mRnnHeadSize = rnnConfig->rnnHeadSize; + mRnnConvDimSize = rnnConfig->rnnConvDimSize; auto dType = modelConfig.getDataType(); auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism()); mLocalNbLayers = localNbLayers; mMaxBeamWidth = maxBeamWidth; mUseMambaConv1dPlugin = modelConfig.useMambaConv1dPlugin(); - auto rnnStatesShape = ITensor::makeShape({localNbLayers * maxBatchBeam, mStateSize, mRnnHiddenSize}); + auto const rnnStatesShape = [&]() + { + if (mRnnHeadSize > 0) + { + return tensorrt_llm::runtime::ITensor::makeShape( + {localNbLayers * maxBatchBeam, mRnnHiddenSize / mRnnHeadSize, mStateSize, mRnnHeadSize}); + } + else + { + return tensorrt_llm::runtime::ITensor::makeShape( + {localNbLayers * maxBatchBeam, mStateSize, mRnnHiddenSize}); + } + }(); auto const convStatesShape = [&]() { if (mUseMambaConv1dPlugin) { return tensorrt_llm::runtime::ITensor::makeShape( - {localNbLayers * maxBatchBeam, mConvKernel - 1, mRnnHiddenSize}); + {localNbLayers * maxBatchBeam, mConvKernel - 1, mRnnConvDimSize}); } else { return tensorrt_llm::runtime::ITensor::makeShape( - {localNbLayers * maxBatchBeam, mRnnHiddenSize, mConvKernel - 1}); + {localNbLayers * maxBatchBeam, mRnnConvDimSize, mConvKernel - 1}); } }(); auto& bufferManager = runtime.getBufferManager(); @@ -96,18 +110,30 @@ RnnStateBuffers::RnnStateBuffers( void RnnStateBuffers::reshape(SizeType32 batchSize) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); - auto rnnStatesShape = ITensor::makeShape({mLocalNbLayers * batchSize * mMaxBeamWidth, mStateSize, mRnnHiddenSize}); + auto const rnnStatesShape = [&]() + { + if (mRnnHeadSize > 0) + { + return tensorrt_llm::runtime::ITensor::makeShape( + {mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnHiddenSize / mRnnHeadSize, mStateSize, mRnnHeadSize}); + } + else + { + return tensorrt_llm::runtime::ITensor::makeShape( + {mLocalNbLayers * batchSize * mMaxBeamWidth, mStateSize, mRnnHiddenSize}); + } + }(); auto const convStatesShape = [&]() { if (mUseMambaConv1dPlugin) { return tensorrt_llm::runtime::ITensor::makeShape( - {mLocalNbLayers * batchSize * mMaxBeamWidth, mConvKernel - 1, mRnnHiddenSize}); + {mLocalNbLayers * batchSize * mMaxBeamWidth, mConvKernel - 1, mRnnConvDimSize}); } else { return tensorrt_llm::runtime::ITensor::makeShape( - {mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnHiddenSize, mConvKernel - 1}); + {mLocalNbLayers * batchSize * mMaxBeamWidth, mRnnConvDimSize, mConvKernel - 1}); } }(); rnnStates->reshape(rnnStatesShape); diff --git a/cpp/tensorrt_llm/runtime/rnnStateBuffers.h b/cpp/tensorrt_llm/runtime/rnnStateBuffers.h index fc29ba6f0..7f9d3ac7a 100644 --- a/cpp/tensorrt_llm/runtime/rnnStateBuffers.h +++ b/cpp/tensorrt_llm/runtime/rnnStateBuffers.h @@ -39,7 +39,8 @@ class RnnStateBuffers TensorPtr convStates; // [layer_count * batch_beam, conv_kernel - 1, rnn_hidden_size] TensorPtr convStatesAlt; // [layer_count * batch_beam, conv_kernel - 1, rnn_hidden_size] - std::vector rnnState; // [batch_beam, state_size, rnn_hidden_size] + std::vector rnnState; // [batch_beam, state_size, rnn_hidden_size] or + // [batch_beam, num_heads, rnn_hidden_size, rnn_head_size] std::vector convState; // [batch_beam, conv_kernel - 1, rnn_hidden_size] std::vector convStateAlt; // [batch_beam, conv_kernel - 1, rnn_hidden_size] @@ -83,6 +84,8 @@ class RnnStateBuffers SizeType32 mConvKernel = 0; SizeType32 mStateSize = 0; SizeType32 mRnnHiddenSize = 0; + SizeType32 mRnnHeadSize = 0; + SizeType32 mRnnConvDimSize = 0; int mLocalNbLayers = 0; int mMaxBeamWidth = 0; diff --git a/cpp/tests/resources/scripts/generate_test_lora_weights.py b/cpp/tests/resources/scripts/generate_test_lora_weights.py index 9a5fe779b..f93ebc768 100644 --- a/cpp/tests/resources/scripts/generate_test_lora_weights.py +++ b/cpp/tests/resources/scripts/generate_test_lora_weights.py @@ -120,13 +120,27 @@ def main(): parser.add_argument('--tp-size', type=int, default=1) parser.add_argument('--out-dir', type=Path, required=True) parser.add_argument('--num-loras', type=int, default=1) + parser.add_argument('--num-layers', type=int, default=2) + parser.add_argument('--adapter-size', type=int, default=8) + parser.add_argument('--hidden-size', type=int, default=16) + parser.add_argument('--mlp-hidden-size', type=int, default=32) + parser.add_argument('--no-generate-cache-pages', + action='store_true', + default=False) + parser.add_argument( + '--config-ids-filter', + type=str, + default=None, + help= + "Comma separated list of ids to include. For example, use --config-ids-filter=0 for attn_qkv only." + ) args = parser.parse_args() - num_layers = 2 - adapter_size = 8 - hidden_size = 16 - mlp_hidden_size = 32 + num_layers = args.num_layers + adapter_size = args.adapter_size + hidden_size = args.hidden_size + mlp_hidden_size = args.mlp_hidden_size configs = [ (0, num_layers, adapter_size, hidden_size, 3 * hidden_size), # attn_qkv (1, num_layers, adapter_size // 2, hidden_size, hidden_size), # attn_q @@ -149,6 +163,9 @@ def main(): (12, num_layers, adapter_size, hidden_size, hidden_size), # cross_attn_dense ] + if args.config_ids_filter: + config_ids_filter = [int(x) for x in args.config_ids_filter.split(",")] + configs = [c for c in configs if c[0] in config_ids_filter] for lora_idx in range(args.num_loras): all_source = [] @@ -178,19 +195,20 @@ def main(): os.makedirs(output_dir, exist_ok=True) # copy weights into cache pages - for rank in range(args.tp_size): - page_block = torch.zeros((8, 18, 128), - dtype=torch.float32, - device='cpu') - copy_to_cache_pages(all_source, - all_config, - page_block, - configs, - tp_rank=rank, - tp_size=args.tp_size) - - out_path = output_dir / f'cache_pages_rank{rank}.npy' - np.save(out_path, page_block) + if not args.no_generate_cache_pages: + for rank in range(args.tp_size): + page_block = torch.zeros((8, 18, 128), + dtype=torch.float32, + device='cpu') + copy_to_cache_pages(all_source, + all_config, + page_block, + configs, + tp_rank=rank, + tp_size=args.tp_size) + + out_path = output_dir / f'cache_pages_rank{rank}.npy' + np.save(out_path, page_block) source_out_path = output_dir / 'source.npy' config_out_path = output_dir / 'config.npy' diff --git a/cpp/tests/resources/scripts/test_cpp.py b/cpp/tests/resources/scripts/test_cpp.py index 07820fcae..9a3d97744 100755 --- a/cpp/tests/resources/scripts/test_cpp.py +++ b/cpp/tests/resources/scripts/test_cpp.py @@ -15,6 +15,7 @@ # limitations under the License. import argparse as _arg +import copy import logging as _log import os as _os import pathlib as _pl @@ -135,9 +136,18 @@ def run_tests(build_dir: _pl.Path, "--num-loras=128", ] + generate_gpt2_lora_data_args_tp1 = [ + python_exe, + str(resources_dir / "scripts" / "generate_test_lora_weights.py"), + "--out-dir=cpp/tests/resources/data/lora-test-weights-gpt2-tp1", + "--tp-size=1", "--hidden-size=768", "--num-layers=12", + "--config-ids-filter=0", "--no-generate-cache-pages" + ] + run_command(generate_lora_data_args_tp1, cwd=root_dir, timeout=100) run_command(generate_lora_data_args_tp2, cwd=root_dir, timeout=100) run_command(generate_multi_lora_tp2_args, cwd=root_dir, timeout=100) + run_command(generate_gpt2_lora_data_args_tp1, cwd=root_dir, timeout=100) if not skip_unit_tests: run_unit_tests(build_dir=build_dir, timeout=test_timeout) @@ -484,9 +494,15 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500): ] run_command(trt_model_test, cwd=tests_dir, env=cpp_env, timeout=timeout) # expecting ~ 1200s + cpp_blocking_env = copy.copy(cpp_env) + cpp_blocking_env["CUDA_LAUNCH_BLOCKING"] = '1' + run_command(trt_model_test, + cwd=tests_dir, + env=cpp_blocking_env, + timeout=timeout) # expecting ~ 1200s #Executor test in leader mode - new_env = cpp_env + new_env = copy.copy(cpp_env) xml_output_file = build_dir / "results-multi-gpu-llama-exec-leader-mode.xml" new_env["RUN_LLAMA_MULTI_GPU"] = "true" trt_model_test = [ @@ -507,7 +523,7 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500): run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500) #EncDec test in leader mode - new_env = cpp_env + new_env = copy.copy(cpp_env) xml_output_file = build_dir / "results-multi-gpu-t5-exec-leader-mode.xml" trt_model_test = [ "mpirun", "-n", "4", "--allow-run-as-root", "executor/executorTest", diff --git a/docs/source/executor.md b/docs/source/executor.md index 5d556590f..f30bf9745 100644 --- a/docs/source/executor.md +++ b/docs/source/executor.md @@ -52,7 +52,7 @@ The `awaitResponses` method of the `Executor` class returns a vector of response ### The Result Class -The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = false` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request. +The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = true` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request. ## C++ Executor API Example diff --git a/docs/source/quick-start-guide.md b/docs/source/quick-start-guide.md index e1d26579d..37eaa643f 100644 --- a/docs/source/quick-start-guide.md +++ b/docs/source/quick-start-guide.md @@ -23,7 +23,7 @@ git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf (quick-start-guide-compile)= ## Compile the Model into a TensorRT Engine -Use the included [Llama model definition](https://nvidia.github.io/TensorRT-LLM/_modules/tensorrt_llm/models/llama/model.html#LLaMAModel). This is a minimal example that includes some of the optimizations available in TensorRT-LLM. +Use the included [Llama model definition](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama). This is a minimal example that includes some of the optimizations available in TensorRT-LLM. ```bash # Launch the Tensorrt-LLM container @@ -138,3 +138,7 @@ In this Quick Start Guide, you: For more examples, refer to: - [examples/](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples) for showcases of how to run a quick benchmark on latest LLMs. + +## Links + - [Best Practices Guide](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/performance/perf-best-practices.md) + - [Support Matrix](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html) diff --git a/docs/source/reference/support-matrix.md b/docs/source/reference/support-matrix.md index 12b75cdf6..f0db748cb 100644 --- a/docs/source/reference/support-matrix.md +++ b/docs/source/reference/support-matrix.md @@ -83,7 +83,7 @@ The following table shows the supported software for TensorRT-LLM. - [mT5](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/enc_dec) - [OPT](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/opt) - [Phi-1.5/Phi-2/Phi-3](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/phi) - - [Qwen](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen) + - [Qwen/Qwen1.5/Qwen2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen) - [Qwen-VL](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwenvl) - [RecurrentGemma](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/recurrentgemma) - [Replit Code](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/mpt) @@ -103,6 +103,7 @@ The following table shows the supported software for TensorRT-LLM. - [Fuyu](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal) - [Kosmos](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal) - [LLaVA-v1.5](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal) + - [LLaVa-Next](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal) - [NeVA](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal) - [Nougat](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal) - [Phi-3-vision](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal) diff --git a/examples/baichuan/requirements.txt b/examples/baichuan/requirements.txt index 663a3c11b..01fb6963c 100644 --- a/examples/baichuan/requirements.txt +++ b/examples/baichuan/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.15.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/bloom/requirements.txt b/examples/bloom/requirements.txt index 19db45403..67ff309f0 100644 --- a/examples/bloom/requirements.txt +++ b/examples/bloom/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/chatglm/requirements.txt b/examples/chatglm/requirements.txt index ef30de859..3f729c739 100644 --- a/examples/chatglm/requirements.txt +++ b/examples/chatglm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 evaluate~=0.4.1 protobuf diff --git a/examples/dbrx/requirements.txt b/examples/dbrx/requirements.txt index 44e7c6b65..537b594db 100644 --- a/examples/dbrx/requirements.txt +++ b/examples/dbrx/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/enc_dec/run.py b/examples/enc_dec/run.py index 6837e1539..d0d829987 100644 --- a/examples/enc_dec/run.py +++ b/examples/enc_dec/run.py @@ -29,8 +29,10 @@ import tensorrt_llm from tensorrt_llm import logger +from tensorrt_llm._ipc_utils import set_peer_access from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch from tensorrt_llm.lora_manager import LoraManager +from tensorrt_llm.plugin.plugin import CustomAllReduceHelper from tensorrt_llm.runtime import ModelConfig, SamplingConfig @@ -387,19 +389,27 @@ def encoder_run(self, (max_input_length, ), dtype=hidden_states_dtype('max_input_length'), device=self.device).contiguous() - batch_size = input_lengths.size(0) - inputs['host_request_types'] = torch.IntTensor([0] * - batch_size).to('cpu') - if self.encoder_model_config.remove_input_padding: - inputs['host_context_lengths'] = input_lengths.to('cpu') - if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None: + if self.encoder_model_config.use_custom_all_reduce and self.encoder_runtime_mapping.tp_size > 1: + set_peer_access(self.encoder_runtime_mapping) + ipc_buffers, all_reduce_workspace = CustomAllReduceHelper.allocate_workspace( + self.encoder_runtime_mapping, + CustomAllReduceHelper.max_workspace_size_auto( + self.encoder_runtime_mapping.tp_size)) + inputs['all_reduce_workspace'] = all_reduce_workspace + + if self.encoder_model_config.lora_plugin: inputs.update( self.encoder_lora_manager.input_buffers( self.lora_task_uids, self.encoder_runtime_mapping, self.encoder_model_config.num_layers, )) + batch_size = input_lengths.size(0) + inputs['host_request_types'] = torch.IntTensor([0] * + batch_size).to('cpu') + if self.encoder_model_config.remove_input_padding: + inputs['host_context_lengths'] = input_lengths.to('cpu') # Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape self.encoder_session.set_shapes(inputs) diff --git a/examples/falcon/requirements.txt b/examples/falcon/requirements.txt index 44c2c4d93..452945de7 100644 --- a/examples/falcon/requirements.txt +++ b/examples/falcon/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 transformers>=4.31.0 datasets~=2.14.5 evaluate~=0.4.1 diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt index 888837259..72742429d 100644 --- a/examples/gemma/requirements.txt +++ b/examples/gemma/requirements.txt @@ -3,7 +3,7 @@ # WAR the new posting of "nvidia-cudnn-cu12~=9.0". # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9". nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64" -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 flax~=0.8.0 # jax[cuda12_pip]~=0.4.19; platform_system != "Windows" jax~=0.4.19; platform_system == "Windows" diff --git a/examples/gpt/README.md b/examples/gpt/README.md index a750a2827..ef5a08b32 100644 --- a/examples/gpt/README.md +++ b/examples/gpt/README.md @@ -414,6 +414,21 @@ trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8-sq-ptpc/1-gpu \ Note that GPT attention plugin is required to be enabled for SmoothQuant for now. +User can also use `ModelOpt` to do INT8 quantization. Especially for gpt variant Starcoder2. +```bash +python3 example/quantization/quantize.py --model_dir starcoder2 \ + --dtype float16 \ + --qformat int8_sq \ + --output_dir starcoder2/trt_ckpt/int8-sq/ +``` +Then, use `trtllm-build` to build engine(s). + +```bash +trtllm-build --checkpoint_dir starcoder2/trt_ckpt/int8-sq/ \ + --output_dir starcoder2/trt_engine/int8-sq/ \ + --builder_opt 4 +``` + ### INT8 KV Cache diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index 9c1091653..5d33d28d0 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptj/requirements.txt b/examples/gptj/requirements.txt index 7800b0ea6..5e5ac6a15 100644 --- a/examples/gptj/requirements.txt +++ b/examples/gptj/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptneox/requirements.txt b/examples/gptneox/requirements.txt index aa242a331..4e0e428dd 100644 --- a/examples/gptneox/requirements.txt +++ b/examples/gptneox/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 rouge_score~=0.1.2 evaluate~=0.4.1 diff --git a/examples/grok/requirements.txt b/examples/grok/requirements.txt index 9f62ca868..33dc6547f 100644 --- a/examples/grok/requirements.txt +++ b/examples/grok/requirements.txt @@ -1,6 +1,6 @@ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/high-level-api/README.md b/examples/high-level-api/README.md index 6e46bb669..8c0d10eed 100644 --- a/examples/high-level-api/README.md +++ b/examples/high-level-api/README.md @@ -4,12 +4,49 @@ Here we show you a preview of how it works and how to use it. Note that the APIs are not stable and only support the LLaMA model. We appreciate your patience and understanding as we improve this API. +## Quick start + Please install the required packages first: ```bash pip install -r requirements.txt ``` +Here is a simple example to show how to use the HLAPI: + +Firstly, import the `LLM` and `SamplingParams` from the `tensorrt_llm` package, and create an LLM object with a HuggingFace (HF) model directly. Here we use the TinyLlama model as an example, `LLM` will download the model from the HuggingFace model hub automatically. You can also specify local models, either in HF format, TensorRT-LLM engine format or TensorRT-LLM checkpoint format. + +```python +from tensorrt_llm import LLM, SamplingParams + +llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") +``` + +Secondly, generate text with the `generate` method of the `LLM` object directly with a batch of prompts, the `sampling_params` is optional, and you can customize the sampling strategy with it. + +```python +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +outputs = llm.generate(prompts, sampling_params) + +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +``` + +Please refer to the [LLM quickstart](./quickstart_example.py) for the complete example. + +## Examples + You can refer to [llm_examples.py](llm_examples.py) for all of the examples, and run it with the [run_examples.py](./run_examples.py) script, the command is as follows: ```sh @@ -34,15 +71,16 @@ python3 llm_examples.py --task run_llm_on_tensor_parallel \ ``` ## Model preparation -The HLAPI supports three kinds of model formats: +The `LLM` class supports four kinds of model inputs: -1. HuggingFace models -2. TensorRT-LLM engine built by trtllm-build tool or saved by the HLAPI -3. TensorRT-LLM checkpoints, converted by `convert_checkpoint.py` in examples +1. **HuggingFace model name**: triggers a download from the HuggingFace model hub, e.g. `TinyLlama/TinyLlama-1.1B-Chat-v1.0` in the quickstart. +1. **Local HuggingFace models**: uses a locally stored HuggingFace model. +2. **Local TensorRT-LLM engine**: built by `trtllm-build` tool or saved by the HLAPI +3. **Local TensorRT-LLM checkpoints**: converted by `convert_checkpoint.py` script in the examples -All kinds of models could be used directly by the HLAPI, and the `LLM(model=)` could accept any kind of them. +All kinds of the model inputs can be seamlessly integrated with the HLAPI, and the `LLM(model=)` construcotr can accommodate models in any of the above formats. -Let's elaborate on the preparation of the three kinds of model formats. +Let's delve into the preparation of the three kinds of local model formats. ### Option 1: From HuggingFace models @@ -143,16 +181,16 @@ It is easy to enable Tensor Parallelism in the HLAPI. For example, setting `para ```python from tensorrt_llm.hlapi import LLM -llm = LLM(, tensor_parallel_size=2) +llm = LLM(, + tensor_parallel_size=2) ``` ### Pipeline Parallelism Similar to Tensor Parallelism, you can enable Pipeline Parallelism in the HLAPI with following code: ```python -config.parallel_config.pp_size = 4 -# you can also mix TP and PP -# config.parallel_config.tp_size = 2 +llm = LLM(, + pipeline_parallel_size=4) ``` ### Automatic Parallelism (in preview) @@ -266,17 +304,28 @@ Please refer to these classes for more details. ## LLM pipeline configuration +### Build configuration +Apart from the arguments mentioned above, you can also customize the build configuration with the `build_config` class and other arguments borrowed from the lower-level APIs. For example: + +```python +llm = LLM(, + build_config=BuildConfig( + max_new_tokens=4096, + max_batch_size=128, + max_beam_width=4)) +``` + ### Runtime customization +Similar to `build_config`, you can also customize the runtime configuration with the `runtime_config`, `peft_cache_config` or other arguments borrowed from the lower-level APIs. For example: -For `kv_cache_config` and `streaming_llm` features, please refer to LLaMA's [README](../llama/README.md) for more details, the high-level API supports these features as well by setting the corresponding fields in the `LLM()` constructor. ```python from tensorrt_llm.hlapi import LLM, KvCacheConfig llm = LLM(, kv_cache_config=KvCacheConfig( - max_new_tokens=128, - free_gpu_memory_fraction=0.8)) + max_new_tokens=128, + free_gpu_memory_fraction=0.8)) ``` ### Tokenizer customization @@ -313,3 +362,13 @@ RequestOutput(request_id=1, prompt=None, prompt_token_ids=[1, 15043, 29892, 590, ``` Note that the `text` field in `CompletionOutput` is empty since the tokenizer is deactivated. + +### Build caching +Although the HLAPI runs the engine building in the background, you can also cache the built engine to disk and load it in the next run to save the engine building time. + +To enable the build cache, there are two ways to do it: + +1. Use the environment variable: `export TLLM_HLAPI_BUILD_CACHE=1` to enable the build cache globally, and optionally export `TLLM_HLAPI_BUILD_CACHE_ROOT` to specify the cache root directory. +2. Pass the `build_cache_config` to the `LLM` constructor + +The build cache will reuse the built engine if all the building settings are the same, or it will rebuild the engine. diff --git a/examples/high-level-api/requirements.txt b/examples/high-level-api/requirements.txt index 476362335..ba11d627d 100644 --- a/examples/high-level-api/requirements.txt +++ b/examples/high-level-api/requirements.txt @@ -1,2 +1,2 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 diff --git a/examples/internlm/requirements.txt b/examples/internlm/requirements.txt index 2dd11374a..07f8580bf 100644 --- a/examples/internlm/requirements.txt +++ b/examples/internlm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets==2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/llama/README.md b/examples/llama/README.md index 97cb61177..562200adb 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -440,10 +440,16 @@ expected results: #### 1M long context test case -```bash -git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k/ +- Prepare 1M needle-in-a-haystack datasets +```bash python examples/infinitebench/construct_synthetic_dataset.py --test_case build_passkey --test_level 7 +``` + +- Llama-3-8B example + +```bash +git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k/ python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ --output_dir /tmp/llama-3-8B-1048k/trt_ckpts \ @@ -454,8 +460,8 @@ python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-8B-1048k/trt --output_dir /tmp/llama-3-8B-1048k/trt_engines \ --gemm_plugin float16 \ --max_num_tokens 4096 \ - --max_input_len 1048576 \ - --max_output_len 10 \ + --max_input_len 1048566 \ + --max_seq_len 1048576 \ --use_paged_context_fmha enable \ --workers 4 @@ -463,7 +469,37 @@ mpirun -n 4 --allow-run-as-root python examples/eval_long_context.py --task pas --engine_dir /tmp/llama-3-8B-1048k/trt_engines \ --tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ --stop_idx 1 \ - --max_input_length 1048576 \ + --max_input_length 1048566 \ + --enable_chunked_context \ + --max_tokens_in_paged_kv_cache 1100000 +``` + +- Llama-3-70B example + +For the 70B model, at least 8 A100 80GB GPUs are required. + +```bash +git-lfs clone https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k/ + +python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-70B-Instruct-Gradient-1048k/ \ + --output_dir /tmp/llama-3-70B-1048k/trt_ckpts \ + --dtype float16 \ + --tp_size 8 + +python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-70B-1048k/trt_ckpts \ + --output_dir /tmp/llama-3-70B-1048k/trt_engines \ + --gemm_plugin float16 \ + --max_num_tokens 4096 \ + --max_input_len 1048566 \ + --max_seq_len 1048576 \ + --use_paged_context_fmha enable \ + --workers 8 + +mpirun -n 8 --allow-run-as-root python examples/eval_long_context.py --task passkey \ + --engine_dir /tmp/llama-3-70B-1048k/trt_engines \ + --tokenizer_dir ./Llama-3-70B-Instruct-Gradient-1048k/ \ + --stop_idx 1 \ + --max_input_length 1048566 \ --enable_chunked_context \ --max_tokens_in_paged_kv_cache 1100000 ``` diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index f9f50349b..7c473affe 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/mamba/README.md b/examples/mamba/README.md index d3cfd78d4..5475b3e4d 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -2,6 +2,15 @@ This document shows how to build and run a [Mamba](https://github.com/state-spaces/mamba) model in TensorRT-LLM on a single GPU. +- [Mamba](#mamba) + - [Overview](#overview) + - [Support Matrix](#support-matrix) + - [Usage](#usage) + - [1. Download weights from HuggingFace Transformers](#1-download-weights-from-huggingface-transformers) + - [2. Convert weights from HF Transformers to TensorRT-LLM format](#2-convert-dweights-from-hf-transformers-to-tensorrt-llm-format) + - [3. Build TensorRT engine(s)](#3-build-tensorrt-engines) + - [4. Run summarization task with the TensorRT engine(s)](#4-run-summarization-task-with-the-tensorrt-engines) + ## Overview The TensorRT-LLM Mamba implementation can be found in [`tensorrt_llm/models/mamba/model.py`](../../tensorrt_llm/models/mamba/model.py). The TensorRT-LLM Mamba example code is located in [`examples/mamba`](./). There is one main file: @@ -15,8 +24,13 @@ In addition, there are two shared files in the parent folder [`examples`](../) f ## Support Matrix - * FP16 - * BF16 + +| Model Name | FP16 | BF16 | +| :--------------: | :---: | :---: | +| Mamba1 | Y | Y | +| Mamba2 | Y | Y | + +* Mamba2: TensorRT-LLM can only support the pure Mamba model for now, will support the hybrid models later. ## Usage @@ -32,23 +46,20 @@ pip install -r requirements.txt git lfs install ``` -There are five HF checkpoints available. Use one of the following commands to fetch the checkpoint you are interested in. +There are different HF checkpoints available. For Mamba1, TensorRT-LLM can support those Transformers compatible models. Here're some examples to fetch the checkpoint. ```bash # mamba-2.8b git clone https://huggingface.co/state-spaces/mamba-2.8b-hf ./mamba_model/mamba-2.8b -# mamba-1.4b -git clone https://huggingface.co/state-spaces/mamba-1.4b-hf ./mamba_model/mamba-1.4b - -# mamba-790m -git clone https://huggingface.co/state-spaces/mamba-790m-hf ./mamba_model/mamba-790m - -# mamba-370m -git clone https://huggingface.co/state-spaces/mamba-370m-hf ./mamba_model/mamba-370m - # mamba-130m git clone https://huggingface.co/state-spaces/mamba-130m-hf ./mamba_model/mamba-130m + +# mamba2-2.7b +git clone https://huggingface.co/state-spaces/mamba2-2.7b ./mamba_model/mamba2-2.7b + +# mamba2-130m +git clone https://huggingface.co/state-spaces/mamba2-130m ./mamba_model/mamba2-130m ``` Since mamba models use tokenizer from gpt-neox-20b model, use the following command to fetch the checkpoint of gpt-neox-20b. @@ -67,25 +78,20 @@ python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b/ \ --dtype bfloat16 \ --output_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ -# mamba-1.4b -python convert_checkpoint.py --model_dir ./mamba_model/mamba-1.4b/ \ - --dtype float16 \ - --output_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/ - -# mamba-790m -python convert_checkpoint.py --model_dir ./mamba_model/mamba-790m/ \ +# mamba-130m +python convert_checkpoint.py --model_dir ./mamba_model/mamba-130m/ \ --dtype float16 \ - --output_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ -# mamba-370m -python convert_checkpoint.py --model_dir ./mamba_model/mamba-370m/ \ +# mamba2-2.7b +python convert_checkpoint.py --model_dir ./mamba_model/mamba2-2.7b/ \ --dtype float16 \ - --output_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/ + --output_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/ -# mamba-130m -python convert_checkpoint.py --model_dir ./mamba_model/mamba-130m/ \ +# mamba2-130m +python convert_checkpoint.py --model_dir ./mamba_model/mamba2-130m/ \ --dtype float16 \ - --output_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ + --output_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/ ``` ### 3. Build TensorRT engine(s) @@ -101,41 +107,32 @@ trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \ --max_seq_len 1024 \ --output_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/ -# mamba-1.4b -trtllm-build --checkpoint_dir ./mamba_model/mamba-1.4b/trt_ckpt/fp16/1-gpu/ \ - --paged_kv_cache disable \ - --gemm_plugin auto \ - --max_batch_size 8 \ - --max_input_len 924 \ - --max_seq_len 1024 \ - --output_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/ - -# mamba-790m -trtllm-build --checkpoint_dir ./mamba_model/mamba-790m/trt_ckpt/fp16/1-gpu/ \ +# mamba-130m +trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \ --paged_kv_cache disable \ --gemm_plugin auto \ --max_batch_size 8 \ --max_input_len 924 \ --max_seq_len 1024 \ - --output_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/ + --output_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/ -# mamba-370m -trtllm-build --checkpoint_dir ./mamba_model/mamba-370m/trt_ckpt/fp16/1-gpu/ \ +# mamba2-2.7b +trtllm-build --checkpoint_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/ \ --paged_kv_cache disable \ --gemm_plugin auto \ --max_batch_size 8 \ --max_input_len 924 \ --max_seq_len 1024 \ - --output_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/ + --output_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/ -# mamba-130m -trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \ +# mamba2-130m +trtllm-build --checkpoint_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/ \ --paged_kv_cache disable \ --gemm_plugin auto \ --max_batch_size 8 \ --max_input_len 924 \ --max_seq_len 1024 \ - --output_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/ + --output_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/ ``` Note that when building Mamba models, you need to disable the `paged_kv_cache` as it is used for @@ -148,7 +145,6 @@ The following section describes how to run a TensorRT-LLM Mamba model to summari [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. For each summary, the script can compute the [ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)) scores and use the `ROUGE-1` score to validate the implementation. -### Run ```bash # mamba-2.8b python ../summarize.py --test_trt_llm \ @@ -157,31 +153,24 @@ python ../summarize.py --test_trt_llm \ --data_type bf16 \ --engine_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/ -# mamba-1.4b -python ../summarize.py --test_trt_llm \ - --hf_model_dir ./mamba_model/mamba-1.4b/ \ - --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ - --data_type fp16 \ - --engine_dir ./mamba_model/mamba-1.4b/trt_engines/fp16/1-gpu/ - -# mamba-790m +# mamba-130m python ../summarize.py --test_trt_llm \ - --hf_model_dir ./mamba_model/mamba-790m/ \ + --hf_model_dir ./mamba_model/mamba-130m/ \ --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type fp16 \ - --engine_dir ./mamba_model/mamba-790m/trt_engines/fp16/1-gpu/ + --engine_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/ -# mamba-370m +# mamba2-2.7b python ../summarize.py --test_trt_llm \ - --hf_model_dir ./mamba_model/mamba-370m/ \ + --hf_model_dir ./mamba_model/mamba2-2.7b/ \ --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type fp16 \ - --engine_dir ./mamba_model/mamba-370m/trt_engines/fp16/1-gpu/ + --engine_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/ -# mamba-130m +# mamba2-130m python ../summarize.py --test_trt_llm \ - --hf_model_dir ./mamba_model/mamba-130m/ \ + --hf_model_dir ./mamba_model/mamba2-130m/ \ --tokenizer_dir ./mamba_model/gpt-neox-20b/ \ --data_type fp16 \ - --engine_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/ + --engine_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/ ``` diff --git a/examples/mamba/convert_checkpoint.py b/examples/mamba/convert_checkpoint.py index 9b4aa5898..ae5c03504 100644 --- a/examples/mamba/convert_checkpoint.py +++ b/examples/mamba/convert_checkpoint.py @@ -1,13 +1,17 @@ import argparse import copy import json +import re import time +from dataclasses import dataclass, field from pathlib import Path from typing import Union import safetensors.torch import torch from transformers import AutoConfig, AutoModelForCausalLM +from transformers.utils import CONFIG_NAME +from transformers.utils.hub import cached_file import tensorrt_llm from tensorrt_llm import logger @@ -55,7 +59,10 @@ def get_tllm_linear_weight(weight, prefix, bias=None): return results -def convert_hf_mamba(hf_mamba, rank=0, dtype='float32'): +def convert_hf_mamba(hf_mamba, + rank=0, + dtype='float32', + mamba_version: str = 'Mamba1'): weights = {} tik = time.time() @@ -130,9 +137,12 @@ def rename_hf_to_tllm(name: str): # change layer name if 'embeddings.' in name: name = name.replace('embeddings', 'vocab_embedding') + elif 'embedding.' in name: + name = name.replace('embedding', 'vocab_embedding') + norm_pattern = r'\d\.norm\.' if 'mixer.' in name: name = name.replace('mixer.', 'ssm.') - elif 'norm.' in name: + elif re.search(norm_pattern, name): name = name.replace('norm.', 'input_layernorm.') elif 'norm_f.' in name: name = name.replace('norm_f.', 'ln_f.') @@ -147,7 +157,8 @@ def rename_hf_to_tllm(name: str): def convert_from_hf_checkpoint(model_dir: Union[str, Path], rank=0, - dtype: Union[str, torch.dtype] = torch.float32): + dtype: Union[str, torch.dtype] = torch.float32, + mamba_version: str = 'Mamba1'): logger.info('Loading weights from HF Mamba...') tik = time.time() @@ -164,15 +175,19 @@ def convert_from_hf_checkpoint(model_dir: Union[str, Path], param = param.detach().cpu() if 'A_log' in name: param = -torch.exp(param.float()) - param = param.permute(1, 0).contiguous() + if mamba_version == 'Mamba1': + param = param.permute(1, 0).contiguous() elif 'D' in name: param = param.float() elif 'dt_proj.bias' in name: param = param.float() + elif 'dt_bias' in name: + param = param.float() elif 'conv1d.weight' in name: param = param.unsqueeze(3) - if 'in_proj' in name: + # split in_proj in Mamba1 + if 'in_proj' in name and mamba_version == 'Mamba1': in_proj_params = torch.split(param, param.size(0) // 2, dim=0) weights[tllm_name.replace('proj', 'proj_x')] = in_proj_params[0] weights[tllm_name.replace('proj', 'proj_z')] = in_proj_params[1] @@ -181,9 +196,10 @@ def convert_from_hf_checkpoint(model_dir: Union[str, Path], del model_params # lm_head - if 'lm_head.weight' not in weights: - weights['lm_head.weight'] = copy.deepcopy( - weights['backbone.vocab_embedding.weight']) + emb = weights['backbone.vocab_embedding.weight'] + if 'lm_head.weight' not in weights or weights['lm_head.weight'].data_ptr( + ) == emb.data_ptr(): + weights['lm_head.weight'] = copy.deepcopy(emb) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) @@ -208,6 +224,72 @@ def convert(worker_rank, args, convert_args): args.output_dir / f'rank{rank}.safetensors') +@dataclass +class MambaConfig: + + d_model: int = 2560 + d_intermediate: int = 0 + n_layer: int = 64 + vocab_size: int = 50277 + ssm_cfg: dict = field(default_factory=dict) + attn_layer_idx: list = field(default_factory=list) + attn_cfg: dict = field(default_factory=dict) + rms_norm: bool = True + residual_in_fp32: bool = True + fused_add_norm: bool = True + pad_vocab_size_multiple: int = 8 + tie_embeddings: bool = True + hidden_size: int = 2560 + num_hidden_layers: int = 64 + intermediate_size: int = 0 + state_size: int = 128 + conv_kernel: int = 4 + use_bias: bool = False + headdim: int = 64 + ngroups: int = 1 + chunk_size: int = 256 + ssm_rmsnorm: bool = True + + def update(self, data_dict): + self.__dict__.update(data_dict) + + +def load_config_hf(model_name): + resolved_archive_file = cached_file( + model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) + config = json.load(open(resolved_archive_file)) + if 'transformers_version' in config: # transformer compatible models + hf_config = AutoConfig.from_pretrained(model_name, + trust_remote_code=True) + # TODO: change mamba_version when transformers can support Mamba2 models + mamba_version = 'Mamba1' + else: # state-spaces/mamba models + hf_config = MambaConfig(**config) + hf_config.hidden_size = hf_config.d_model + hf_config.num_hidden_layers = hf_config.n_layer + if 'expand' in hf_config.ssm_cfg: + expand = hf_config.ssm_cfg['hf_config'] + hf_config.intermediate_size = expand * hf_config.d_model + else: + hf_config.intermediate_size = 2 * hf_config.d_model + ssm_cfg_to_hf_cfg = { + 'd_state': 'state_size', + 'd_conv': 'conv_kernel', + 'bias': 'use_bias', + 'headdim': 'headdim', + 'ngroups': 'ngroups', + 'chunk_size': 'chunk_size', + 'rmsnorm': 'ssm_rmsnorm', + } + cfg_dict = {} + for k, v in hf_config.ssm_cfg.items(): + if k in ssm_cfg_to_hf_cfg: + cfg_dict[ssm_cfg_to_hf_cfg[k]] = v + hf_config.update(cfg_dict) + mamba_version = hf_config.ssm_cfg.pop("layer", "Mamba1") + return hf_config, mamba_version + + def main(): print(tensorrt_llm.__version__) @@ -217,8 +299,8 @@ def main(): args.output_dir.mkdir(exist_ok=True, parents=True) - hf_config = AutoConfig.from_pretrained(args.model_dir, - trust_remote_code=True) + hf_config, mamba_version = load_config_hf(args.model_dir) + vocab_size = hf_config.vocab_size pad_vocab_size_multiple = hf_config.pad_vocab_size_multiple if vocab_size % pad_vocab_size_multiple != 0: @@ -239,15 +321,29 @@ def main(): 'hidden_act': 'silu', 'num_attention_heads': 1, 'rnn_hidden_size': hf_config.intermediate_size, + 'rnn_conv_dim_size': hf_config.intermediate_size, 'state_size': hf_config.state_size, 'conv_kernel': hf_config.conv_kernel, 'use_bias': hf_config.use_bias, + 'mamba_version': mamba_version, } + if mamba_version == 'Mamba2': + conv_dim = hf_config.intermediate_size + 2 * hf_config.ngroups * hf_config.state_size + mamba2_cfg = { + 'rnn_head_size': hf_config.headdim, + 'rnn_conv_dim_size': conv_dim, + 'ngroups': hf_config.ngroups, + 'chunk_size': hf_config.chunk_size, + 'ssm_rmsnorm': hf_config.ssm_rmsnorm, + } + config.update(mamba2_cfg) with (args.output_dir / 'config.json').open('w') as f: json.dump(config, f, indent=4) convert_from_ckpt = do_convert_from_ckpt(args) + # TODO: Add convert_hf_mamba support for Mamba2 when transformers can support Mamba2 models + assert convert_from_ckpt or mamba_version == 'Mamba2', "Mamba2 can only support convert from checkpoints." if not convert_from_ckpt: logger.info(f'Convert by using model') hf_mamba = AutoModelForCausalLM.from_pretrained(args.model_dir, @@ -264,6 +360,7 @@ def main(): convert_args['model_dir'] = args.model_dir else: convert_args['hf_mamba'] = hf_mamba + convert_args['mamba_version'] = mamba_version convert(0, args, convert_args) diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt index 4ffa84ff5..bde94c2fd 100644 --- a/examples/mamba/requirements.txt +++ b/examples/mamba/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 transformers>=4.39.0 datasets~=2.14.5 evaluate diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt index 21925c333..77510bad2 100644 --- a/examples/medusa/requirements.txt +++ b/examples/medusa/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/mixtral/requirements.txt b/examples/mixtral/requirements.txt index bdf3bc52d..38b5831c6 100644 --- a/examples/mixtral/requirements.txt +++ b/examples/mixtral/requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 transformers==4.38.2 accelerate==0.25.0 diff --git a/examples/mmlu.py b/examples/mmlu.py index c32afface..f9d72e524 100644 --- a/examples/mmlu.py +++ b/examples/mmlu.py @@ -255,6 +255,7 @@ def __init__(self, tokenizer, model, model_name, pad_id, end_id, self.pad_id = pad_id self.end_id = end_id self.max_attention_window_size = max_attention_window_size + self.output_len = 2 def __call__(self, prompt): rank = tensorrt_llm.mpi_rank() @@ -263,7 +264,7 @@ def __call__(self, prompt): batch_input_ids = [inputs] # For multi-choice tasks like MMLU, we don't need to adjust following parameters - output_len = 2 + output_len = self.output_len top_k = 1 top_p = 0.0 @@ -313,7 +314,8 @@ def __call__(self, prompt): def check_valid_length(self, prompt): if isinstance(self.model, nn.Module): return True - return len(self.tokenizer.encode(prompt)) <= self.model.max_input_len + input_len = len(self.tokenizer.encode(prompt)) + return input_len <= self.model.max_input_len and input_len + self.output_len <= self.model.max_seq_len def parse_args(): @@ -391,7 +393,7 @@ def main(): model = auto_model_cls.from_pretrained( args.hf_model_dir, trust_remote_code=True, - torch_dtype=DTYPE_STR_MAPPING[args.data_type], + torch_dtype=DTYPE_STR_MAPPING[args.hf_data_type], device_map="auto" if args.hf_device_map_auto else None, ) if not args.hf_device_map_auto: diff --git a/examples/mpt/requirements.txt b/examples/mpt/requirements.txt index 7800b0ea6..5e5ac6a15 100644 --- a/examples/mpt/requirements.txt +++ b/examples/mpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index 38166a793..079fee2a9 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -12,7 +12,7 @@ We first describe how to run each model on a single GPU. We then provide general - [Deplot](#deplot) - [Fuyu](#fuyu) - [Kosmos-2](#kosmos-2) -- [LLaVA and VILA](#llava-and-vila) +- [LLaVA, LLaVa-NeXT and VILA](#llava-llava-next-and-vila) - [NeVA](#neva) - [Nougat](#nougat) - [Phi-3-vision](#phi-3-vision) @@ -361,9 +361,9 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu ``` -## LLaVA and VILA +## LLaVA, LLaVa-NeXT and VILA -[LLaVA](https://github.com/haotian-liu/LLaVA) and [VILA](https://github.com/Efficient-Large-Model/VILA) are both visual language models (VLM) that can be deployed in TensorRT-LLM with many quantization options. +[LLaVA](https://github.com/haotian-liu/LLaVA) and [VILA](https://github.com/Efficient-Large-Model/VILA) are both visual language models (VLM) that can be deployed in TensorRT-LLM with many quantization options. [LLaVA-NeXT](https://huggingface.co/collections/llava-hf/llava-next-65f75c4afac77fd37dbbe6cf) is an extension of LLaVA. TRT-LLM currently supports [Mistral-7b](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) and [ Nous-Hermes-2-Yi-34B](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) variant of LLaVA-NeXT. 1. Download Huggingface model weights. These models have both visual and LLM components unlike BLIP2 example which downloads only LLM components from Huggingface. @@ -374,6 +374,12 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in export MODEL_NAME="llava-1.5-7b-hf" # also llava-1.5-13b-hf git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME} ``` + For LLaVA-NeXT, + + ```bash + export MODEL_NAME="llava-v1.6-mistral-7b-hf" #for 34b variant "llava-v1.6-34b-hf" + git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME} + ``` For VILA, we need a few more steps until it is added to HF model zoo @@ -408,6 +414,18 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --max_seq_len 2560 \ --max_multimodal_len 576 # 1 (max_batch_size) * 576 (num_visual_features) for LLaVA + trtllm-build \ + --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ + --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ + --gpt_attention_plugin float16 \ + --gemm_plugin float16 \ + --max_batch_size 1 \ + --max_input_len 4096 \ + --max_seq_len 5120 \ + --max_num_tokens 4096 \ # 1 (max_batch_size) * 4096 (max_input_len) + --max_multimodal_len 4096 \ # 1 (max_batch_size) * 4096 (max_input_len) + --use_fused_mlp + trtllm-build \ --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ @@ -426,6 +444,8 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in ```bash python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type llava # for LLaVA + python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type llava_next --model_path tmp/hf_models/${MODEL_NAME} --max_batch_size 5 # 1 (max_batch_size) * 5 (because LLAVA-NeXT visual encoder can have at most 5 patches) # for LLaVA-NeXT + python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type vila --vila_path ${VILA_PATH} # for VILA ``` @@ -435,7 +455,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --hf_model_dir tmp/hf_models/${MODEL_NAME} \ --visual_engine_dir visual_engines/${MODEL_NAME} \ --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ - --input_text "Question: which city is this? Answer:" # for LLaVA + --input_text "Question: which city is this? Answer:" # for LLaVA and for LLaVA-NeXT ``` For VILA, you can use either local file or web url as input images. diff --git a/examples/multimodal/build_visual_engine.py b/examples/multimodal/build_visual_engine.py index 440d8572f..70bab2ed9 100644 --- a/examples/multimodal/build_visual_engine.py +++ b/examples/multimodal/build_visual_engine.py @@ -11,13 +11,6 @@ import torch import tensorrt as trt from tensorrt_llm.builder import Builder -# isort: on -import json -import math - -import torch.nn.functional as F -from PIL import Image -from safetensors.torch import save_file from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, Blip2ForConditionalGeneration, Blip2Processor, @@ -25,6 +18,13 @@ LlavaForConditionalGeneration, NougatProcessor, Pix2StructForConditionalGeneration, VisionEncoderDecoderModel) +# isort: on +import json +import math + +import torch.nn.functional as F +from PIL import Image +from safetensors.torch import save_file def parse_arguments(): @@ -33,8 +33,8 @@ def parse_arguments(): type=str, default=None, choices=[ - 'blip2', 'llava', 'vila', 'nougat', 'cogvlm', - 'fuyu', 'pix2struct', 'neva', 'kosmos-2', + 'blip2', 'llava', 'llava_next', 'vila', 'nougat', + 'cogvlm', 'fuyu', 'pix2struct', 'neva', 'kosmos-2', 'video-neva', 'phi-3-vision' ], help="Model type") @@ -80,7 +80,7 @@ def build(self): build_blip2_engine(args) elif args.model_type == 'pix2struct': build_pix2struct_engine(args) - elif args.model_type == 'llava': + elif 'llava' in args.model_type: build_llava_engine(args) elif args.model_type == 'vila': assert args.vila_path is not None, "Please clone and provide VILA source code path" @@ -305,30 +305,59 @@ def forward(self, image, attention_mask): def build_llava_engine(args): processor = AutoProcessor.from_pretrained(args.model_path) - raw_image = Image.new('RGB', [10, 10]) # dummy image - image = processor(text="dummy", images=raw_image, - return_tensors="pt")['pixel_values'].to( - args.device, torch.float16) - - class LlavaVisionWrapper(torch.nn.Module): - - def __init__(self, tower, projector, feature_layer): - super().__init__() - self.tower = tower - self.projector = projector - self.feature_layer = feature_layer - - def forward(self, image): - all_hidden_states = self.tower( - image, output_hidden_states=True).hidden_states - features = all_hidden_states[self.feature_layer][:, 1:] - return self.projector(features) - - model = LlavaForConditionalGeneration.from_pretrained( - args.model_path, torch_dtype=torch.float16) - wrapper = LlavaVisionWrapper(model.vision_tower.to(args.device), - model.multi_modal_projector.to(args.device), - model.config.vision_feature_layer) + if args.model_type == "llava": + raw_image = Image.new('RGB', [10, 10]) # dummy image + image = processor(text="dummy", images=raw_image, + return_tensors="pt")['pixel_values'].to( + args.device, torch.float16) + + class LlavaVisionWrapper(torch.nn.Module): + + def __init__(self, tower, projector, feature_layer): + super().__init__() + self.tower = tower + self.projector = projector + self.feature_layer = feature_layer + + def forward(self, image): + all_hidden_states = self.tower( + image, output_hidden_states=True).hidden_states + features = all_hidden_states[self.feature_layer][:, 1:] + return self.projector(features) + + model = LlavaForConditionalGeneration.from_pretrained( + args.model_path, torch_dtype=torch.float16) + wrapper = LlavaVisionWrapper( + model.vision_tower.to(args.device), + model.multi_modal_projector.to(args.device), + model.config.vision_feature_layer) + elif args.model_type == "llava_next": + from transformers import LlavaNextForConditionalGeneration + raw_image = Image.new('RGB', [512, 512]) + image = processor(text="dummy", images=raw_image, + return_tensors="pt")['pixel_values'].to( + args.device, torch.float16)[0] + + class LlavaNextVisionWrapper(torch.nn.Module): + + def __init__(self, vision_tower, projector): + super().__init__() + self.vision_tower = vision_tower + self.projector = projector + + def forward(self, pixel_values): + image_features = self.vision_tower(pixel_values, + output_hidden_states=True) + selected_image_feature = image_features.hidden_states[-2][:, 1:] + image_features = self.projector(selected_image_feature) + return image_features # (bs, 576, c) + + model = LlavaNextForConditionalGeneration.from_pretrained( + args.model_path, torch_dtype=torch.float16) + wrapper = LlavaNextVisionWrapper( + model.vision_tower.vision_model.to(args.device), + model.multi_modal_projector.to(args.device), + ) export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine( @@ -336,6 +365,11 @@ def forward(self, image): [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] args.output_dir, args.max_batch_size) + if args.model_type == "llava_next": + image_newline = model.image_newline.data + tensor_img_newline = {"image_newline": image_newline} + save_file(tensor_img_newline, + os.path.join(args.output_dir, "image_newline.safetensors")) def build_vila_engine(args): @@ -517,7 +551,12 @@ def forward(self, images): vision_x = self.connector(vision_x) return vision_x - encoder = AutoModel.from_pretrained(vision_config["from_pretrained"], + vision_path = vision_config["from_pretrained"] + joined_path = os.path.join(os.path.dirname(args.model_path), + os.path.basename(vision_path)) + if os.path.isdir(joined_path): + vision_path = joined_path + encoder = AutoModel.from_pretrained(vision_path, torch_dtype=torch.bfloat16, trust_remote_code=True) vision_encoder = encoder.vision_model diff --git a/examples/multimodal/run.py b/examples/multimodal/run.py index 8c55905de..cbaed938a 100644 --- a/examples/multimodal/run.py +++ b/examples/multimodal/run.py @@ -95,6 +95,130 @@ def trt_dtype_to_torch(dtype): raise TypeError("%s is not supported" % dtype) +class LlavaNextUtils: + # https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py + + @staticmethod + def select_best_resolution(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int( + original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, + original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + @staticmethod + def get_anyres_image_grid_shape(image_size, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + IMAGE_GRID_PINPOINTS = [[336, 672], [672, 336], [672, 672], [1008, 336], + [336, 1008]] + width, height = LlavaNextUtils.select_best_resolution( + image_size, IMAGE_GRID_PINPOINTS) + return width // patch_size, height // patch_size + + @staticmethod + def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of the image (width, height). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + + @staticmethod + def rearrange_image_features(image_feature, image_newline, image_size): + """ + Combine PyTorch feature grids from image patches. + + Args: + image_feature (torch.Tensor): The feature grids, assumed to be in NxCxHxW format. + image_newline (torch.Tensor): The newline embedding. + image_size (tuple): Size of the original image (width, height). + """ + CLIP_IMAGE_SIZE = 336 + CLIP_PATCH_SIZE = 14 + NUM_PATCHES_PER_SIDE = CLIP_IMAGE_SIZE // CLIP_PATCH_SIZE + if image_feature.shape[0] == 1: + return torch.cat((image_feature, image_newline[None]), dim=0) + + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = NUM_PATCHES_PER_SIDE + assert height * width == base_image_feature.shape[0] + + num_patch_width, num_patch_height = LlavaNextUtils.get_anyres_image_grid_shape( + image_size, CLIP_IMAGE_SIZE) + image_feature = image_feature.view(num_patch_height, num_patch_width, + height, width, -1) + + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = LlavaNextUtils.unpad_image(image_feature, image_size) + image_feature = torch.cat( + (image_feature, image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1)), + dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + return image_feature + + class MultimodalModelRunner: def __init__(self, args): @@ -123,7 +247,9 @@ def __init__(self, args): if self.model_type == 'video-neva': self.num_frames = config['builder_config'].get('num_frames', None) - + if self.model_type == "llava_next": + self.llm_name = AutoConfig.from_pretrained( + args.hf_model_dir).text_config._name_or_path self.profiling_iterations = 20 self.init_image_encoder() @@ -203,6 +329,14 @@ def init_image_encoder(self): device="cuda") as f: for k in f.keys(): self.image_newlines[k] = f.get_tensor(k) + if self.model_type == "llava_next": + self.image_newlines = {} + image_newlines_path = os.path.join(self.args.visual_engine_dir, + 'image_newline.safetensors') + with safe_open(image_newlines_path, framework="pt", + device="cuda") as f: + for k in f.keys(): + self.image_newlines[k] = f.get_tensor(k) def init_llm(self): if self.decoder_llm: @@ -276,6 +410,12 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, image = input['pixel_values'] bs = image.shape[0] image = image.flatten(0, 1) + elif self.model_type == 'llava_next': + input = image + image = input['pixel_values'] + bs = image.shape[0] + image = image[0] + image_size = input['image_sizes'][0].cpu() if not warmup: profiler.start("Vision") @@ -366,6 +506,13 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, input_ids = self.ptuning_setup_phi3(visual_features, input_ids, num_img_tokens) length = input_ids.shape[1] + elif self.model_type == 'llava_next': + visual_features = LlavaNextUtils.rearrange_image_features( + visual_features, self.image_newlines["image_newline"], + image_size) + input_ids = self.ptuning_setup_llava_next(visual_features, + pre_prompt, post_prompt) + length = input_ids.shape[1] else: pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", @@ -387,7 +534,9 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, input_lengths = torch.IntTensor([length] * args.batch_size).to( torch.int32) - if self.model_type in ['fuyu', 'kosmos-2', 'phi-3-vision']: + if self.model_type in [ + 'fuyu', 'kosmos-2', 'phi-3-vision', 'llava_next' + ]: return input_ids, input_lengths, [visual_features], visual_features input_ids, ptuning_args = self.setup_fake_prompts( @@ -667,6 +816,19 @@ def ptuning_setup_fuyu(self, input_ids, image_patches_indices): res_input_ids.append(cur_input_ids) return res_input_ids + def ptuning_setup_llava_next(self, visual_features, pre_prompt, + post_prompt): + input_ids = [] + fake_prompt_ids = list( + range(self.model_config.vocab_size, + self.model_config.vocab_size + visual_features.shape[0])) + input_ids = self.tokenizer.encode( + pre_prompt[0]) + fake_prompt_ids + self.tokenizer.encode( + post_prompt[0])[self.tokenizer.add_bos_token:] + input_ids = [input_ids] * len(pre_prompt) + input_ids = torch.tensor(input_ids) + return input_ids + def ptuning_setup_phi3(self, visual_features, input_ids, num_img_tokens): fake_prompt_id = torch.arange( self.model_config.vocab_size, @@ -869,6 +1031,32 @@ def setup_inputs(self, input_text, raw_image): pre_prompt = """System\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUser""" post_prompt = f"\n{input_text}\nAssistant\nquality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n" "" + elif self.model_type == "llava_next": + if self.llm_name == "mistralai/Mistral-7B-Instruct-v0.2": + pre_prompt = "[INST] " + if input_text is None: + input_text = "Question: which city is this? Answer:" + post_prompt = f"\n{input_text} [/INST]" + prompt = pre_prompt + post_prompt + + elif self.llm_name == "NousResearch/Nous-Hermes-2-Yi-34B": + pre_prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n" + if input_text is None: + input_text = "Question: which city is this? Answer:" + post_prompt = f"\n{input_text}<|im_end|><|im_start|>assistant\n" + prompt = pre_prompt + post_prompt + + else: + raise Exception( + f"Prompt template for {self.llm_name} for not included currently" + ) + + processor = AutoProcessor.from_pretrained(args.hf_model_dir, + trust_remote_code=True) + image = processor(text=prompt, + images=raw_image, + return_tensors="pt") + elif self.model_type in ['llava', 'vila', 'fuyu', 'kosmos-2']: # LLaVA and VILA if self.model_type == "llava": @@ -924,7 +1112,8 @@ def setup_inputs(self, input_text, raw_image): pre_prompt = [pre_prompt] * self.args.batch_size post_prompt = [post_prompt] * self.args.batch_size if self.model_type not in [ - 'fuyu', 'pix2struct', 'kosmos-2', 'vila', 'phi-3-vision' + 'fuyu', 'pix2struct', 'kosmos-2', 'vila', 'phi-3-vision', + 'llava_next' ]: if image.dim() == 5: image = image.expand(args.batch_size, -1, -1, -1, @@ -932,7 +1121,6 @@ def setup_inputs(self, input_text, raw_image): else: image = image.expand(args.batch_size, -1, -1, -1).contiguous() image = image.to(self.device) - # Generate decoder_input_ids for enc-dec models # Custom prompts can be added as: # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids @@ -955,7 +1143,6 @@ def setup_inputs(self, input_text, raw_image): def run(self, input_text, input_image, max_new_tokens): input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = model.setup_inputs( input_text, input_image) - model.generate(pre_prompt, post_prompt, processed_image, @@ -999,7 +1186,9 @@ def print_result(self, input_text, output_text): elif self.model_type == "pix2struct": assert "characteristic | cat food, day | cat food, wet | cat treats" in output_text[ 0][0].lower() - elif self.model_type in ['blip2', 'neva', 'phi-3-vision']: + elif self.model_type in [ + 'blip2', 'neva', 'phi-3-vision', 'llava_next' + ]: assert 'singapore' in output_text[0][0].lower() elif self.model_type == 'video-neva': assert 'robot' in output_text[0][0].lower() diff --git a/examples/nemotron/requirements.txt b/examples/nemotron/requirements.txt index baeae63c0..dbb0a4dac 100644 --- a/examples/nemotron/requirements.txt +++ b/examples/nemotron/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 transformers==4.40.2 datasets~=2.14.5 evaluate~=0.4.1 diff --git a/examples/opt/requirements.txt b/examples/opt/requirements.txt index 7800b0ea6..5e5ac6a15 100644 --- a/examples/opt/requirements.txt +++ b/examples/opt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/phi/requirements.txt b/examples/phi/requirements.txt index 065885745..bc040d1b5 100644 --- a/examples/phi/requirements.txt +++ b/examples/phi/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/quantization/requirements.txt b/examples/quantization/requirements.txt index 1d8afb40a..1be698aa0 100644 --- a/examples/quantization/requirements.txt +++ b/examples/quantization/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets>=2.14.4 nemo-toolkit[all]<=1.20.0,>=1.18.0 rouge_score~=0.1.2 diff --git a/examples/qwen/requirements.txt b/examples/qwen/requirements.txt index ea746afdd..3c65e3978 100644 --- a/examples/qwen/requirements.txt +++ b/examples/qwen/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/qwenvl/requirements.txt b/examples/qwenvl/requirements.txt index 4440b456a..2d4837d7d 100644 --- a/examples/qwenvl/requirements.txt +++ b/examples/qwenvl/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/recurrentgemma/convert_checkpoint.py b/examples/recurrentgemma/convert_checkpoint.py index 7b12bd756..3e7800b9e 100644 --- a/examples/recurrentgemma/convert_checkpoint.py +++ b/examples/recurrentgemma/convert_checkpoint.py @@ -496,6 +496,7 @@ def main(): rnn_hidden_size=ckpt_config["lru_width"], logits_soft_cap=ckpt_config["logits_soft_cap"], emb_scale_by_sqrt_dim=ckpt_config["embeddings_scale_by_sqrt_dim"], + rnn_conv_dim_size=ckpt_config["lru_width"], ) trt_llm_config_dict = trt_llm_config.to_dict() diff --git a/examples/recurrentgemma/requirements.txt b/examples/recurrentgemma/requirements.txt index 9d81522ce..20ca84af1 100644 --- a/examples/recurrentgemma/requirements.txt +++ b/examples/recurrentgemma/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 git+https://github.com/google-deepmind/recurrentgemma.git flax>=0.8.2 jax~=0.4.23 diff --git a/examples/run.py b/examples/run.py index 80f8b603c..99981b6ac 100644 --- a/examples/run.py +++ b/examples/run.py @@ -36,7 +36,6 @@ def parse_arguments(args=None): - # see `add_common_args` for extended list of arguments parser = argparse.ArgumentParser() parser.add_argument('--max_input_length', type=int, default=923) parser.add_argument('--max_output_len', type=int, required=True) @@ -319,18 +318,6 @@ def main(args): "Debug mode is not supported in C++ session for now, fallback to Python session." ) args.use_py_session = True - if args.return_all_generated_tokens and args.use_py_session: - raise ValueError( - "Returning all the generated tokens at each step is not supported in the Python session, use C++ session instead." - ) - if (not args.return_all_generated_tokens) and args.streaming and ( - args.num_beams > 1): - logger.warning( - "Setting return_all_generated_tokens to True since streaming AND beam search are done simultaneously. " - "Returning the full beams at each streaming step is needed because beam search + streaming can change previous outputs. " - "WARNING: using this option may increase network usage significantly (quadratically w.r.t output length)." - ) - args.return_all_generated_tokens = True runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp runner_kwargs = dict( engine_dir=args.engine_dir, @@ -360,7 +347,8 @@ def main(args): kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse, kv_cache_free_gpu_memory_fraction=args. kv_cache_free_gpu_memory_fraction, - enable_chunked_context=args.enable_chunked_context) + enable_chunked_context=args.enable_chunked_context, + ) runner = runner_cls.from_dir(**runner_kwargs) with torch.no_grad(): @@ -394,8 +382,7 @@ def main(args): output_sequence_lengths=True, no_repeat_ngram_size=args.no_repeat_ngram_size, return_dict=True, - medusa_choices=args.medusa_choices, - return_all_generated_tokens=args.return_all_generated_tokens) + medusa_choices=args.medusa_choices) torch.cuda.synchronize() if args.streaming: @@ -482,9 +469,7 @@ def main(args): prompt_tasks=args.prompt_tasks, streaming=args.streaming, output_sequence_lengths=True, - return_dict=True, - return_all_generated_tokens=args.return_all_generated_tokens - ) + return_dict=True) torch.cuda.synchronize() tensorrt_llm.profiler.start("tmp") @@ -516,9 +501,7 @@ def main(args): prompt_tasks=args.prompt_tasks, streaming=args.streaming, output_sequence_lengths=True, - return_dict=True, - return_all_generated_tokens=args.return_all_generated_tokens - ) + return_dict=True) torch.cuda.synchronize() tensorrt_llm.profiler.stop("tmp") diff --git a/examples/skywork/requirements.txt b/examples/skywork/requirements.txt index b56872e4f..6dd7c5d2f 100644 --- a/examples/skywork/requirements.txt +++ b/examples/skywork/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets~=2.16.1 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/smaug/requirements.txt b/examples/smaug/requirements.txt index f9f50349b..7c473affe 100644 --- a/examples/smaug/requirements.txt +++ b/examples/smaug/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/summarize.py b/examples/summarize.py index 2be8c4703..bfdfaefa6 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -415,10 +415,6 @@ def eval_hf(datapoint, "Python bindings of C++ session is unavailable, fallback to Python session." ) args.use_py_session = True - if args.return_all_generated_tokens: - raise ValueError( - "Returning all the generated tokens at each step is not supported in summarize.py" - ) runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp runner_kwargs = dict(engine_dir=args.engine_dir, rank=runtime_rank, diff --git a/examples/utils.py b/examples/utils.py index 2305ef0e1..98e21a10d 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -329,16 +329,4 @@ def add_common_args(parser): action='store_true', help="Use device map 'auto' to load a pretrained HF model. This may " "help to test a large model that cannot fit into a singlue GPU.") - - parser.add_argument( - "--return_all_generated_tokens", - default=False, - action="store_true", - help="if false, return only generated tokens at each streaming step." - "If true, return the full beams/outputs at each step" - "Overwritten to True if num_beams>1 and streaming" - "(only available with cpp session). " - "WARNING: using this option may increase network usage significantly (quadratically w.r.t output length)." - ) - return parser diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index 7f7b9e603..0c0bac6ae 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.12.0.dev2024070200 +tensorrt_llm==0.12.0.dev2024070900 tiktoken datasets kaldialign diff --git a/requirements-windows.txt b/requirements-windows.txt index cb9f553d9..04e6d6917 100644 --- a/requirements-windows.txt +++ b/requirements-windows.txt @@ -20,7 +20,7 @@ tokenizers>=0.14 # Default torch is CPU-only on Windows, so need to specify a torch version with GPU support torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-win_amd64.whl nvidia-modelopt~=0.13,<0.14 -transformers==4.38.2 +transformers>=4.38.2 wheel optimum evaluate diff --git a/tensorrt_llm/_ipc_utils.py b/tensorrt_llm/_ipc_utils.py index fea348dcb..84c04c947 100644 --- a/tensorrt_llm/_ipc_utils.py +++ b/tensorrt_llm/_ipc_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import array import struct +import sys from contextlib import contextmanager from typing import List, Tuple @@ -83,7 +84,7 @@ def __init__(self, mapping: Mapping, size: int): self.local_ptr = 0 def __del__(self): - if self.open_ipc: + if not sys.is_finalizing() and self.open_ipc: IpcMemory.close_ipc_memory(self.mapping, self.peer_ptrs) def serialize(self) -> List[int]: diff --git a/tensorrt_llm/auto_parallel/cluster_info.py b/tensorrt_llm/auto_parallel/cluster_info.py index 09eeaee88..903e4f2ec 100644 --- a/tensorrt_llm/auto_parallel/cluster_info.py +++ b/tensorrt_llm/auto_parallel/cluster_info.py @@ -69,23 +69,6 @@ class ClusterInfo(DictConversion): "PCIe-5": 64, } -_templates = { - "H100-SXM": - dict( - inter_node_bw_per_device=50, - intra_node_bw_per_device=450, - intra_node_sharp=True, - memory_bw=3350, - math_throughput=MathThroughput( - int8=1979, - fp8=1979, - float16=989, - bfloat16=989, - float32=495, - ), - ), -} - cluster_infos = { # from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf "A100-SXM-80GB": @@ -119,18 +102,18 @@ class ClusterInfo(DictConversion): # from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet "H100-SXM": ClusterInfo( - **_templates["H100-SXM"], + inter_node_bw_per_device=50, + intra_node_bw_per_device=450, + intra_node_sharp=True, + memory_bw=3350, memory_budget_per_device=80, - ), - "H100-SXM-64G": - ClusterInfo( - **_templates["H100-SXM"], - memory_budget_per_device=64, - ), - "H100-SXM-94G": - ClusterInfo( - **_templates["H100-SXM"], - memory_budget_per_device=94, + math_throughput=MathThroughput( + int8=1979, + fp8=1979, + float16=989, + bfloat16=989, + float32=495, + ), ), "H100-PCIe": ClusterInfo( @@ -369,12 +352,6 @@ def is_32gb(): return "H100-SXM" else: return "H100-PCIe" - elif match("H100XS", device_name): - return "H100-SXM-64G" - elif match("H100XM", device_name): - return "H100-SXM" - elif match("H100XL", device_name): - return "H100-SXM-94G" elif match("L40S", device_name): return "L40S" elif match("L40", device_name): diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index ef71cc6aa..ce5dd1317 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -27,11 +27,11 @@ from ..auto_parallel import infer_cluster_config from ..auto_parallel.cluster_info import cluster_infos from ..builder import BuildConfig, Engine, build +from ..functional import PositionEmbeddingType from ..logger import logger from ..lora_manager import LoraConfig, LoraManager from ..models import MODEL_MAP, PretrainedConfig -from ..models.modeling_utils import (WEIGHT_LOADER_MODELS, - SpeculativeDecodingMode) +from ..models.modeling_utils import SpeculativeDecodingMode from ..plugin import PluginConfig, add_plugin_argument @@ -248,13 +248,6 @@ def parse_arguments(): return args -def preprocess_model_config(model_config, **kwargs): - if model_config.architecture in WEIGHT_LOADER_MODELS: - model_config.mapping.tp_size = kwargs['tp_size'] - model_config.mapping.pp_size = kwargs['pp_size'] - model_config.mapping.world_size = kwargs['tp_size'] * kwargs['pp_size'] - - def build_model( build_config: BuildConfig, rank: int = 0, @@ -428,7 +421,6 @@ def main(): ckpt_dir = ckpt_dir_or_model_config model_config = PretrainedConfig.from_json_file(config_path) - preprocess_model_config(model_config, **kwargs) if args.build_config is None: if args.multiple_profiles == "enable" and args.opt_num_tokens is not None: @@ -472,7 +464,6 @@ def main(): deduced_max_seq_len = model_config.max_position_embeddings # Step 2: Scale max_seq_len with rotary scaling - rotary_scaling = getattr(model_config, "rotary_scaling", None) if rotary_factor != 1: deduced_max_seq_len *= rotary_factor logger.warning( @@ -485,8 +476,18 @@ def main(): f'max_seq_len is not specified, using value {deduced_max_seq_len}' ) else: - if not plugin_config.streamingllm and model_config.max_position_embeddings is not None: - assert args.max_seq_len <= model_config.max_position_embeddings * rotary_factor, f'max_seq_len {args.max_seq_len} can\'t be larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}' + if not plugin_config.streamingllm and model_config.max_position_embeddings is not None \ + and model_config.position_embedding_type != PositionEmbeddingType.relative: + if args.max_seq_len > model_config.max_position_embeddings * rotary_factor: + logger.warning( + f'max_seq_len {args.max_seq_len} is larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}, ' + 'the model accuracy might be affected') + + if args.max_input_len > args.max_seq_len: + logger.warning( + f'max_input_len is {args.max_input_len} is larger than max_seq_len {args.max_seq_len}, clipping it to max_seq_len' + ) + args.max_input_len = args.max_seq_len build_config = BuildConfig.from_dict( { diff --git a/tensorrt_llm/executor.py b/tensorrt_llm/executor.py index 1b26b128b..658108b9c 100644 --- a/tensorrt_llm/executor.py +++ b/tensorrt_llm/executor.py @@ -76,9 +76,9 @@ def as_executor_request(self) -> tllm.Request: # The following options in the Executor API are not yet exposed by the HLAPI: # https://jirasw.nvidia.com/browse/TRTLLM-489 "bad_words": - self.sampling_params.bad_words or [], + self.sampling_params._get_bad_words(), "stop_words": - self.sampling_params.stop_words or [], + self.sampling_params._get_stop_words(), "embedding_bias": self.sampling_params.embedding_bias, "external_draft_tokens_config": @@ -182,6 +182,15 @@ def handle_generation_msg(self, tensors: tuple, error: str): self.outputs[i].generation_logits = generation_logits[ i, :self.outputs[i].length] + if self.finished and not self._generation_request.sampling_params.include_stop_str_in_output: + for beam_output in self.outputs: + for stop_ids in self._generation_request.sampling_params._get_stop_words( + ): + if beam_output.token_ids[-len(stop_ids):] == stop_ids: + beam_output.token_ids = beam_output.token_ids[:-len( + stop_ids)] + break + if context_logits is not None: self.context_logits = context_logits diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 08e768d9c..b89ffff99 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -5637,17 +5637,20 @@ def selective_scan(input: Tensor, A: Tensor, BC: Tensor, D: Tensor, - z: Tensor, host_request_types: Tensor, last_token_ids: Tensor, dim: int, dstate: int, dt_rank: int, - is_variable_B: bool, - is_variable_C: bool, delta_softplus: bool, dtype: str, - slot_mapping: Optional[Tensor] = None): + z: Optional[Tensor] = None, + host_context_lengths: Optional[Tensor] = None, + slot_mapping: Optional[Tensor] = None, + nheads: int = 1, + ngroups: int = 1, + chunk_size: int = 256, + mamba_version: str = 'Mamba1'): ''' Parameters: input : Tensor (On GPU) @@ -5658,27 +5661,34 @@ def selective_scan(input: Tensor, Or the CPU tensor of shape [1] for the pointer of paged states. delta : Tensor (On GPU) - The delta tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding + The delta tensor. + mamba: Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding + mamba2: Its shape is [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding delta_bias : Tensor (On GPU) - The delta bias tensor. Its shape is [dim] + The delta bias tensor. + mamba: Its shape is [dim] + mamba2: Its shape is [nheads] A : Tensor (On GPU) - A matrix. Its shape is [dstate, dim] + A matrix. + mamba: Its shape is [dstate, dim] + mamba2: Its shape is [nheads] BC : Tensor (On GPU) - B matrix. Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding + B and C matrix. + mamba: Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding + mamba2: Its shape is [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for remove_input_padding D : Tensor (On GPU) - D matrix. Its shape is [dim] - - z : Tensor (On GPU) - The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding + D matrix. + mamba: Its shape is [dim] + mamba2: Its shape is [nheads] host_request_types : Tensor (On CPU) The tensor on the host that indicates if a request is in context or generation phase. Its shape is [batch_size]. See Inflight Batching - in docs/gpt_attention.md, + in docs/gpt_attention.md last_token_ids : Tensor (On GPU) The inclusive prefix-sum of the lengths or the lengths of the @@ -5693,22 +5703,32 @@ def selective_scan(input: Tensor, dt_rank: int The rank dimension of dt_proj - is_variable_B : bool - Is the matrix B a variable? Set to 'True' if B is a dynamic matrix - during inference, 'False' otherwise - - is_variable_C : bool - Is the matrix C a variable? Set to 'True' if C is a dynamic matrix - during inference, 'False' otherwise - delta_softplus : bool Do we apply softplus to the delta. dtype: str data type + z : Tensor (On GPU) (Optional) + The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding + + host_context_lengths: Tensor (On CPU) (Optional) + A host tensor that contains the lengths of the different inputs, + slot_mapping: Tensor (On GPU) (Optional) Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim] + + nheads: int (Optional) + The number of heads. + + ngroups: int (Optional) + The number of groups. + + chunk_size: int (Optional) + The chunk_size is used for the chunk_scan kernel. + + mamba_version: int (Optional) + Mamba version, support Mamba1 as default. ''' assert host_request_types is not None selective_scan_plg_creator = trt.get_plugin_registry().get_plugin_creator( @@ -5721,12 +5741,13 @@ def selective_scan(input: Tensor, trt.PluginFieldType.INT32) dt_rank = trt.PluginField("dt_rank", np.array(dt_rank, dtype=np.int32), trt.PluginFieldType.INT32) - is_variable_B = trt.PluginField( - "is_variable_B", np.array(np.int8(is_variable_B), dtype=np.int8), - trt.PluginFieldType.INT8) - is_variable_C = trt.PluginField( - "is_variable_C", np.array(np.int8(is_variable_C), dtype=np.int8), - trt.PluginFieldType.INT8) + nheads = trt.PluginField("nheads", np.array(nheads, dtype=np.int32), + trt.PluginFieldType.INT32) + ngroups = trt.PluginField("ngroups", np.array(ngroups, dtype=np.int32), + trt.PluginFieldType.INT32) + chunk_size = trt.PluginField("chunk_size", + np.array(chunk_size, dtype=np.int32), + trt.PluginFieldType.INT32) delta_softplus = trt.PluginField( "delta_softplus", np.array(np.int8(delta_softplus), dtype=np.int8), trt.PluginFieldType.INT8) @@ -5741,20 +5762,34 @@ def selective_scan(input: Tensor, "paged_state", np.array(np.int8(default_net().plugin_config.paged_state), dtype=np.int8), trt.PluginFieldType.INT8) + if z is None: + z_enabled = trt.PluginField("z_enabled", np.array(0, dtype=np.int8), + trt.PluginFieldType.INT8) + else: + z_enabled = trt.PluginField("z_enabled", np.array(1, dtype=np.int8), + trt.PluginFieldType.INT8) + is_mamba2 = trt.PluginField( + "is_mamba2", + np.array(1 if mamba_version == 'Mamba2' else 0, dtype=np.int8), + trt.PluginFieldType.INT8) pfc = trt.PluginFieldCollection([ - dim, dstate, dt_rank, is_variable_B, is_variable_C, delta_softplus, - pf_type, remove_input_padding, paged_state + dim, dstate, dt_rank, nheads, ngroups, chunk_size, delta_softplus, + pf_type, remove_input_padding, paged_state, z_enabled, is_mamba2 ]) selective_scan_plug = selective_scan_plg_creator.create_plugin( "selective_scan", pfc) plug_inputs = [ - input, state_or_ptr, delta, delta_bias, A, BC, D, z, host_request_types, + input, state_or_ptr, delta, delta_bias, A, BC, D, host_request_types, last_token_ids ] + if default_net().plugin_config.remove_input_padding: + plug_inputs += [host_context_lengths] if default_net().plugin_config.paged_state: plug_inputs += [slot_mapping] + if z is not None: + plug_inputs += [z] plug_inputs = [i.trt_tensor for i in plug_inputs] layer = default_trtnet().add_plugin_v2(plug_inputs, selective_scan_plug) diff --git a/tensorrt_llm/hlapi/build_cache.py b/tensorrt_llm/hlapi/build_cache.py index 6fcb6cf29..586bf17cc 100644 --- a/tensorrt_llm/hlapi/build_cache.py +++ b/tensorrt_llm/hlapi/build_cache.py @@ -27,32 +27,59 @@ def get_build_cache_config_from_env() -> tuple[bool, str]: return build_cache_enabled, build_cache_root +class BuildCacheConfig: + """ + Configuration for the build cache. + + Attributes: + cache_root (str): The root directory for the build cache. + max_records (int): The maximum number of records to store in the cache. + max_cache_storage_gb (float): The maximum amount of storage (in GB) to use for the cache. + """ + + def __init__(self, + cache_root: Optional[Path] = None, + max_records: int = 10, + max_cache_storage_gb: float = 256): + self._cache_root = cache_root + self._max_records = max_records + self._max_cache_storage_gb = max_cache_storage_gb + + @property + def cache_root(self) -> Path: + _build_cache_enabled, _build_cache_root = get_build_cache_config_from_env( + ) + return self._cache_root or Path(_build_cache_root) + + @property + def max_records(self) -> int: + return self._max_records + + @property + def max_cache_storage_gb(self) -> float: + return self._max_cache_storage_gb + + class BuildCache: - ''' + """ The BuildCache class is a class that manages the intermediate products from the build steps. NOTE: currently, only engine-building is supported TODO[chunweiy]: add support for other build steps, such as quantization, convert_checkpoint, etc. - ''' + """ # The version of the cache, will be used to determine if the cache is compatible CACHE_VERSION = 0 - def __init__(self, - cache_root: Optional[Path] = None, - max_records: int = 10, - max_cache_storage_gb: int = 256): - ''' - Args: - cache_root (Path): The root directory of the cache - max_records (int): The maximum number of records to keep in the cache - max_cache_storage_gb (int): The maximum storage size of the cache - ''' + def __init__(self, config: Optional[BuildCacheConfig] = None): + _, default_cache_root = get_build_cache_config_from_env() - self.cache_root = cache_root or Path(default_cache_root) - self.max_records = max_records - self.max_cache_storage_gb = max_cache_storage_gb + config = config or BuildCacheConfig() + + self.cache_root = config.cache_root or Path(default_cache_root) + self.max_records = config.max_records + self.max_cache_storage_gb = config.max_cache_storage_gb - if max_records < 1: + if config.max_records < 1: raise ValueError("max_records should be greater than 0") def get_engine_building_cache_stage(self, diff --git a/tensorrt_llm/hlapi/llm.py b/tensorrt_llm/hlapi/llm.py index dbba844b9..6e44ca435 100644 --- a/tensorrt_llm/hlapi/llm.py +++ b/tensorrt_llm/hlapi/llm.py @@ -14,13 +14,7 @@ external_mpi_comm_available) from .tokenizer import TokenizerBase # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import -from .utils import (SamplingParams, exception_handler, get_device_count, - init_log_level) - -# This should be called before importing the following cpp-runtime modules -init_log_level() - -from ..executor import GenerationExecutor, GenerationResult +from .utils import SamplingParams, exception_handler, get_device_count class RequestOutput(GenerationResult): @@ -83,10 +77,6 @@ def __init__(self, kwargs: Contains the optional arguments for expert users, please refer to `llm_utils.LlmArgs` for more details. ''' - # TODO[chunweiy]: Add API docs - - # TODO[chunweiy]: Deal with model_dir - try: self.args = LlmArgs.from_kwargs( model=model, @@ -230,12 +220,7 @@ def _prepare_sampling_params( raise ValueError( "tokenizer is required to reset end_id if it is None, or you can explicitly specify the end_id for sampling_params" ) - sampling_params.end_id = self.tokenizer.eos_token_id - if self.tokenizer.pad_token_id is not None: - sampling_params.pad_id = self.tokenizer.pad_token_id - else: - sampling_params.pad_id = self.tokenizer.eos_token_id - return sampling_params + return sampling_params.setup(self.tokenizer) else: raise TypeError( f"The sampling_params must be type SamplingParams or None, but got {type(sampling_params)}" diff --git a/tensorrt_llm/hlapi/llm_utils.py b/tensorrt_llm/hlapi/llm_utils.py index fec343611..5eaaa82cc 100644 --- a/tensorrt_llm/hlapi/llm_utils.py +++ b/tensorrt_llm/hlapi/llm_utils.py @@ -13,6 +13,7 @@ 'ContextChunkingPolicy', 'CapacitySchedulerPolicy', 'BuildConfig', + 'BuildCacheConfig', 'QuantConfig', 'CachedModelLoader', 'ConfigArbitrateError', @@ -51,7 +52,7 @@ from ..models import MODEL_MAP from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig from ..module import Module -from .build_cache import (BuildCache, CachedStage, +from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, get_build_cache_config_from_env) from .mpi_session import MPINodeState, MpiSession from .tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory @@ -223,7 +224,7 @@ class LlmArgs: batching_type (BatchingType, optional): The batching type for the model. Default is None. - enable_build_cache (str or bool, optional): Whether to enable build caching for the model. + enable_build_cache (bool or BuildCacheConfig, optional): Whether to enable build caching for the model. Default is None. enable_tqdm (bool, default=False): Whether to display a progress bar during model building. @@ -279,7 +280,7 @@ class LlmArgs: batching_type: Optional[BatchingType] = None # Once set, the model will reuse the build_cache - enable_build_cache: Optional[str | bool] = None + enable_cache_config: Union[BuildCacheConfig, bool] = False # Display the model building progress bar enable_tqdm: bool = False @@ -358,6 +359,13 @@ def setup(self): self._setup_embedding_parallel_mode() + if self.enable_cache_config: + self.enable_cache_config = BuildCacheConfig() if isinstance( + self.enable_cache_config, bool) else self.enable_cache_config + if not isinstance(self.enable_cache_config, BuildCacheConfig): + raise ValueError( + f"Invalid build_cache_config: {self.enable_cache_config}") + if self.is_local_model: # Load parallel_config from the engine. self.model_format = ModelLoader.get_model_format(self.model_dir) @@ -939,7 +947,6 @@ def workspace(self) -> str: def model_format(self) -> _ModelFormatKind: return self._model_format - # TODO[tali]: Replace this with a lower-level API @staticmethod def save( model: _ModelRuntimeContext, @@ -1131,9 +1138,6 @@ def load_extra_build_configs_from_engine( with open(Path(model_dir) / "config.json", "r") as f: engine_config = json.load(f) - # TODO[chunweiy]: Remove the following if-check after the engine config is unified. - if 'build_config' not in engine_config: - return None build_config = engine_config['build_config'] build_config.pop("plugin_config") return Namespace(**build_config) @@ -1223,18 +1227,16 @@ def get_engine_dir(self) -> Path: def build_cache_enabled(self) -> bool: _enable_build_cache, _ = get_build_cache_config_from_env() - return (self.llm_args.enable_build_cache or _enable_build_cache) and ( + return (self.llm_args.enable_cache_config or _enable_build_cache) and ( self.llm_args.model_format is _ModelFormatKind.HF) def _get_engine_cache_stage(self) -> CachedStage: ''' Get the cache stage fir engine building. ''' - _, _build_cache_root = get_build_cache_config_from_env() - build_cache_root = Path(self.llm_args.enable_build_cache if isinstance( - self.llm_args.enable_build_cache, str) else _build_cache_root) + assert self.llm_args.enable_cache_config - build_cache = BuildCache(build_cache_root) + build_cache = BuildCache(self.llm_args.enable_cache_config) assert self._hf_model_dir is not None, "HF model dir is required for cache key." dummy_build_config = CachedModelLoader.get_final_build_config( diff --git a/tensorrt_llm/hlapi/utils.py b/tensorrt_llm/hlapi/utils.py index 6f8fe9b11..de0b23dc1 100644 --- a/tensorrt_llm/hlapi/utils.py +++ b/tensorrt_llm/hlapi/utils.py @@ -5,10 +5,10 @@ import tempfile import traceback import weakref -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union import filelock import huggingface_hub @@ -17,7 +17,7 @@ from tqdm.auto import tqdm from tensorrt_llm.bindings import executor as tllme -from tensorrt_llm.logger import Singleton, set_level +from tensorrt_llm.logger import Singleton def print_traceback_on_error(func): @@ -42,8 +42,11 @@ class SamplingParams: end_id (int): The end token id. pad_id (int): The pad token id. max_new_tokens (int): The maximum number of tokens to generate. - bad_words (List[List[int]]): A list of bad words tokens. Each "word" can be composed of multiple tokens. - stop_words (List[List[int]]): A list of stop words tokens. Each "word" can be composed of multiple tokens. + bad (Union[str, List[str]]): A string or a list of strings that redirect the generation when they are generated, so that the bad strings are excluded from the returned output. + bad_token_ids (List[int]): A list of token ids that redirect the generation when they are generated, so that the bad ids are excluded from the returned output. + stop (Union[str, List[str]]): A string or a list of strings that stop the generation when they are generated. The returned output will not contain the stop strings unless include_stop_str_in_output is True. + stop_token_ids (List[int]): A list of token ids that stop the generation when they are generated. + include_stop_str_in_output (bool): Whether to include the stop strings in output text. Defaults to False. embedding_bias (torch.Tensor): The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size]. external_draft_tokens_config (ExternalDraftTokensConfig): The speculative decoding configuration. prompt_tuning_config (PromptTuningConfig): The prompt tuning configuration. @@ -85,8 +88,19 @@ class SamplingParams: end_id: Optional[int] = None pad_id: Optional[int] = None max_new_tokens: int = 32 - bad_words: Optional[List[List[int]]] = None - stop_words: Optional[List[List[int]]] = None + + bad: Optional[Union[str, List[str]]] = None + bad_token_ids: Optional[List[int]] = None + _bad_word_ids: Optional[List[List[int]]] = field(default=None, + init=False, + repr=False) + stop: Optional[Union[str, List[str]]] = None + stop_token_ids: Optional[List[int]] = None + include_stop_str_in_output: bool = False + _stop_word_ids: Optional[List[List[int]]] = field(default=None, + init=False, + repr=False) + embedding_bias: Optional[torch.Tensor] = None external_draft_tokens_config: Optional[ tllme.ExternalDraftTokensConfig] = None @@ -123,7 +137,60 @@ def __post_init__(self): if self.pad_id is None: self.pad_id = self.end_id - def _get_sampling_config(self): + def setup(self, + tokenizer, + add_special_tokens: bool = False) -> 'SamplingParams': + if self.end_id is None: + self.end_id = tokenizer.eos_token_id + self.pad_id = tokenizer.pad_token_id + if self.pad_id is None: + self.pad_id = self.end_id + + if self.bad is not None: + strs = [self.bad] if isinstance(self.bad, str) else self.bad + self._bad_word_ids = [ + tokenizer.encode(s, add_special_tokens=add_special_tokens) + for s in strs + ] + + if self.stop is not None: + strs = [self.stop] if isinstance(self.stop, str) else self.stop + self._stop_word_ids = [ + tokenizer.encode(s, add_special_tokens=add_special_tokens) + for s in strs + ] + + return self + + def _get_bad_words(self) -> List[List[int]]: + words = [] + if self.bad_token_ids is not None: + words = [[i] for i in self.bad_token_ids] + + if self.bad is None: + return words + else: + if self._bad_word_ids is None: + raise RuntimeError( + f"{self.__class__.__name__}.bad ({self.bad}) is not processed by tokenizer, " + "please call the setup method.") + return words + self._bad_word_ids + + def _get_stop_words(self) -> List[List[int]]: + words = [] + if self.stop_token_ids is not None: + words = [[i] for i in self.stop_token_ids] + + if self.stop is None: + return words + else: + if self._stop_word_ids is None: + raise RuntimeError( + f"{self.__class__.__name__}.stop ({self.stop}) is not processed by tokenizer, " + "please call the setup method.") + return words + self._stop_word_ids + + def _get_sampling_config(self) -> tllme.SamplingConfig: expected_fields = [ "beam_width", "top_k", "top_p", "top_p_min", "top_p_reset_ids", "top_p_decay", "random_seed", "temperature", "min_length", @@ -143,7 +210,7 @@ def _get_sampling_config(self): **{f: getattr(self, f) for f in expected_fields}) - def _get_output_config(self): + def _get_output_config(self) -> tllme.OutputConfig: expected_fields = [ "return_log_probs", "return_context_logits", "return_generation_logits", "exclude_input_from_output", @@ -239,13 +306,6 @@ def is_directory_empty(directory: Path) -> bool: return not any(directory.iterdir()) -def init_log_level(): - ''' Set the log level if the environment variable is not set. ''' - if "TLLM_LOG_LEVEL" not in os.environ: - set_level("warning") - os.environ["TLLM_LOG_LEVEL"] = "WARNING" - - class ExceptionHandler(metaclass=Singleton): def __init__(self): diff --git a/tensorrt_llm/layers/__init__.py b/tensorrt_llm/layers/__init__.py index 3511baf59..d37fe4842 100644 --- a/tensorrt_llm/layers/__init__.py +++ b/tensorrt_llm/layers/__init__.py @@ -28,7 +28,7 @@ from .normalization import GroupNorm, LayerNorm, RmsNorm from .pooling import AvgPool2d from .recurrent import FusedRgLru, GroupedLinear, Recurrent, RgLru -from .ssm import Mamba +from .ssm import Mamba, Mamba2 __all__ = [ 'LayerNorm', @@ -65,6 +65,7 @@ 'MOE', 'MoeConfig', 'Mamba', + 'Mamba2', 'Recurrent', 'GroupedLinear', 'RgLru', diff --git a/tensorrt_llm/layers/ssm.py b/tensorrt_llm/layers/ssm.py index 63510c5fc..06d8221fb 100644 --- a/tensorrt_llm/layers/ssm.py +++ b/tensorrt_llm/layers/ssm.py @@ -22,6 +22,7 @@ from ..module import Module from ..parameter import Parameter from .linear import Linear +from .normalization import RmsNorm class MambaConv1d(Module): @@ -197,17 +198,132 @@ def forward(self, self.A.value, x_dbl, self.D.value, - z, host_request_types, last_token_ids, self.d_inner, self.d_state, self.dt_rank, - is_variable_B=True, - is_variable_C=True, delta_softplus=True, dtype=self.dtype, + z=z, + host_context_lengths=host_context_lengths, slot_mapping=slot_mapping) # out_proj out = self.out_proj(y) return out, conv_state, ssm_state + + +class Mamba2(Module): + + def __init__(self, + d_model, + d_inner, + d_state=16, + d_conv=4, + headdim=64, + ngroups=1, + chunk_size=256, + bias=False, + rmsnorm=True, + dtype=None): + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.d_inner = d_inner + self.headdim = headdim + self.ngroups = ngroups + self.chunk_size = chunk_size + self.rmsnorm = rmsnorm + self.dtype = dtype + assert self.d_inner % self.headdim == 0 + self.nheads = self.d_inner // self.headdim + + self.A = Parameter(shape=(self.nheads, ), dtype="float32") + self.D = Parameter(shape=(self.nheads, ), dtype="float32") + self.dt_bias = Parameter(shape=(self.nheads, ), dtype="float32") + + d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + self.in_proj = Linear(self.d_model, + d_in_proj, + bias=bias, + dtype=dtype, + gather_output=False) + + self.conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + self.conv1d = MambaConv1d(self.conv_dim, self.d_conv, self.dtype) + + if rmsnorm: + self.norm = RmsNorm(normalized_shape=self.d_inner, + eps=1e-5, + dtype=dtype) + + self.out_proj = Linear(self.d_inner, + self.d_model, + bias=bias, + dtype=dtype, + gather_output=False) + + def forward(self, + hidden_states: Tensor, + conv_state: Tensor, + ssm_state: Tensor, + host_request_types: Tensor, + last_token_ids: Tensor, + host_context_lengths: Optional[Tensor] = None, + slot_mapping: Optional[Tensor] = None, + conv_indices: Optional[Tensor] = None): + ''' + Parameters: + hidden_states: [B, L, D] or [T, D] + conv_state: [B, W, D_conv] or [1] of type int64 for paged state + ssm_state: [B, H, N, D] or [1] of type int64 for paged state + host_request_types: [B] + last_token_ids: [B] + host_context_lengths: [B] + slot_mapping: [B] + conv_indices: [B] + ''' + # in_proj + zxbcdt = self.in_proj(hidden_states) + z, xbc, dt = split(zxbcdt, [self.d_inner, self.conv_dim, self.nheads], + dim=-1) + + # conv1d + xbc_conv, conv_state = self.conv1d(xbc, conv_state, host_request_types, + last_token_ids, host_context_lengths, + slot_mapping, conv_indices) + x_conv, bc = split(xbc_conv, + [self.d_inner, 2 * self.ngroups * self.d_state], + dim=-1) + + # mamba scan + y, ssm_state = selective_scan(x_conv, + ssm_state, + dt, + self.dt_bias.value, + self.A.value, + bc, + self.D.value, + host_request_types, + last_token_ids, + self.d_inner, + self.d_state, + dt_rank=0, + delta_softplus=True, + dtype=self.dtype, + z=z, + host_context_lengths=host_context_lengths, + slot_mapping=slot_mapping, + nheads=self.nheads, + ngroups=self.ngroups, + chunk_size=self.chunk_size, + mamba_version='Mamba2') + + # norm + if self.rmsnorm: + y = self.norm(y) + + # out_proj + out = self.out_proj(y) + return out, conv_state, ssm_state diff --git a/tensorrt_llm/models/cogvlm/model.py b/tensorrt_llm/models/cogvlm/model.py index 3b2122785..db2856fe8 100644 --- a/tensorrt_llm/models/cogvlm/model.py +++ b/tensorrt_llm/models/cogvlm/model.py @@ -21,7 +21,6 @@ Embedding, GatedMLP, PromptTuningEmbedding, RmsNorm) from ...mapping import Mapping from ...module import Module -from ...plugin import init_all_reduce_helper # this is to use to module global algo string with a quant_algo prefix from ...quantization import QuantMode from ...top_model_mixin import TopModelMixin @@ -141,7 +140,6 @@ class CogvlmModel(Module): def __init__(self, config: CogVLMConfig) -> None: super().__init__() - init_all_reduce_helper() self.mapping = config.mapping self.use_prompt_tuning = config.use_prompt_tuning diff --git a/tensorrt_llm/models/gpt/model.py b/tensorrt_llm/models/gpt/model.py index f25c47f72..74a970333 100644 --- a/tensorrt_llm/models/gpt/model.py +++ b/tensorrt_llm/models/gpt/model.py @@ -20,6 +20,7 @@ Embedding, GatedMLP, LayerNorm, MoeConfig, PositionEmbeddingType) from ...lora_manager import LoraConfig, use_lora +from ...mapping import Mapping from ...module import Module from ...quantization import QuantMode from ..modeling_utils import DecoderLayerList, DecoderModelForCausalLM @@ -34,7 +35,7 @@ def MLPFactory(hidden_size, moe_config: MoeConfig = MoeConfig(), tp_group=None, tp_size=1, - tp_rank=0, + mapping=Mapping(), quant_mode=QuantMode(0), inner_layernorm=False, eps=1e-05): @@ -43,11 +44,11 @@ def MLPFactory(hidden_size, hidden_size, ffn_hidden_size, hidden_act, - bias, - dtype, - tp_group, - tp_size, - tp_rank, + mapping=mapping, + bias=bias, + dtype=dtype, + tp_group=tp_group, + tp_size=tp_size, quant_mode=quant_mode) MLPClass = GatedMLP if is_gated_activation(hidden_act) else MLP hidden_act = non_gated_version(hidden_act) @@ -120,7 +121,7 @@ def __init__(self, config: GPTConfig, layer_idx: int): moe_config=config.moe, tp_group=tp_group, tp_size=tp_size, - tp_rank=tp_rank, + mapping=config.mapping, quant_mode=config.quant_mode, inner_layernorm=inner_layernorm, eps=config.norm_epsilon) diff --git a/tensorrt_llm/models/grok/model.py b/tensorrt_llm/models/grok/model.py index 9034400e7..7b77873d7 100644 --- a/tensorrt_llm/models/grok/model.py +++ b/tensorrt_llm/models/grok/model.py @@ -21,7 +21,6 @@ from ...lora_manager import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module -from ...plugin import init_all_reduce_helper from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, PretrainedConfig, QuantConfig) @@ -69,8 +68,10 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): mlp_kwargs = {} assert config.moe_num_experts > 1, "Grok model is a MoE model." ClsMLP = MOE + moe_config = MoeConfig(config.moe_num_experts, config.moe_top_k, + config.moe_normalization_mode).validate() mlp_kwargs = { - "moe_config": config.moe, + "moe_config": moe_config, "mapping": config.mapping, } self.mlp = ClsMLP(hidden_size=config.hidden_size, @@ -130,7 +131,6 @@ class GrokModel(Module): def __init__(self, config: PretrainedConfig) -> None: super().__init__() - init_all_reduce_helper() self.mapping = config.mapping if self.mapping.is_first_pp_rank(): diff --git a/tensorrt_llm/models/llama/config.py b/tensorrt_llm/models/llama/config.py index 3a40e9232..8bfbe3917 100644 --- a/tensorrt_llm/models/llama/config.py +++ b/tensorrt_llm/models/llama/config.py @@ -98,8 +98,13 @@ def from_hugging_face( if hf_config.model_type == "llava": # LLaVA = Vision model + Llama LLM # We load a llava config and use its' text config as llama config + from transformers import LlavaConfig hf_config = LlavaConfig.from_pretrained( hf_config_dir).text_config + if hf_config.model_type == "llava_next": + from transformers import LlavaNextConfig + hf_config = LlavaNextConfig.from_pretrained( + hf_config_dir).text_config if hf_config.model_type == "llava_llama": hf_config.llm_cfg["architecture"] = hf_config.llm_cfg[ "architectures"] diff --git a/tensorrt_llm/models/llama/convert.py b/tensorrt_llm/models/llama/convert.py index 9b67bcafc..4d96570a6 100644 --- a/tensorrt_llm/models/llama/convert.py +++ b/tensorrt_llm/models/llama/convert.py @@ -649,13 +649,16 @@ def load_hf_llama(model_dir: str, load_model_on_cpu: bool = False): if hf_config.model_type == "llava": from transformers import LlavaForConditionalGeneration model_cls = LlavaForConditionalGeneration + if hf_config.model_type == "llava_next": + from transformers import LlavaNextForConditionalGeneration + model_cls = LlavaNextForConditionalGeneration model = model_cls.from_pretrained( model_dir, device_map='auto' if not load_model_on_cpu else 'cpu', torch_dtype='auto', trust_remote_code=True, ) - if hf_config.model_type == "llava": + if hf_config.model_type in ["llava", "llava_next"]: model = model.language_model return model @@ -1221,7 +1224,7 @@ def quantize(hf_model_dir: str, assert hf_model_dir is not None ## only load and call smooth quant routine once for all ranks hf_config = AutoConfig.from_pretrained(hf_model_dir, trust_remote_code=True) - assert "llava" not in hf_config.model_type, "Smooth quant llava/vila is not supported yet" + assert "llava" not in hf_config.model_type, "Smooth quant llava/vila/llava_next is not supported yet" hf_model = AutoModelForCausalLM.from_pretrained( hf_model_dir, device_map='auto', diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index 4f1c8aef2..8e7063a8f 100644 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -23,7 +23,6 @@ from ...lora_manager import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module -from ...plugin import init_all_reduce_helper from ...quantization import W8A8_SQ_PLUGIN_LIST, QuantAlgo from ..convert_utils import has_safetensors from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, @@ -200,7 +199,6 @@ class LLaMAModel(Module): def __init__(self, config: LLaMAConfig) -> None: super().__init__() - init_all_reduce_helper() self.mapping = config.mapping if self.mapping.is_first_pp_rank(): diff --git a/tensorrt_llm/models/mamba/model.py b/tensorrt_llm/models/mamba/model.py index 32b2e1cc6..f984e99b4 100644 --- a/tensorrt_llm/models/mamba/model.py +++ b/tensorrt_llm/models/mamba/model.py @@ -21,7 +21,7 @@ from ..._utils import str_dtype_to_trt from ...functional import (Tensor, arange, cast, concat, expand, gather_last_token_logits, shape, unsqueeze) -from ...layers import Embedding, LayerNorm, Linear, Mamba, RmsNorm +from ...layers import Embedding, LayerNorm, Linear, Mamba, Mamba2, RmsNorm from ...module import Module, ModuleList from ...plugin import current_all_reduce_helper from ..generation_mixin import GenerationMixin @@ -30,18 +30,31 @@ class MambaLayer(Module): - def __init__(self, config: PretrainedConfig, last_layer=False): + def __init__(self, config: PretrainedConfig, layer_idx: int): super().__init__() self.dtype = config.dtype self.residual_in_fp32 = config.residual_in_fp32 - self.last_layer = last_layer - - self.ssm = Mamba(config.hidden_size, - config.rnn_hidden_size, - d_state=config.state_size, - d_conv=config.conv_kernel, - bias=config.use_bias, - dtype=config.dtype) + n_layer = config.num_hidden_layers + self.last_layer = layer_idx == n_layer - 1 + + if config.mamba_version == 'Mamba1': + self.ssm = Mamba(config.hidden_size, + config.rnn_hidden_size, + d_state=config.state_size, + d_conv=config.conv_kernel, + bias=config.use_bias, + dtype=config.dtype) + elif config.mamba_version == 'Mamba2': + self.ssm = Mamba2(config.hidden_size, + config.rnn_hidden_size, + d_state=config.state_size, + d_conv=config.conv_kernel, + headdim=config.rnn_head_size, + ngroups=config.ngroups, + chunk_size=config.chunk_size, + bias=config.use_bias, + rmsnorm=config.ssm_rmsnorm, + dtype=config.dtype) if config.rms_norm: self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, @@ -101,10 +114,8 @@ def __init__(self, config: PretrainedConfig): self.vocab_embedding = Embedding(config.vocab_size, config.hidden_size, dtype=config.dtype) - self.layers = ModuleList([ - MambaLayer(config, last_layer=i == n_layer - 1) - for i in range(n_layer) - ]) + self.layers = ModuleList( + [MambaLayer(config, i) for i in range(n_layer)]) if config.rms_norm: self.ln_f = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, @@ -166,9 +177,11 @@ def __init__(self, config: PretrainedConfig): self.dtype = dtype self.config = config + self.mamba_version = config.mamba_version self.d_inner = config.rnn_hidden_size self.d_conv = config.conv_kernel self.d_state = config.state_size + self.conv_dim = config.rnn_conv_dim_size self.gather_context_logits = False if isinstance(logits_dtype, str): @@ -294,20 +307,32 @@ def prepare_inputs( conv_state_dim_range = OrderedDict([ ('batch_size', ranges['bb_range']), ('kernel_size', [self.d_conv - 1] * num_profiles), - ('dim_size', [self.d_inner] * num_profiles), + ('dim_size', [self.conv_dim] * num_profiles), ]) else: conv_state_dim_range = OrderedDict([ ('batch_size', ranges['bb_range']), - ('dim_size', [self.d_inner] * num_profiles), + ('dim_size', [self.conv_dim] * num_profiles), ('kernel_size', [self.d_conv - 1] * num_profiles), ]) - ssm_state_dim_range = OrderedDict([ - ('batch_size', ranges['bb_range']), - ('state_size', [self.d_state] * num_profiles), - ('dim_size', [self.d_inner] * num_profiles), - ]) + if self.mamba_version == 'Mamba2': + headdim = self.config.rnn_head_size + nheads = self.d_inner // headdim + ssm_state_dim_range = OrderedDict([ + ('batch_size', ranges['bb_range']), + ('head_size', [nheads] * num_profiles), + ('state_size', [self.d_state] * num_profiles), + ('headdim_size', [headdim] * num_profiles), + ]) + ssm_state_shape = [-1, nheads, self.d_state, headdim] + else: + ssm_state_dim_range = OrderedDict([ + ('batch_size', ranges['bb_range']), + ('state_size', [self.d_state] * num_profiles), + ('dim_size', [self.d_inner] * num_profiles), + ]) + ssm_state_shape = [-1, self.d_state, self.d_inner] one_dim_range = OrderedDict([ ('buffer_count', [1] * num_profiles), ]) @@ -328,18 +353,18 @@ def prepare_inputs( conv_state = Tensor( name=f'past_conv_state_{i}', dtype=self.dtype, - shape=[-1, self.d_conv - 1, self.d_inner], + shape=[-1, self.d_conv - 1, self.conv_dim], dim_range=conv_state_dim_range) else: conv_state = Tensor( name=f'past_conv_state_{i}', dtype=self.dtype, - shape=[-1, self.d_inner, self.d_conv - 1], + shape=[-1, self.conv_dim, self.d_conv - 1], dim_range=conv_state_dim_range) ssm_state = Tensor(name=f'past_rnn_state_{i}', dtype=self.dtype, - shape=[-1, self.d_state, self.d_inner], + shape=ssm_state_shape, dim_range=ssm_state_dim_range) conv_states.append(conv_state) @@ -352,7 +377,7 @@ def prepare_inputs( dim_range=OrderedDict([('batch_size', ranges['bb_range'])]), ) - if use_mamba_conv1d_plugin and remove_input_padding: + if remove_input_padding: host_context_lengths = Tensor( name='host_context_lengths', dtype=trt.int32, diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index e0e2ef9f7..949186413 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -26,6 +26,7 @@ from ..mapping import Mapping from ..module import Module, ModuleList from ..parameter import Parameter +from ..plugin import init_all_reduce_helper from ..quantization import QuantMode from ..quantization.layers import (WeightOnlyGroupwiseQuantLinear, WeightOnlyGroupwiseQuantRowLinear, @@ -370,6 +371,7 @@ class PretrainedModel(Module, def __init__(self, config: PretrainedConfig): super().__init__() + init_all_reduce_helper() self.config = config def __post_init__(self): diff --git a/tensorrt_llm/models/phi/convert.py b/tensorrt_llm/models/phi/convert.py index 72f924e88..e72837a8d 100644 --- a/tensorrt_llm/models/phi/convert.py +++ b/tensorrt_llm/models/phi/convert.py @@ -3,7 +3,7 @@ from ..._utils import str_dtype_to_torch -def convert_hf_weights(hf_model, dtype, **kwargs): +def convert_hf_weights(hf_model, dtype, args=None): torch_dtype = str_dtype_to_torch(dtype) hf_state_dict = hf_model.state_dict() weights = {} @@ -45,7 +45,7 @@ def convert_hf_weights(hf_model, dtype, **kwargs): return weights -def convert_hf_config(hf_config, dtype, **kwargs): +def convert_hf_config(hf_config, dtype, args): config = { 'architecture': hf_config.architectures[0], 'dtype': dtype, @@ -59,5 +59,10 @@ def convert_hf_config(hf_config, dtype, **kwargs): 'max_position_embeddings': hf_config.max_position_embeddings, 'hidden_act': hf_config.hidden_act, 'share_embedding_table': False, + 'mapping': { + 'world_size': args.tp_size * args.pp_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + } } return config diff --git a/tensorrt_llm/models/phi/model.py b/tensorrt_llm/models/phi/model.py index 8a5d7f942..7f26488fd 100644 --- a/tensorrt_llm/models/phi/model.py +++ b/tensorrt_llm/models/phi/model.py @@ -166,15 +166,15 @@ def convert_hf_checkpoint(cls, hf_model_dir: str, dtype: Optional[str] = "float16", output_dir: Optional[str] = None, - **kwargs): + args=None): ''' Convert Huggingface checkpoint to TRT-LLM checkpoint ''' hf_model = AutoModelForCausalLM.from_pretrained(hf_model_dir, torch_dtype="auto", trust_remote_code=True) - config = convert_hf_config(hf_model.config, dtype=dtype, **kwargs) - weights = convert_hf_weights(hf_model, dtype=dtype, **kwargs) + config = convert_hf_config(hf_model.config, dtype, args) + weights = convert_hf_weights(hf_model, dtype, args) if output_dir: save_checkpoint(output_dir, config=config, weights=weights) diff --git a/tensorrt_llm/models/qwen/model.py b/tensorrt_llm/models/qwen/model.py index 5d4f1e8f5..05ae7623e 100644 --- a/tensorrt_llm/models/qwen/model.py +++ b/tensorrt_llm/models/qwen/model.py @@ -15,6 +15,7 @@ from typing import Optional +from tensorrt_llm.layers import MoeConfig from tensorrt_llm.lora_manager import LoraConfig, use_lora from ..._utils import pad_vocab_size @@ -62,10 +63,13 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): ClsMLP = GatedMLP mlp_kwargs = {} + + moe_config = MoeConfig(config.moe_num_experts, config.moe_top_k, + config.moe_normalization_mode).validate() if config.qwen_type == 'qwen2_moe': ClsMLP = MOE mlp_kwargs = { - "moe_config": config.moe, + "moe_config": moe_config, "mapping": config.mapping, } diff --git a/tensorrt_llm/models/recurrentgemma/model.py b/tensorrt_llm/models/recurrentgemma/model.py index 53a6c4502..1f91ebd0b 100644 --- a/tensorrt_llm/models/recurrentgemma/model.py +++ b/tensorrt_llm/models/recurrentgemma/model.py @@ -554,7 +554,7 @@ def prepare_inputs( if use_gpt_attention_plugin and remove_input_padding: host_context_lengths = attention_inputs['host_context_lengths'] - elif use_mamba_conv1d_plugin and remove_input_padding: + elif remove_input_padding: host_context_lengths = Tensor( name='host_context_lengths', dtype=trt.int32, diff --git a/tensorrt_llm/quantization/quantize_by_modelopt.py b/tensorrt_llm/quantization/quantize_by_modelopt.py index ee60e45c7..81d510252 100644 --- a/tensorrt_llm/quantization/quantize_by_modelopt.py +++ b/tensorrt_llm/quantization/quantize_by_modelopt.py @@ -177,14 +177,21 @@ def get_model(ckpt_path, dtype="fp16", device="cuda"): raise NotImplementedError(f"Unknown dtype {dtype}") # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo + hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=True) + model_cls = AutoModelForCausalLM + if hf_config.model_type == "llava": + from transformers import LlavaForConditionalGeneration + model_cls = LlavaForConditionalGeneration if "vila" in ckpt_path: model = _get_vila_model(ckpt_path) else: - model = AutoModelForCausalLM.from_pretrained( + model = model_cls.from_pretrained( ckpt_path, device_map="auto" if device != "cpu" else "cpu", torch_dtype="auto", trust_remote_code=True) + if hf_config.model_type == "llava": + model = model.language_model model.eval() model_dtype = next(model.parameters()).dtype diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index cb03e3f66..986c6bd87 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -306,7 +306,7 @@ def _set_shape(self, context: trt.IExecutionContext, shape_dict: Dict[str, List[int]]): for i in range(self.engine.num_io_tensors): name = self.engine.get_tensor_name(i) - if not name in shape_dict: + if name not in shape_dict: # shape and buffer can be set by calling _set_tensors API continue if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: @@ -340,7 +340,7 @@ def _set_tensors(self, context: trt.IExecutionContext, for name in self.input_tensor_names: # it's allowed to call set_tensors multi times with different tensors # each time only set some of the engine tensors, so it is valid to skip the ones not in the current given tensors dict - if not name in tensors: + if name not in tensors: continue tensor = tensors[name] @@ -351,7 +351,7 @@ def _set_tensors(self, context: trt.IExecutionContext, context.set_input_shape(name, tensor.shape) for name in self.output_tensor_names: - if not name in tensors: + if name not in tensors: dtype = self.engine.get_tensor_dtype(name) shape = context.get_tensor_shape(name) tensors[name] = RuntimeTensor.from_torch( @@ -467,6 +467,8 @@ class ModelConfig: conv_kernel: int = 0 layer_types: List[str] = field(default_factory=list) rnn_hidden_size: int = 0 + rnn_head_size: int = 0 + rnn_conv_dim_size: int = 0 state_size: int = 0 state_dtype: str = "" gpu_weights_percent: float = 1.0 @@ -1014,6 +1016,14 @@ def conv_kernel(self): def rnn_hidden_size(self): return self._model_config.rnn_hidden_size + @property + def rnn_head_size(self): + return self._model_config.rnn_head_size + + @property + def rnn_conv_dim_size(self): + return self._model_config.rnn_conv_dim_size + @property def state_size(self): return self._model_config.state_size @@ -1540,7 +1550,9 @@ def setup(self, kv_cache_type = torch.int8 else: if self.has_attn_layers: - first_atten_layer = self.layer_types.index('attention') + first_atten_layer = self.layer_types[ + self.first_layer:self.last_layer].index( + 'attention') + self.first_layer kv_cache_type = self.dtype if self.paged_kv_cache else self._tensor_dtype( f'present_key_value_{first_atten_layer}') else: @@ -1633,20 +1645,28 @@ def setup(self, conv_state_shape = ( batch_size, self.conv_kernel - 1, - self.rnn_hidden_size, + self.rnn_conv_dim_size, ) else: conv_state_shape = ( batch_size, - self.rnn_hidden_size, + self.rnn_conv_dim_size, self.conv_kernel - 1, ) - rnn_state_shape = ( - batch_size, - self.state_size, - self.rnn_hidden_size, - ) + if self.rnn_head_size > 1: + rnn_state_shape = ( + batch_size, + self.rnn_hidden_size // self.rnn_head_size, + self.state_size, + self.rnn_head_size, + ) + else: + rnn_state_shape = ( + batch_size, + self.state_size, + self.rnn_hidden_size, + ) for i in range(self.first_layer, self.last_layer): if self.layer_types[i] == 'recurrent': @@ -1873,9 +1893,9 @@ def add_tensor_with_shape(x, name, shape): dtype = self._tensor_dtype(f'present_conv_state_{idx}') if self.use_mamba_conv1d_plugin: conv_state_shape = (batch_size, self.conv_kernel - 1, - self.rnn_hidden_size) + self.rnn_conv_dim_size) else: - conv_state_shape = (batch_size, self.rnn_hidden_size, + conv_state_shape = (batch_size, self.rnn_conv_dim_size, self.conv_kernel - 1) conv_state = torch.zeros(conv_state_shape, @@ -1920,7 +1940,7 @@ def add_tensor_with_shape(x, name, shape): host_request_types = torch.zeros_like(context_lengths, device='cpu').int() add_tensor(host_request_types, 'host_request_types') - if self.use_mamba_conv1d_plugin and self.remove_input_padding: + if self.remove_input_padding: add_tensor(host_context_lengths, 'host_context_lengths') if self.has_attn_layers: add_tensor(attention_mask, 'attention_mask') @@ -2151,9 +2171,9 @@ def add_tensor_with_shape(x, name, shape): # conv state if self.use_mamba_conv1d_plugin: conv_state_shape = (batch_size, self.conv_kernel - 1, - self.rnn_hidden_size) + self.rnn_conv_dim_size) else: - conv_state_shape = (batch_size, self.rnn_hidden_size, + conv_state_shape = (batch_size, self.rnn_conv_dim_size, self.conv_kernel - 1) if step % 2: add_tensor_with_shape( @@ -2211,7 +2231,7 @@ def add_tensor_with_shape(x, name, shape): host_request_types = torch.ones_like(context_lengths, device='cpu').int() add_tensor(host_request_types, 'host_request_types') - if self.use_mamba_conv1d_plugin and self.remove_input_padding: + if self.remove_input_padding: add_tensor(host_context_lengths_local, 'host_context_lengths') if self.has_attn_layers: diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index 2b3eb1b8d..fe2ffbe32 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -439,7 +439,7 @@ def from_engine(cls, rnn_config_items = [ 'conv_kernel', 'layer_types', 'rnn_hidden_size', 'state_size', - 'state_dtype' + 'state_dtype', 'rnn_head_size', 'rnn_conv_dim_size' ] rnn_configs_kwargs = {} for item in rnn_config_items: diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 812329a68..f1083c07d 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -63,26 +63,28 @@ def __init__(self, executor: trtllm.Executor, max_batch_size: int, self.world_config = world_config @classmethod - def from_dir(cls, - engine_dir: str, - *, - lora_dir: Optional[str] = None, - rank: int = 0, - max_batch_size: Optional[int] = None, - max_input_len: Optional[int] = None, - max_output_len: Optional[int] = None, - max_beam_width: Optional[int] = None, - max_attention_window_size: Optional[int] = None, - sink_token_length: Optional[int] = None, - kv_cache_free_gpu_memory_fraction: Optional[float] = None, - medusa_choices: list[list[int]] | None = None, - debug_mode: bool = False, - lora_ckpt_source: str = "hf", - gpu_weights_percent: float = 1, - max_tokens_in_paged_kv_cache: int | None = None, - kv_cache_enable_block_reuse: bool = False, - enable_chunked_context: bool = False, - is_enc_dec: bool = False) -> 'ModelRunnerCpp': + def from_dir( + cls, + engine_dir: str, + *, + lora_dir: Optional[str] = None, + rank: int = 0, + max_batch_size: Optional[int] = None, + max_input_len: Optional[int] = None, + max_output_len: Optional[int] = None, + max_beam_width: Optional[int] = None, + max_attention_window_size: Optional[int] = None, + sink_token_length: Optional[int] = None, + kv_cache_free_gpu_memory_fraction: Optional[float] = None, + medusa_choices: list[list[int]] | None = None, + debug_mode: bool = False, + lora_ckpt_source: str = "hf", + gpu_weights_percent: float = 1, + max_tokens_in_paged_kv_cache: int | None = None, + kv_cache_enable_block_reuse: bool = False, + enable_chunked_context: bool = False, + is_enc_dec: bool = False, + ) -> 'ModelRunnerCpp': """ Create a ModelRunnerCpp instance from an engine directory. @@ -182,15 +184,6 @@ def from_dir(cls, json_config = GptJsonConfig.parse_file(config_path) model_config = json_config.model_config - if max_batch_size is None: - max_batch_size = model_config.max_batch_size - if max_input_len is None: - max_input_len = model_config.max_input_len - if max_output_len is None: - max_output_len = model_config.max_seq_len - model_config.max_input_len - if max_beam_width is None: - max_beam_width = model_config.max_beam_width - # Note: Parallel configuration will be fetched automatically from trtllm.Executor constructor # by inspecting the json file. These lines serve the purpose of serving vocab_size_padded and # num_layers properties. @@ -221,8 +214,10 @@ def from_dir(cls, assert max_batch_size <= model_config.max_batch_size if max_input_len is None: max_input_len = model_config.max_input_len - else: - assert max_input_len <= model_config.max_input_len + # NOTE{pengyunl}: remove assertion here for temp fix, + # model_config.max_input_len is not the upper bound of input length. + # If runtime max_input_len is not properly set, + # C++ runtime will throw an error when fetching new requests if max_output_len is None: max_seq_len = model_config.max_seq_len else: @@ -341,7 +336,6 @@ def generate(self, output_cum_log_probs: bool = False, prompt_table: Optional[Union[str, torch.Tensor]] = None, prompt_tasks: Optional[str] = None, - return_all_generated_tokens: bool = False, **kwargs) -> Union[torch.Tensor, dict]: """ Generates sequences of token ids. @@ -367,8 +361,6 @@ def generate(self, Custom stopping criteria. logits_processor (LogitsProcessor): Custom logits processors. - return_all_generated_tokens (bool): - Whether the full output is returned at each streaming step kwargs (Dict[str, Any]: Ad hoc parametrization of sampling_config. The passed **kwargs matching the sampling_config's attributes will override them. @@ -443,20 +435,18 @@ def generate(self, len(batch_input_ids_list)) requests = [ - trtllm.Request( - input_token_ids=input_ids, - encoder_input_token_ids=encoder_input_ids_list[i] - if encoder_input_ids is not None else None, - max_new_tokens=max_new_tokens, - pad_id=pad_id, - end_id=end_id, - stop_words=stop_words, - bad_words=bad_words, - sampling_config=sampling_config, - streaming=streaming, - output_config=output_config, - prompt_tuning_config=prompt_tuning_config, - return_all_generated_tokens=return_all_generated_tokens) + trtllm.Request(input_token_ids=input_ids, + encoder_input_token_ids=encoder_input_ids_list[i] + if encoder_input_ids is not None else None, + max_new_tokens=max_new_tokens, + pad_id=pad_id, + end_id=end_id, + stop_words=stop_words, + bad_words=bad_words, + sampling_config=sampling_config, + streaming=streaming, + output_config=output_config, + prompt_tuning_config=prompt_tuning_config) for i, (input_ids, stop_words, bad_words, prompt_tuning_config) in enumerate( zip(batch_input_ids_list, stop_words_list, @@ -466,16 +456,17 @@ def generate(self, request_ids = self.session.enqueue_requests(requests) if not streaming: - return self._initialize_and_fill_output( - request_ids, end_id, return_dict, output_sequence_lengths, - output_log_probs, output_cum_log_probs, batch_input_ids, - streaming, return_all_generated_tokens) + return self._initialize_and_fill_output(request_ids, end_id, + return_dict, + output_sequence_lengths, + output_log_probs, + output_cum_log_probs, + batch_input_ids, streaming) else: return self._stream(request_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, - streaming, batch_input_ids_list, - return_all_generated_tokens) + streaming, batch_input_ids_list) def _prepare_words_list(self, words_list: List[List[List[int]]], batch_size: int): @@ -509,7 +500,7 @@ def _prepare_ptuning_executor(self, batch_input_ids_list, prompt_table, def _initialize_and_fill_output(self, request_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, - streaming, return_all_generated_tokens): + streaming): output_ids = [[] for _ in range(len(request_ids))] for reqid_pos in range(len(request_ids)): output_ids[reqid_pos] = [[] for _ in range(self.max_beam_width)] @@ -522,12 +513,11 @@ def _initialize_and_fill_output(self, request_ids, end_id, return_dict, return self._fill_output(responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, - streaming, request_ids, - return_all_generated_tokens) + streaming, request_ids) def _stream(self, request_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, - streaming, batch_input_ids_list, return_all_generated_tokens): + streaming, batch_input_ids_list): output_ids = [[] for _ in range(len(request_ids))] for reqid_pos in range(len(request_ids)): output_ids[reqid_pos] = [ @@ -546,13 +536,12 @@ def _stream(self, request_ids, end_id, return_dict, output_sequence_lengths, yield self._fill_output(responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, - streaming, request_ids, - return_all_generated_tokens) + streaming, request_ids) def _fill_output(self, responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, streaming, - request_ids, return_all_generated_tokens): + request_ids): cuda_device = torch.device("cuda") for response in responses: @@ -562,10 +551,7 @@ def _fill_output(self, responses, output_ids, end_id, return_dict, reqid_pos = request_ids.index(response.request_id) for beam, output_tokens in enumerate( response.result.output_token_ids): - if return_all_generated_tokens: - output_ids[reqid_pos][beam] = output_tokens - else: - output_ids[reqid_pos][beam] += output_tokens + output_ids[reqid_pos][beam] += output_tokens sequence_lengths = [] for output in output_ids: diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index ac57fa0b2..1ad5a5465 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.12.0.dev2024070200" +__version__ = "0.12.0.dev2024070900" diff --git a/tests/bindings/binding_test_utils.py b/tests/bindings/binding_test_utils.py index d3dda1d18..e34316005 100644 --- a/tests/bindings/binding_test_utils.py +++ b/tests/bindings/binding_test_utils.py @@ -40,6 +40,19 @@ def prepare_model_tests( run_command(generate_expected_output, cwd=llm_root, env=model_env) +def prepare_lora_configs(llm_root: _pl.Path, resource_path: _pl.Path, + lora_config_path: _pl.Path): + python_exe = _sys.executable + generate_lora_data_args_tp1 = [ + python_exe, + str(resource_path / "scripts" / "generate_test_lora_weights.py"), + f"--out-dir={str(lora_config_path)}", "--tp-size=1", + "--hidden-size=768", "--num-layers=12", "--config-ids-filter=0", + "--no-generate-cache-pages" + ] + run_command(generate_lora_data_args_tp1, cwd=llm_root) + + def sequence_lengths(sequences: _np.ndarray, pad_id: int) -> _np.ndarray: return _np.apply_along_axis(lambda x: _np.searchsorted(x, True), 1, sequences == pad_id).astype("int32") diff --git a/tests/bindings/test_executor_bindings.py b/tests/bindings/test_executor_bindings.py index a9b06201c..77b4045cf 100644 --- a/tests/bindings/test_executor_bindings.py +++ b/tests/bindings/test_executor_bindings.py @@ -22,7 +22,7 @@ @pytest.fixture -def model_files(llm_root: Path, resource_path: Path, results_data_path): +def model_files(llm_root: Path, resource_path: Path, results_data_path: Path): # Model engines and expected outputs need to be generated. if not results_data_path.exists(): model_cache = llm_models_root() @@ -31,6 +31,14 @@ def model_files(llm_root: Path, resource_path: Path, results_data_path): prepare_model_tests(llm_root, resource_path, "gpt", model_cache_arg) +@pytest.fixture +def lora_config_paths(llm_root: Path, resource_path: Path, + lora_config_path: Path): + if not lora_config_path.exists(): + prepare_lora_configs(llm_root, resource_path, lora_config_path) + return (lora_config_path / "source.npy", lora_config_path / "config.npy") + + def get_expected_num_tokens(prompt_len, max_new_tokens, streaming, exclude_input_from_output): if not streaming and not exclude_input_from_output: @@ -203,6 +211,61 @@ def test_single_request(streaming: bool, exclude_input_from_output: bool, executor.get_latest_request_stats() +@skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture +def test_single_request_lora(model_files, model_path_lora, lora_config_paths): + streaming = False + exclude_input_from_output = False + output_config = trtllm.OutputConfig() + output_config.exclude_input_from_output = exclude_input_from_output + + # Create executor + beam_width = 1 + + peft_cache_config = trtllm.PeftCacheConfig(num_put_workers=4, + num_ensure_workers=4) + executor_config = trtllm.ExecutorConfig(1, + peft_cache_config=peft_cache_config) + executor = trtllm.Executor(model_path_lora, trtllm.ModelType.DECODER_ONLY, + executor_config) + + # Create the request + max_new_tokens = 5 + input_tokens = [1, 2, 3, 4] + lora_weights = torch.tensor(np.load(lora_config_paths[0])).half() + lora_config = torch.tensor(np.load(lora_config_paths[1])) + request = trtllm.Request(input_tokens, + max_new_tokens, + streaming, + trtllm.SamplingConfig(), + output_config, + lora_config=trtllm.LoraConfig( + 0, lora_weights, lora_config)) + + # Enqueue the request + request_id = executor.enqueue_request(request) + + # Get the new tokens + tokens = [] + done = False + i = 0 + max_wait_ms = 10000 + while not done and i < max_wait_ms: + wait_time = datetime.timedelta(milliseconds=1) + responses = executor.await_responses(request_id, wait_time) + for response in responses: + assert not response.has_error( + ), f"Request id {request_id} failed with err {response.error_msg}" + result = response.result + done = result.is_final + new_tokens = result.output_token_ids[beam_width - 1] + tokens.extend(new_tokens) + i += 1 + assert i < max_wait_ms + assert len(tokens) == get_expected_num_tokens( + len(input_tokens), max_new_tokens, streaming, + exclude_input_from_output), f"{request_id}" + + @pytest.mark.parametrize("streaming", [False, True]) @pytest.mark.parametrize("exclude_input_from_output", [False]) @skip_pre_ampere # ContextFMHAType with fp32 acc is not supported in pre-ampere architecture @@ -1239,14 +1302,3 @@ def test_executor_config_pickle(): assert config.max_beam_width == config_copy.max_beam_width assert config.scheduler_config.capacity_scheduler_policy == config_copy.scheduler_config.capacity_scheduler_policy assert config.kv_cache_config.enable_block_reuse == config_copy.kv_cache_config.enable_block_reuse - - -def test_return_full_tokens(): - max_new_tokens = 5 - input_tokens = [1, 2, 3, 4] - request = trtllm.Request(input_tokens, max_new_tokens, False, - trtllm.SamplingConfig()) - request.return_all_generated_tokens = True - assert request.return_all_generated_tokens == True - request.return_all_generated_tokens = False - assert request.return_all_generated_tokens == False diff --git a/tests/functional/test_selective_scan.py b/tests/functional/test_selective_scan.py index a9c8b7201..e1b7aa6f7 100644 --- a/tests/functional/test_selective_scan.py +++ b/tests/functional/test_selective_scan.py @@ -18,16 +18,19 @@ from itertools import product import numpy as np +import pytest import torch +from einops import rearrange, repeat from parameterized import parameterized -from torch_ref import selective_scan_ref, selective_state_update_ref +from torch_ref import (selective_scan_ref, selective_state_update_ref, + ssd_chunk_scan_combined_ref) import tensorrt_llm from tensorrt_llm import Tensor from tensorrt_llm._utils import str_dtype_to_torch sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import skip_bf16_pre_ampere, unittest_name_func +from utils.util import getSMVersion, skip_bf16_pre_ampere, unittest_name_func class TestFunctional(unittest.TestCase): @@ -49,8 +52,6 @@ def test_selective_scan(self, dim, dstate, req_type, dtype, batch_size, device = "cuda" seq_len = max_seq_len if req_type == 'context' else 1 dt_rank = 160 - is_variable_B = True - is_variable_C = True delta_softplus = True # test data @@ -59,6 +60,7 @@ def test_selective_scan(self, dim, dstate, req_type, dtype, batch_size, last_token_ids = torch.randint(1, seq_len + 1, (batch_size, ), dtype=torch.int32) + host_context_lengths = last_token_ids.detach().clone().cpu() last_token_ids = torch.cumsum(last_token_ids, dim=0, dtype=torch.int32).to(device) @@ -66,6 +68,7 @@ def test_selective_scan(self, dim, dstate, req_type, dtype, batch_size, else: last_token_ids = torch.ones( (batch_size, ), dtype=torch.int32, device=device) * seq_len + host_context_lengths = last_token_ids.detach().clone().cpu() total_num_tokens = batch_size * seq_len state = torch.randn(batch_size, dstate, @@ -152,11 +155,29 @@ def test_selective_scan(self, dim, dstate, req_type, dtype, batch_size, name='last_token_ids', shape=last_token_ids.shape, dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_context_lengths_tensor = None + if remove_padding: + host_context_lengths_tensor = Tensor( + name='host_context_lengths', + shape=host_context_lengths.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) outputs = tensorrt_llm.functional.selective_scan( - x_tensor, state_tensor, dt_tensor, dt_bias_tensor, A_tensor, - BC_tensor, D_tensor, z_tensor, host_request_types_tensor, - last_token_ids_tensor, dim, dstate, dt_rank, is_variable_B, - is_variable_C, delta_softplus, dtype) + x_tensor, + state_tensor, + dt_tensor, + dt_bias_tensor, + A_tensor, + BC_tensor, + D_tensor, + host_request_types_tensor, + last_token_ids_tensor, + dim, + dstate, + dt_rank, + delta_softplus, + dtype, + host_context_lengths=host_context_lengths_tensor, + z=z_tensor) net._mark_output(outputs[0], 'output', dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -177,6 +198,8 @@ def test_selective_scan(self, dim, dstate, req_type, dtype, batch_size, 'host_request_types': host_request_types, 'last_token_ids': last_token_ids } + if remove_padding: + inputs['host_context_lengths'] = host_context_lengths outputs = {'output': output, 'present_state': state} stream = torch.cuda.current_stream() builder_config = builder.create_builder_config(precision=dtype, ) @@ -239,3 +262,290 @@ def test_selective_scan(self, dim, dstate, req_type, dtype, batch_size, np.testing.assert_allclose(state_ref.to(torch.float32).cpu().numpy(), present_state_cpu.numpy(), atol=dtype_atol[dtype]) + + @parameterized.expand(list( + product([2048], [64], ['context', 'generation'], + ['float32', 'float16', 'bfloat16'], [3], [16], [True, False], + [True, False])), + name_func=unittest_name_func) + def test_selective_scan_v2(self, dim, headdim, req_type, dtype, batch_size, + max_seq_len, has_z, remove_padding): + + # Skip tests that are not supported + skip_bf16_pre_ampere(dtype) + if dtype == 'float32' and req_type == 'context': + pytest.skip( + "Mamba2 chunk scan kernel only support float16 and bfloat16") + if getSMVersion() < 80: + pytest.skip("Mamba2 is not supported in pre-Ampere architecture") + + # configs + device = "cuda" + seq_len = max_seq_len if req_type == 'context' else 1 + dstate = 128 + chunk_size = 256 + nheads = dim // headdim + ngroups = 1 + delta_softplus = True + mean = 0.0 + std_dev = 0.5 if dtype == "float32" else 0.1 + + # test data + torch.random.manual_seed(0) + if remove_padding: + last_token_ids = torch.randint(1, + seq_len + 1, (batch_size, ), + dtype=torch.int32) + host_context_lengths = last_token_ids.detach().clone().cpu() + last_token_ids = torch.cumsum(last_token_ids, + dim=0, + dtype=torch.int32).to(device) + total_num_tokens = last_token_ids[batch_size - 1] + else: + last_token_ids = torch.ones( + (batch_size, ), dtype=torch.int32, device=device) * seq_len + host_context_lengths = last_token_ids.detach().clone().cpu() + total_num_tokens = batch_size * seq_len + state = torch.empty(batch_size, + nheads, + dstate, + headdim, + device=device, + dtype=str_dtype_to_torch(dtype)) + x = torch.empty(total_num_tokens, + dim, + device=device, + dtype=str_dtype_to_torch(dtype)) + x.normal_(mean, std_dev) + state.normal_(mean, std_dev) + dt = torch.randn(total_num_tokens, + nheads, + device=device, + dtype=str_dtype_to_torch(dtype)) + dt_bias = torch.rand(nheads, device=device) - 4.0 + A = -torch.rand(nheads, device=device) - 1.0 + BC = torch.randn(total_num_tokens, + ngroups * dstate * 2, + device=device, + dtype=str_dtype_to_torch(dtype)) + D = torch.randn(nheads, device=device) + if has_z: + z = torch.randn_like(x) + host_request_types = torch.tensor([0 if req_type == 'context' else 1] * + batch_size, + dtype=torch.int32) + if not remove_padding or req_type == 'generation': + x = x.view(-1, seq_len, dim) + dt = dt.view(-1, seq_len, nheads) + BC = BC.view(-1, seq_len, ngroups * dstate * 2) + if has_z: + z = z.view(-1, seq_len, dim) + output = torch.zeros(x.shape, + device=device, + dtype=str_dtype_to_torch(dtype)) + + state_ref = state.detach().clone() + x_ref = x.detach().clone() + dt_ref = dt.detach().clone() + dt_bias_ref = dt_bias.detach().clone() + A_ref = A.detach().clone() + B_ref = BC[..., 0:ngroups * dstate].detach().clone() + C_ref = BC[..., ngroups * dstate:].detach().clone() + D_ref = D.detach().clone() + z_ref = z.detach().clone() if has_z else None + + # construct trt network + builder = tensorrt_llm.Builder() + net = builder.create_network() + if remove_padding: + net.plugin_config.remove_input_padding = True + else: + net.plugin_config.remove_input_padding = False + net.plugin_config.paged_state = False + with tensorrt_llm.net_guard(net): + x_tensor = Tensor(name='input', + shape=x.shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + state_tensor = Tensor(name='state', + shape=state.shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + dt_tensor = Tensor(name='delta', + shape=dt.shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + dt_bias_tensor = Tensor( + name='delta_bias', + shape=dt_bias.shape, + dtype=tensorrt_llm.str_dtype_to_trt('float32')) + A_tensor = Tensor(name='A', + shape=A.shape, + dtype=tensorrt_llm.str_dtype_to_trt('float32')) + BC_tensor = Tensor(name='BC', + shape=BC.shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + D_tensor = Tensor(name='D', + shape=D.shape, + dtype=tensorrt_llm.str_dtype_to_trt('float32')) + if has_z: + z_tensor = Tensor(name='z', + shape=z.shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + host_request_types_tensor = Tensor( + name='host_request_types', + shape=host_request_types.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + last_token_ids_tensor = Tensor( + name='last_token_ids', + shape=last_token_ids.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_context_lengths_tensor = None + if remove_padding: + host_context_lengths_tensor = Tensor( + name='host_context_lengths', + shape=host_context_lengths.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + outputs = tensorrt_llm.functional.selective_scan( + x_tensor, + state_tensor, + dt_tensor, + dt_bias_tensor, + A_tensor, + BC_tensor, + D_tensor, + host_request_types_tensor, + last_token_ids_tensor, + dim, + dstate, + 0, + delta_softplus, + dtype, + z=z_tensor if has_z else None, + host_context_lengths=host_context_lengths_tensor, + nheads=nheads, + ngroups=ngroups, + chunk_size=chunk_size, + mamba_version='Mamba2') + net._mark_output(outputs[0], + 'output', + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + net._mark_output(outputs[1], + 'present_state', + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + + # trt run + inputs = { + 'input': x, + 'state': state, + 'delta': dt, + 'delta_bias': dt_bias, + 'A': A, + 'BC': BC, + 'D': D, + 'host_request_types': host_request_types, + 'last_token_ids': last_token_ids + } + if remove_padding: + inputs['host_context_lengths'] = host_context_lengths + if has_z: + inputs['z'] = z + inputs + outputs = {'output': output, 'present_state': state} + stream = torch.cuda.current_stream() + builder_config = builder.create_builder_config(precision=dtype, ) + engine = builder.build_engine(net, builder_config) + session = tensorrt_llm.runtime.Session.from_serialized_engine(engine) + session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream) + out_ref = torch.zeros(output.shape, + device=device, + dtype=str_dtype_to_torch(dtype)) + # pytorch run + if req_type == 'context': + if remove_padding: + for i in range(batch_size): + start = 0 if i == 0 else last_token_ids[i - 1] + end = last_token_ids[i] + x_reshaped = rearrange(x_ref[start:end].unsqueeze(0), + "b l (h p) -> b l h p", + p=headdim) + B_ref_reshaped = rearrange(B_ref[start:end].unsqueeze(0), + "b l (g n) -> b l g n", + g=ngroups) + C_ref_reshaped = rearrange(C_ref[start:end].unsqueeze(0), + "b l (g n) -> b l g n", + g=ngroups) + z_ref_reshaped = rearrange(z_ref[start:end].unsqueeze(0), + "b l (h p) -> b l h p", + p=headdim) if has_z else None + part_out_ref, part_state_ref = ssd_chunk_scan_combined_ref( + x_reshaped, + dt_ref[start:end].unsqueeze(0), + A_ref, + B_ref_reshaped, + C_ref_reshaped, + chunk_size, + D=D_ref, + z=z_ref_reshaped, + dt_bias=dt_bias_ref, + dt_softplus=delta_softplus) + part_out_ref = rearrange(part_out_ref, + "b l h p -> b l (h p)") + out_ref[start:end, ] = part_out_ref.squeeze(0) + state_ref[i, ] = part_state_ref.squeeze(0) + else: + x_reshaped = rearrange(x_ref, "b l (h p) -> b l h p", p=headdim) + B_ref_reshaped = rearrange(B_ref, + "b l (g n) -> b l g n", + g=ngroups) + C_ref_reshaped = rearrange(C_ref, + "b l (g n) -> b l g n", + g=ngroups) + z_ref_reshaped = rearrange( + z_ref, "b l (h p) -> b l h p", p=headdim) if has_z else None + out_ref, state_ref = ssd_chunk_scan_combined_ref( + x_reshaped, + dt_ref, + A_ref, + B_ref_reshaped, + C_ref_reshaped, + chunk_size, + D=D_ref, + z=z_ref_reshaped, + dt_bias=dt_bias_ref, + dt_softplus=delta_softplus) + out_ref = rearrange(out_ref, "b l h p -> b l (h p)") + elif req_type == 'generation': + A_ref = repeat(A_ref, "h -> h n p", p=headdim, + n=dstate).to(dtype=torch.float32) + dt_ref = repeat(dt_ref.squeeze(1), "b h -> b h p", p=headdim) + dt_bias_ref = repeat(dt_bias_ref, "h -> h p", p=headdim) + D_ref = repeat(D_ref, "h -> h p", p=headdim) + B_ref = rearrange(B_ref.squeeze(1), "b (g n) -> b g n", g=ngroups) + C_ref = rearrange(C_ref.squeeze(1), "b (g n) -> b g n", g=ngroups) + x_reshaped = rearrange(x_ref.squeeze(1), + "b (h p) -> b h p", + p=headdim) + if has_z: + z_ref = rearrange(z_ref.squeeze(1), + "b (h p) -> b h p", + p=headdim) + out_ref = selective_state_update_ref(state_ref, + x_reshaped, + dt_ref, + A_ref, + B_ref, + C_ref, + D=D_ref, + z=z_ref, + dt_bias=dt_bias_ref, + dt_softplus=delta_softplus) + out_ref = rearrange(out_ref, "b h p -> b (h p)").unsqueeze(1) + + dtype_atol = {"float16": 5e-3, "float32": 2e-3, "bfloat16": 5e-2} + + output_cpu = outputs['output'].to(torch.float32).cpu() + present_state_cpu = outputs['present_state'].to(torch.float32).cpu() + np.testing.assert_allclose(out_ref.to(torch.float32).cpu().numpy(), + output_cpu.numpy(), + atol=dtype_atol[dtype]) + np.testing.assert_allclose(state_ref.to(torch.float32).cpu().numpy(), + present_state_cpu.numpy(), + atol=dtype_atol[dtype]) diff --git a/tests/functional/torch_ref.py b/tests/functional/torch_ref.py index 1b6716ae7..decb37375 100644 --- a/tests/functional/torch_ref.py +++ b/tests/functional/torch_ref.py @@ -19,6 +19,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat +from transformers.models.llama.modeling_llama import LlamaRMSNorm def geglu(x): @@ -252,42 +253,261 @@ def selective_state_update_ref(state, dt_softplus=False): """ Argument: - state: (batch, dstate, dim) - x: (batch, dim) - dt: (batch, dim) - A: (dstate, dim) - B: (batch, dstate) - C: (batch, dstate) - D: (dim,) - z: (batch, dim) - dt_bias: (dim,) + state: (batch, dstate, dim) or (batch, nheads, dstate, dim) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dstate, dim) or (nheads, dstate, dim) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) Return: - out: (batch, dim) + out: (batch, dim) or (batch, nheads, dim) """ - batch, dstate, dim = state.shape - assert x.shape == (batch, dim) + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dstate, dim = state.shape + + assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape - assert A.shape == (dstate, dim) - assert B.shape == (batch, dstate) + assert A.shape == (nheads, dstate, dim) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) assert C.shape == B.shape + if D is not None: - assert D.shape == (dim, ) + assert D.shape == (nheads, dim) if z is not None: assert z.shape == x.shape if dt_bias is not None: - assert dt_bias.shape == (dim, ) + assert dt_bias.shape == (nheads, dim) dt = dt + dt_bias dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp(rearrange(dt, "b d -> b 1 d") * A) # (batch, dstate, dim) - dB = rearrange(dt, "b d -> b 1 d") * rearrange( - B.float(), "b n -> b n 1") # (batch, dstate, dim) - state_new = state * dA + dB * rearrange( - x, "b d -> b 1 d") # (batch, dstate, dim) + dA = torch.exp(rearrange(dt, "b h d -> b h 1 d") * + A) # (batch, nheads, dstate, dim) + B = repeat(B, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h 1 d") * rearrange( + B.float(), "b h n -> b h n 1") # (batch, nheads, dstate, dim) + state_new = state.float() * dA + dB * rearrange( + x.float(), "b h d -> b h 1 d") # (batch, nheads, dstate, dim) state.copy_(state_new.to(state.dtype)) - out = torch.einsum("bnd,bn->bd", state_new, C.float()) + out = torch.einsum("bhnd,bhn->bhd", state_new, C.float()) if D is not None: - out += x * D - return (out if z is None else out * F.silu(z.float())).to(x.dtype) + out += x.float() * D + out = (out if z is None else out * F.silu(z.float())).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out + + +def chunk_state_ref(B, x, dt, dA_cumsum): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + ngroups = B.shape[2] + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) + B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) + decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), + decay_states.to(x.dtype), dt.to(x.dtype), x) + + +def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): + """ + Argument: + states: (batch, nchunks, nheads, dim) + dA_chunk_cumsum: (batch, nheads, nchunks) + initial_states: (batch, nheads, dim) + Return: + out: (batch, nchunks, nheads, dim) + final_states: (batch, nheads, dim) + """ + if initial_states is None: + initial_states = torch.zeros_like(states[:, 0]) + states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], + dim=1) + dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0)) + dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1) + nchunks = dA_chunk_cumsum.shape[-1] + # (batch, nheads, nchunks, nchunks) + dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, + None] - dA_chunk_cumsum[:, :, + None, :] + # (batch, nheads, nchunks, nchunks) + decay_chunk = torch.exp(dt_chunk_segment_sum) + causal_mask = torch.tril(torch.ones(nchunks, + nchunks, + device=states.device, + dtype=bool), + diagonal=0) + decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) + out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), + states) + return out[:, :-1], out[:, -1] + + +def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert C.shape == B.shape + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + C = F.pad(C, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + if z is not None: + z = F.pad(z, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + CB = torch.einsum("bclhn,bcshn->bchls", + rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, + chunk_size, + device=x.device, + dtype=bool), + diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), + dt.to(x.dtype), + rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = torch.einsum('bclhn,bchpn->bclhp', + rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + prev_states.to(C.dtype)) * state_decay_out + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return (out if z is None else out * F.silu(z)).to(x.dtype) + + +def ssd_chunk_scan_combined_ref(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + final_states: (batch, nheads, dstate, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + if seqlen % chunk_size != 0: + dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) + mask = torch.zeros_like(dt) + mask[:, 0:seqlen, :] = 1 + dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) + mask = rearrange(mask, "b (c l) h -> b h c l", l=chunk_size) + dt = dt.float() # We want high precision for this before cumsum + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "h -> h 1 1") + if dt_softplus: + dt = F.softplus(dt) + dt = torch.clamp(dt, min=0) + dt = dt * mask + dA = dt * rearrange(A, "h -> h 1 1") + dA_cumsum = torch.cumsum(dA, dim=-1) + # 1. Compute the state for each chunk + states = chunk_state_ref(B, x, dt, dA_cumsum) + states_dtype = states.dtype + if states.dtype not in [torch.float32, torch.float64]: + states = states.to(torch.float32) + # 2. Pass the state to all the chunks by weighted cumsum. + # state_passing_ref is much less numerically stable + states, final_states = state_passing_ref( + rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]) + states, final_states = [ + rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states] + ] + states = states.to(states_dtype) + final_states = final_states.to(states_dtype) + final_states = final_states.permute(0, 1, 3, 2).contiguous() + # 3. Compute the output for each chunk + out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z) + if seqlen % chunk_size != 0: + out = out[:, 0:seqlen, :, :] + return out, final_states class mamba_ref(nn.Module): @@ -485,6 +705,191 @@ def step(self, hidden_states, conv_state, ssm_state): return out.unsqueeze(1), conv_state, ssm_state +class mamba2_ref(mamba_ref): + + def __init__(self, + d_model, + d_state=128, + d_conv=4, + expand=2, + headdim=64, + ngroups=1, + chunk_size=256, + conv_bias=True, + bias=False, + rmsnorm=True, + device=None, + dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(d_model, d_state, d_conv, expand, "auto", conv_bias, + bias, **factory_kwargs) + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.headdim = headdim + self.chunk_size = chunk_size + self.d_ssm = self.d_inner + self.ngroups = ngroups + assert self.d_ssm % self.headdim == 0 + self.nheads = self.d_ssm // self.headdim + self.rmsnorm = rmsnorm + self.group_d_state = self.ngroups * self.d_state + + d_in_proj = 2 * self.d_inner + 2 * self.group_d_state + self.nheads + self.in_proj = nn.Linear(self.d_model, + d_in_proj, + bias=bias, + **factory_kwargs) + + self.conv_dim = self.d_ssm + 2 * self.group_d_state + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=self.conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + self.act = nn.SiLU() + + self.out_proj = nn.Linear(self.d_inner, + self.d_model, + bias=bias, + **factory_kwargs) + + # dt_bias + dt_min, dt_max, dt_init_floor = 0.001, 0.1, 1e-4 + dt = torch.exp( + torch.rand(self.nheads, **factory_kwargs) * + (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + + # A + A_init_range = (1, 16) + A = torch.empty(self.nheads, dtype=torch.float32, + device=device).uniform_(*A_init_range) + A_log = torch.log(A) + self.A = nn.Parameter(-torch.exp(A_log.float())) + + # D + self.D = nn.Parameter(torch.ones(self.nheads, device=device)) + + # norm + if rmsnorm: + self.norm = LlamaRMSNorm(self.d_inner, eps=1e-5) + + def forward_impl(self, + hidden_states, + conv_state, + ssm_state, + seqlen_offset=0): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + _, seqlen, _ = hidden_states.shape + + if seqlen_offset > 0: + # The states are updated inplace + out, conv_state, ssm_state = self.step(hidden_states, conv_state, + ssm_state) + return out, conv_state, ssm_state + + # in_proj + zxbcdt = self.in_proj(hidden_states) + z, xBC, dt = torch.split(zxbcdt, + [self.d_ssm, self.conv_dim, self.nheads], + dim=-1) + + # Conv + if conv_state is not None: + xBC_t = rearrange(xBC, "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) + xBC = self.act( + self.conv1d(xBC.transpose(1, 2))[..., :seqlen].transpose(1, 2)) + x, B, C = torch.split( + xBC, [self.d_ssm, self.group_d_state, self.group_d_state], dim=-1) + + # chunk scan + y, last_state = ssd_chunk_scan_combined_ref( + rearrange(x, "b l (h p) -> b l h p", p=self.headdim), + dt, + self.A, + rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), + rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), + chunk_size=self.chunk_size, + D=self.D, + z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) + if not self.rmsnorm else None, + dt_bias=self.dt_bias, + dt_softplus=True) + y = rearrange(y, "b l h p -> b l (h p)") + ssm_state.copy_(last_state) + + # norm + if self.rmsnorm: + y = (self.norm(y.float() * self.act(z.float()))).to(y.dtype) + + # out_proj + out = self.out_proj(y) + return out, conv_state, ssm_state + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1 + + # in_proj + zxbcdt = self.in_proj(hidden_states.squeeze(1)) + z, xBC, dt = torch.split(zxbcdt, + [self.d_ssm, self.conv_dim, self.nheads], + dim=-1) + + # Conv step + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * + rearrange(self.conv1d.weight, "d 1 w -> d w"), + dim=-1) + if self.conv1d.bias is not None: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(dtype=dtype) + x, B, C = torch.split( + xBC, [self.d_ssm, self.group_d_state, self.group_d_state], dim=-1) + + # SSM step + A = repeat(self.A, "h -> h n p", p=self.headdim, + n=self.d_state).to(dtype=torch.float32) + dt = repeat(dt, "b h -> b h p", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) + D = repeat(self.D, "h -> h p", p=self.headdim) + B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) + C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) + x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + y = selective_state_update_ref(ssm_state, + x_reshaped, + dt, + A, + B, + C, + D=D, + z=z if not self.rmsnorm else None, + dt_bias=dt_bias, + dt_softplus=True) + y = rearrange(y, "b h p -> b (h p)") + if self.rmsnorm: + y = (self.norm(y.float() * self.act(z.float()))).to(y.dtype) + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def rnn_scan(x: torch.Tensor, a: torch.Tensor, reset: torch.Tensor, h0: torch.Tensor): """Runs the recurrence of a linear RNN.""" diff --git a/tests/hlapi/test_build_cache.py b/tests/hlapi/test_build_cache.py index 8ad02342d..44b4c291f 100644 --- a/tests/hlapi/test_build_cache.py +++ b/tests/hlapi/test_build_cache.py @@ -13,7 +13,7 @@ def test_BuildStep(): with TemporaryDirectory() as tempdir: - build_cache = BuildCache(cache_root=Path(tempdir)) + build_cache = BuildCache(BuildCacheConfig(Path(tempdir))) build_step = build_cache.get_engine_building_cache_stage( build_config=BuildConfig(), hf_model_name="test") assert not build_step.cache_hitted() @@ -32,7 +32,7 @@ def test_BuildStep(): def test_BuildCache_clean_untracked_path(): # The BuildCache could cleanup the untracked files/dirs within the cache_root with TemporaryDirectory() as tempdir: - build_cache = BuildCache(cache_root=Path(tempdir)) + build_cache = BuildCache(BuildCacheConfig(Path(tempdir))) (build_cache.cache_root / 'untracked').mkdir() (build_cache.cache_root / 'untracked_file').touch() @@ -43,7 +43,7 @@ def test_BuildCache_clean_untracked_path(): def test_BuildCache_clean_cache_exceed_record_limit(): # The BuildCache could cleanup the cache if the number of records exceed the limit with TemporaryDirectory() as tempdir: - build_cache = BuildCache(cache_root=Path(tempdir), max_records=2) + build_cache = BuildCache(BuildCacheConfig(Path(tempdir), max_records=2)) build_config = BuildConfig() def create_cache(hf_model_name: str): @@ -81,7 +81,7 @@ def test_build_cache_prune_untracked_files(): # The BuildCache could cleanup the untracked files/dirs within the cache_root # The broken cache such as empty cache record directory should be pruned as well with TemporaryDirectory() as tempdir: - build_cache = BuildCache(cache_root=Path(tempdir)) + build_cache = BuildCache(BuildCacheConfig(cache_root=Path(tempdir))) (build_cache.cache_root / 'untracked').mkdir() (build_cache.cache_root / 'untracked_file').touch() (build_cache.cache_root / 'broken_cache').mkdir() diff --git a/tests/hlapi/test_llm.py b/tests/hlapi/test_llm.py index a3bb12775..69f4e47f6 100644 --- a/tests/hlapi/test_llm.py +++ b/tests/hlapi/test_llm.py @@ -442,15 +442,32 @@ def test_generate_with_stop_words(): model=llama_model_path, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), ) + stop_id = llm.tokenizer.encode("N", add_special_tokens=False)[-1] - sampling_params = SamplingParams(max_new_tokens=6, stop_words=[[11]]) + sampling_params = SamplingParams(stop_token_ids=[stop_id]) + for output in llm.generate(prompts, sampling_params=sampling_params): + assert output.outputs[0].text == "D E F G H I J K L M" + sampling_params = SamplingParams(stop_token_ids=[stop_id], + include_stop_str_in_output=True) for output in llm.generate(prompts, sampling_params=sampling_params): - print(output) - assert output.outputs[0].text == "D E F G H I" + assert output.outputs[0].text == "D E F G H I J K L M N" + + sampling_params = SamplingParams(stop="I J") + for output in llm.generate(prompts, sampling_params=sampling_params): + assert output.outputs[0].text == "D E F G H" + + sampling_params = SamplingParams(stop="I J", + include_stop_str_in_output=True) + for output in llm.generate(prompts, sampling_params=sampling_params): + assert output.outputs[0].text == "D E F G H I J" + + sampling_params = SamplingParams(stop=["F E", "I J"], + stop_token_ids=[stop_id]) + for output in llm.generate(prompts, sampling_params=sampling_params): + assert output.outputs[0].text == "D E F G H" -@pytest.mark.skip("waive") @force_ampere def test_generate_with_bad_words(): llm = LLM( @@ -458,24 +475,21 @@ def test_generate_with_bad_words(): kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), ) - sampling_params = SamplingParams(max_new_tokens=6) + bad_id = llm.tokenizer.encode("N", add_special_tokens=False)[-1] - tokenizer = AutoTokenizer.from_pretrained(llama_model_path, - add_prefix_space=False) + sampling_params = SamplingParams(max_new_tokens=15, bad_token_ids=[bad_id]) + for output in llm.generate(prompts, sampling_params=sampling_params): + assert output.outputs[0].text == "D E F G H I J K L M\n\nI hope this" - # TODO[chunweiy]: Consider to make the generate api accept bad_words as a list of strings - bad_words = tokenizer(["H", "I"]).input_ids - bad_words = [row[1] for row in tokenizer(["H", "I"]).input_ids] - bad_words = [bad_words] - print('bad_words:', bad_words) - sampling_params.bad_words = bad_words + sampling_params = SamplingParams(max_new_tokens=15, bad="I J") + for output in llm.generate(prompts, sampling_params=sampling_params): + assert output.outputs[0].text == "D E F G H I K L M N O P Q R S" + sampling_params = SamplingParams(max_new_tokens=15, bad=["F E", "I J"]) for output in llm.generate(prompts, sampling_params=sampling_params): - print(output) - assert output.outputs[0].text == "D E F G HI" + assert output.outputs[0].text == "D E F G H I K L M N O P Q R S" -@pytest.mark.skip("waive") @force_ampere def test_generate_with_embedding_bias(): llm = LLM( @@ -483,28 +497,23 @@ def test_generate_with_embedding_bias(): kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), ) - sampling_params = SamplingParams(max_new_tokens=6) - - tokenizer = AutoTokenizer.from_pretrained(llama_model_path, - add_prefix_space=False) - biased_word_id = tokenizer(["Z"]).input_ids[0][1] - + biased_word_id = llm.tokenizer.encode("Z", add_special_tokens=False)[-1] vocab_size_padded = 32000 embedding_bias = torch.zeros(vocab_size_padded) embedding_bias[biased_word_id] = torch.finfo(torch.float32).max - sampling_params.embedding_bias = embedding_bias + + sampling_params = SamplingParams(max_new_tokens=6, + embedding_bias=embedding_bias) for output in llm.generate(prompts, sampling_params=sampling_params): print(output) assert output.outputs[0].text == "Z Z Z Z Z Z" -@pytest.mark.skip("waive") @force_ampere def test_generate_with_logits_post_processor(): - tokenizer = AutoTokenizer.from_pretrained(llama_model_path, - add_prefix_space=False) - biased_word_id = tokenizer(["Z"]).input_ids[0][1] + tokenizer = AutoTokenizer.from_pretrained(llama_model_path) + biased_word_id = tokenizer.encode("Z", add_special_tokens=False)[-1] def logits_post_processor(req_id: int, logits: torch.Tensor, ids: List[List[int]], stream_ptr: int): @@ -516,8 +525,8 @@ def logits_post_processor(req_id: int, logits: torch.Tensor, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), logits_post_processor_map={"my_logits_pp": logits_post_processor}) - sampling_params = SamplingParams(max_new_tokens=6) - sampling_params.logits_post_processor_name = "my_logits_pp" + sampling_params = SamplingParams(max_new_tokens=6, + logits_post_processor_name="my_logits_pp") for output in llm.generate(prompts, sampling_params=sampling_params): print(output) diff --git a/tests/hlapi/test_llm_download.py b/tests/hlapi/test_llm_download.py index 5d16f57c8..3159b9f4e 100644 --- a/tests/hlapi/test_llm_download.py +++ b/tests/hlapi/test_llm_download.py @@ -13,7 +13,7 @@ def test_llm_with_model_downloaded(): - llm = LLM(model=model_name, enable_build_cache=True) + llm = LLM(model=model_name, build_cache_config=True) for output in llm.generate(prompts): print(output) diff --git a/tests/model/test_mamba.py b/tests/model/test_mamba.py index 91c1cdfe7..0fc793908 100644 --- a/tests/model/test_mamba.py +++ b/tests/model/test_mamba.py @@ -63,9 +63,11 @@ def _gen_tensorrt_llm_mamba(self, hf_config, hf_path, hf_mamba, load_mode, 'hidden_act': 'silu', 'num_attention_heads': 1, 'rnn_hidden_size': hf_config.intermediate_size, + 'rnn_conv_dim_size': hf_config.intermediate_size, 'state_size': hf_config.state_size, 'conv_kernel': hf_config.conv_kernel, 'use_bias': hf_config.use_bias, + 'mamba_version': 'Mamba1', } config = tensorrt_llm.models.PretrainedConfig.from_dict(config) if load_mode == 'from_checkpoint': diff --git a/tests/model/test_phi.py b/tests/model/test_phi.py index 17fb126fa..9e5646c73 100644 --- a/tests/model/test_phi.py +++ b/tests/model/test_phi.py @@ -76,6 +76,7 @@ def initialize_network(self, network: tensorrt_llm.Network, hf_model, 'mapping': { 'world_size': tensor_parallel, 'tp_size': tensor_parallel, + 'world_size': tensor_parallel, }, 'use_parallel_embedding': False, 'embedding_sharding_dim': 0, diff --git a/tests/test_layer.py b/tests/test_layer.py index d407bd725..0d78606db 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -23,8 +23,8 @@ import torch import tensorrt as trt # isort: on -from functional.torch_ref import (attention_qkvpacked_ref, mamba_ref, - recurrent_ref) +from functional.torch_ref import (attention_qkvpacked_ref, mamba2_ref, + mamba_ref, recurrent_ref) from parameterized import parameterized from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile, TrtRunner) @@ -1035,26 +1035,10 @@ def test_attention(self, atol=a_tol, verbose=True) - @parameterized.expand([ - (3, 16, 1, 1024, 16, 'context', 'float32', False, True), - (3, 16, 1, 1024, 16, 'context', 'float16', False, True), - (3, 16, 1, 1024, 16, 'context', 'bfloat16', False, True), - (3, 16, 1, 1024, 16, 'generation', 'float32', False, True), - (3, 16, 1, 1024, 16, 'generation', 'float16', False, True), - (3, 16, 1, 1024, 16, 'generation', 'bfloat16', False, True), - (3, 16, 1, 1024, 16, 'context', 'float32', False, False), - (3, 16, 1, 1024, 16, 'context', 'float16', False, False), - (3, 16, 1, 1024, 16, 'context', 'bfloat16', False, False), - (3, 16, 1, 1024, 16, 'generation', 'float32', False, False), - (3, 16, 1, 1024, 16, 'generation', 'float16', False, False), - (3, 16, 1, 1024, 16, 'generation', 'bfloat16', False, False), - (3, 16, 1, 1024, 16, 'context', 'float32', True, True), - (3, 16, 1, 1024, 16, 'context', 'float16', True, True), - (3, 16, 1, 1024, 16, 'context', 'bfloat16', True, True), - (3, 16, 1, 1024, 16, 'generation', 'float32', True, True), - (3, 16, 1, 1024, 16, 'generation', 'float16', True, True), - (3, 16, 1, 1024, 16, 'generation', 'bfloat16', True, True), - ], + @parameterized.expand(list( + product([3], [16], [1], [1024], [16], ['context', 'generation'], + ["float32", "float16", "bfloat16"], [True, False], + [True, False])), name_func=unittest_name_func) def test_mamba(self, batch_size, in_seq_len, out_seq_len, d_model, d_state, req_type, dtype, remove_padding, use_plugin): @@ -1372,6 +1356,344 @@ def test_mamba(self, batch_size, in_seq_len, out_seq_len, d_model, d_state, ssm_state_trt_llm, atol=dtype_atol[dtype]) + @parameterized.expand(list( + product([3], [16], [1], [1024], [128], [64], [256], + ['context', 'generation'], ["float32", "float16", "bfloat16"], + [True, False], [True, False])), + name_func=unittest_name_func) + def test_mamba2(self, batch_size, in_seq_len, out_seq_len, d_model, d_state, + headdim, chunk_size, req_type, dtype, remove_padding, + use_plugin): + + # Skip tests that are not supported in pre-ampere architecture + skip_bf16_pre_ampere(dtype) + + if not use_plugin and remove_padding: + pytest.skip( + "Skipping remove input padding without mamba conv1d plugin") + if dtype == 'float32' and req_type == 'context': + pytest.skip( + "Mamba2 layer only support float16 and bfloat16 in context phase" + ) + if getSMVersion() < 80: + pytest.skip( + "Mamba2 layer is not supported in pre-Ampere architecture") + + # configs + device = "cuda" + d_conv = 4 + expand = 2 + ngroups = 1 + bias = False + rmsnorm = True + d_inner = int(expand * d_model) + nheads = d_inner // headdim + conv_dim = d_inner + 2 * ngroups * d_state + seqlen_offset = 0 if req_type == 'context' else in_seq_len + seq_len = in_seq_len if req_type == 'context' else out_seq_len + + # test data + torch_dtype = str_dtype_to_torch(dtype) + mean = 0.0 + std_dev = 0.05 if dtype == "float32" else 0.02 + torch.random.manual_seed(0) + + if req_type == 'context': + last_token_ids = torch.randint(1, + in_seq_len + 1, + size=(batch_size, ), + dtype=torch.int32, + device=device) + last_token_ids[0] = in_seq_len + host_context_lengths = last_token_ids.detach().clone().cpu() + else: + last_token_ids = torch.ones(size=[batch_size], + dtype=torch.int32, + device=device) + host_context_lengths = last_token_ids.detach().clone().cpu() + + if use_plugin: + trt_conv_state_shape = [batch_size, d_conv - 1, conv_dim] + conv_indices = torch.arange(0, + d_conv - 1, + dtype=torch.int32, + device=device).view([1, d_conv - 1, 1]) + else: + trt_conv_state_shape = [batch_size, conv_dim, d_conv - 1] + conv_indices = torch.arange(0, + d_conv - 1, + dtype=torch.int32, + device=device).view([1, 1, d_conv - 1]) + offsets = last_token_ids.view([batch_size, 1, 1]) + conv_indices = conv_indices.expand(trt_conv_state_shape) + offsets + + if remove_padding: + last_token_ids = torch.cumsum(last_token_ids, + dim=0, + dtype=torch.int32).to(device) + total_num_tokens = last_token_ids[batch_size - 1] + else: + total_num_tokens = batch_size * seq_len + + if remove_padding: + hidden_states = torch.empty(size=[total_num_tokens, d_model], + dtype=torch_dtype, + device=device) + output = torch.zeros(size=[total_num_tokens, d_model], + dtype=torch_dtype, + device=device) + else: + hidden_states = torch.empty(size=[batch_size, seq_len, d_model], + dtype=torch_dtype, + device=device) + output = torch.zeros(size=[batch_size, seq_len, d_model], + dtype=torch_dtype, + device=device) + hidden_states.normal_(mean, std_dev) + + if req_type == 'context': + conv_state = torch.zeros(size=[batch_size, conv_dim, d_conv - 1], + dtype=torch_dtype, + device=device) + else: + conv_state = torch.randn(size=[batch_size, conv_dim, d_conv - 1], + dtype=torch_dtype, + device=device) + if req_type == 'context': + ssm_state = torch.empty(size=[batch_size, nheads, d_state, headdim], + dtype=torch_dtype, + device=device) + else: + ssm_state = torch.randn(size=[batch_size, nheads, d_state, headdim], + dtype=torch_dtype, + device=device) + + host_request_types = torch.tensor([0 if req_type == 'context' else 1] * + batch_size, + dtype=torch.int32) + + present_conv_state = torch.zeros(size=trt_conv_state_shape, + dtype=torch_dtype, + device=device) + + hidden_states_ref = hidden_states.detach().clone() + out_ref = output.detach().clone() + if req_type == 'context': + conv_state_ref = torch.zeros(size=[batch_size, conv_dim, d_conv], + dtype=torch_dtype, + device=device).detach() + else: + conv_state_ref = torch.concat( + (torch.zeros(size=[batch_size, conv_dim, 1], + dtype=torch_dtype, + device=device), conv_state), + dim=2).detach().clone() + ssm_state_ref = ssm_state.detach().clone() + + # get torch layer + mamba2_torch = mamba2_ref(d_model, + d_state, + d_conv, + expand, + headdim, + ngroups, + chunk_size, + True, + bias, + rmsnorm=rmsnorm, + device=device, + dtype=torch_dtype) + + # init weights + for module in mamba2_torch.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)): + if module.bias is not None: + torch.nn.init.normal_(module.bias, std=std_dev) + torch.nn.init.normal_(module.weight, std=std_dev) + + A = -torch.rand(nheads, device=device) - 1.0 + D = torch.randn(nheads, device=device) + dt_bias = torch.rand(nheads, device=device) - 4.0 + norm_weight = torch.randn(d_inner, device=device) + + mamba2_torch.A.data = A.detach().clone() + mamba2_torch.D.data = D.detach().clone() + mamba2_torch.dt_bias.data = dt_bias.detach().clone() + mamba2_torch.norm.weight.data = norm_weight.detach().clone() + + # construct trt network + builder = tensorrt_llm.Builder() + net = builder.create_network() + if use_plugin: + net.plugin_config.mamba_conv1d_plugin = dtype + else: + net.plugin_config.mamba_conv1d_plugin = None + if remove_padding: + net.plugin_config.remove_input_padding = True + else: + net.plugin_config.remove_input_padding = False + net.plugin_config.paged_state = False + + with tensorrt_llm.net_guard(net): + hidden_states_tensor = Tensor( + name='hidden_states', + shape=hidden_states.shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + conv_state_tensor = Tensor( + name='conv_state', + shape=trt_conv_state_shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + ssm_state_tensor = Tensor( + name='ssm_state', + shape=ssm_state.shape, + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + host_request_types_tensor = Tensor( + name='host_request_types', + shape=host_request_types.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + last_token_ids_tensor = Tensor( + name='last_token_ids', + shape=last_token_ids.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_context_lengths_tensor = Tensor( + name='host_context_lengths', + shape=host_context_lengths.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + conv_indices_tensor = Tensor( + name='conv_indices', + shape=trt_conv_state_shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + mamba2_layer = tensorrt_llm.layers.Mamba2(d_model=d_model, + d_inner=d_inner, + d_state=d_state, + d_conv=d_conv, + headdim=headdim, + ngroups=ngroups, + chunk_size=chunk_size, + bias=bias, + rmsnorm=rmsnorm, + dtype=dtype) + mamba2_layer.A.value = torch_to_numpy(A.detach().cpu()) + mamba2_layer.D.value = torch_to_numpy(D.detach().cpu()) + mamba2_layer.dt_bias.value = torch_to_numpy(dt_bias.detach().cpu()) + mamba2_layer.norm.weight.value = torch_to_numpy( + norm_weight.detach().cpu()) + mamba2_layer.in_proj.weight.value = torch_to_numpy( + mamba2_torch.in_proj.weight.detach().cpu()) + mamba2_layer.out_proj.weight.value = torch_to_numpy( + mamba2_torch.out_proj.weight.detach().cpu()) + if bias: + mamba2_layer.in_proj.bias.value = torch_to_numpy( + mamba2_torch.in_proj.bias.detach().cpu()) + mamba2_layer.out_proj.bias.value = torch_to_numpy( + mamba2_torch.out_proj.bias.detach().cpu()) + mamba2_layer.conv1d.weight.value = torch_to_numpy( + mamba2_torch.conv1d.weight.detach().unsqueeze(3).cpu()) + mamba2_layer.conv1d.bias.value = torch_to_numpy( + mamba2_torch.conv1d.bias.detach().cpu()) + if rmsnorm: + mamba2_layer.norm.weight.value = torch_to_numpy( + mamba2_torch.norm.weight.detach().cpu()) + + outputs = mamba2_layer( + hidden_states_tensor, + conv_state_tensor, + ssm_state_tensor, + host_request_types_tensor, + last_token_ids_tensor, + host_context_lengths=host_context_lengths_tensor, + conv_indices=conv_indices_tensor) + net._mark_output(outputs[0], + 'output', + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + net._mark_output(outputs[1], + 'present_conv_state', + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + net._mark_output(outputs[2], + 'present_ssm_state', + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + + if use_plugin: + trt_conv_state = conv_state.permute(0, 2, 1).contiguous() + else: + trt_conv_state = conv_state.clone().detach() + trt_conv_indices = conv_indices.clone().detach() + # trt run + inputs = { + 'hidden_states': hidden_states, + 'conv_state': trt_conv_state, + 'ssm_state': ssm_state, + 'host_request_types': host_request_types, + 'last_token_ids': last_token_ids, + 'host_context_lengths': host_context_lengths, + 'conv_indices': trt_conv_indices, + } + outputs = { + 'output': output, + 'present_conv_state': present_conv_state, + 'present_ssm_state': ssm_state, + } + + stream = torch.cuda.current_stream() + builder_config = builder.create_builder_config(name='mamba2', + opt_level=0, + precision=dtype) + engine = builder.build_engine(net, builder_config) + session = tensorrt_llm.runtime.Session.from_serialized_engine(engine) + session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream) + + # pytorch run + out_ref, conv_state_ref, ssm_state_ref = mamba2_torch( + hidden_states_ref, last_token_ids, conv_state_ref, ssm_state_ref, + remove_padding, batch_size, seqlen_offset) + + dtype_atol = {"float16": 5e-3, "float32": 5e-3, "bfloat16": 5e-2} + + if not remove_padding: + # get out_mask + if req_type == 'context': + out_mask = torch.zeros(batch_size, seq_len, device=device) + for i in range(batch_size): + for j in range(last_token_ids[i]): + out_mask[i, j] = 1 + out_mask = out_mask.unsqueeze(2).expand( + [batch_size, seq_len, d_model]) + else: + out_mask = torch.ones(batch_size, + seq_len, + d_model, + device=device) + + # compare out diff + out_ref = (out_ref * out_mask).detach().to( + torch.float32).cpu().numpy() + outputs['output'][out_mask == 0] = 0 + else: + out_ref = out_ref.detach().to(torch.float32).cpu().numpy() + + out_trt_llm = outputs['output'].to(torch.float32).cpu().numpy() + np.testing.assert_allclose(out_ref, out_trt_llm, atol=dtype_atol[dtype]) + + # compare conv state diff + conv_state_ref = conv_state_ref[:, :, 1:].detach().to( + torch.float32).cpu().numpy() + conv_state_trt_llm = outputs['present_conv_state'] + if use_plugin: + conv_state_trt_llm = conv_state_trt_llm.permute(0, 2, + 1).contiguous() + conv_state_trt_llm = conv_state_trt_llm.to(torch.float32).cpu().numpy() + np.testing.assert_allclose(conv_state_ref, + conv_state_trt_llm, + atol=dtype_atol[dtype]) + + # compare ssm state diff + ssm_state_ref = ssm_state_ref.detach().to(torch.float32).cpu().numpy() + ssm_state_trt_llm = outputs['present_ssm_state'] + ssm_state_trt_llm = ssm_state_trt_llm.to(torch.float32).cpu().numpy() + np.testing.assert_allclose(ssm_state_ref, + ssm_state_trt_llm, + atol=dtype_atol[dtype]) + @parameterized.expand(list( product([3], [16], [1], [1280], [1280], [10], ['context', 'generation'], ["float32", "float16", "bfloat16"], [True, False], diff --git a/tests/utils/cpp_paths.py b/tests/utils/cpp_paths.py index 7aef5fff7..1fff1e9fa 100644 --- a/tests/utils/cpp_paths.py +++ b/tests/utils/cpp_paths.py @@ -41,6 +41,16 @@ def model_path_return_logits(engine_path): return engine_path / "gpt2/fp16-plugin-packed-paged-gather/tp1-pp1-gpu" +@pytest.fixture +def model_path_lora(engine_path: _pl.Path) -> _pl.Path: + return engine_path / "gpt2/fp16-plugin-packed-paged-lora/tp1-pp1-gpu" + + +@pytest.fixture +def lora_config_path(data_path: _pl.Path) -> _pl.Path: + return data_path / "lora-test-weights-gpt2-tp1" + + @pytest.fixture(scope="module") def results_data_path(data_path: _pl.Path) -> _pl.Path: return data_path / "gpt2/sampling/output_tokens_fp16_plugin_packed_paged_tp1_pp1.npy"