diff --git a/.gitignore b/.gitignore index 3b1694b43..324dd2af1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ *.nsys-rep .VSCodeCounter build*/ +!builders/ *.egg-info/ .coverage *.onnx diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 02253b7c4..f54001cf0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,5 +46,5 @@ repos: args: - --skip=".git,3rdparty" - --exclude-file=examples/whisper/tokenizer.py - - --ignore-words-list=rouge,inout,atleast,strat,nd + - --ignore-words-list=rouge,inout,atleast,strat,nd,subtile exclude: 'tests/llm-test-defs/turtle/test_input_files' diff --git a/README.md b/README.md index 9d55c6108..df0e42f06 100644 --- a/README.md +++ b/README.md @@ -75,3 +75,6 @@ To get started with TensorRT-LLM, visit our documentation: - [Installation Guide for Linux](https://nvidia.github.io/TensorRT-LLM/installation/linux.html) - [Installation Guide for Windows](https://nvidia.github.io/TensorRT-LLM/installation/windows.html) - [Supported Hardware, Models, and other Software](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html) + +## Community +- [Model zoo](https://huggingface.co/TheFloat16) (generated by TRT-LLM rel 0.9 a9356d4b7610330e89c1010f342a9ac644215c52) diff --git a/benchmarks/cpp/README.md b/benchmarks/cpp/README.md index a7cc4520d..f459df432 100644 --- a/benchmarks/cpp/README.md +++ b/benchmarks/cpp/README.md @@ -210,8 +210,10 @@ TP=2 PP=1 MAX_LEN=1024 MAX_BATCH=32 -MAX_LORA_RANK=32 +NUM_LAYERS=40 +MAX_LORA_RANK=64 NUM_LORA_MODS=7 +EOS_ID=2 SOURCE_LORA=chinese-llama-2-lora-13b CPP_LORA=chinese-llama-2-lora-13b-cpp @@ -234,7 +236,7 @@ ${HOME}/.local/bin/trtllm-build \ --gemm_plugin float16 \ --lora_plugin float16 \ --use_paged_context_fmha enable \ - --lora_target_modules attn_qkv \ + --lora_target_modules attn_q attn_k attn_v attn_dense mlp_h_to_4h mlp_4h_to_h mlp_gate \ --max_lora_rank ${MAX_LORA_RANK} NUM_LORAS=(8 16 24 32 64 128 256) @@ -252,8 +254,6 @@ mkdir -p $EG_DIR/data # Prepare dataset without lora_task_id python benchmarks/cpp/prepare_dataset.py \ --output "${EG_DIR}/data/token-norm-dist.json" \ - --request-rate -1 \ - --time-delay-dist constant \ --tokenizer $TOKENIZER \ token-norm-dist \ --num-requests $NUM_REQUESTS \ @@ -263,8 +263,6 @@ python benchmarks/cpp/prepare_dataset.py \ for nloras in ${NUM_LORAS[@]}; do python benchmarks/cpp/prepare_dataset.py \ --output "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \ - --request-rate -1 \ - --time-delay-dist constant \ --rand-task-id 0 $(( $nloras - 1 )) \ --tokenizer $TOKENIZER \ token-norm-dist \ @@ -292,7 +290,7 @@ mpirun -n ${TP} --output-filename ${EG_DIR}/log-base-lora \ # Now run inference with various numbers or loras # The host cache is set large enough to hold all the LoRAs in lora_dir -# GPU cache is set to hold 32 LoRAs +# GPU cache is set to hold 16 LoRAs # This benchmark will preload all the LoRAs into the host cache # We run inference on a range of active LoRAs exercising different cache miss rates. for nloras in ${NUM_LORAS[@]}; do @@ -303,7 +301,7 @@ for nloras in ${NUM_LORAS[@]}; do --type IFB \ --dataset "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \ --lora_host_cache_bytes 8589934592 \ - --lora_num_device_mod_layers $(( 32 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \ + --lora_num_device_mod_layers $(( 16 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \ --kv_cache_free_gpu_mem_fraction 0.80 \ --log_level info \ --eos_id ${EOS_ID} \ diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index d7f664c19..00033ef16 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -458,10 +458,6 @@ class Recorder { this->recordEnd(requestId, hasError); - if (mRespJsonFile.empty()) - return; - int32_t outputSeqLen; - for (auto& tensor : responseTensors) { if (tensor.name == inference_request::kOutputIdsTensorName) @@ -471,7 +467,7 @@ class Recorder else if (tensor.name == inference_request::kSequenceLengthTensorName) { // Tensor of shape nBeams, and we only need the first one - outputSeqLen = *(bufferCast(*(tensor.tensor))); + int32_t outputSeqLen = *(bufferCast(*(tensor.tensor))); if (mOutputHasInput) { int inputSeqLen = mRequestBenchInfos[requestId].inputLength; @@ -482,6 +478,30 @@ class Recorder } } + void recordEnd(uint64_t requestId, texec::Response const& response) + { + + this->recordEnd(requestId, response.hasError()); + + // Get the actual output length + if (!response.hasError()) + { + auto outputTokenIds = response.getResult().outputTokenIds; + + int32_t outSeqLen = 0; + for (auto const& beam : outputTokenIds) + { + outSeqLen = std::max(static_cast(beam.size()), outSeqLen); + } + if (mOutputHasInput) + { + int inputSeqLen = mRequestBenchInfos[requestId].inputLength; + outSeqLen -= inputSeqLen; + } + mRequestBenchInfos[requestId].outputLength = outSeqLen; + } + } + float calcPercentile(std::vector const& latencies, int percentile) { int const index = static_cast(std::ceil((percentile / 100.0) * latencies.size())) - 1; @@ -827,7 +847,7 @@ class ExecutorServer numFinished++; if (!warmup) { - mRecorder->recordEnd(reqId, response.hasError()); + mRecorder->recordEnd(reqId, response); } } } diff --git a/benchmarks/cpp/gptSessionBenchmark.cpp b/benchmarks/cpp/gptSessionBenchmark.cpp index 1505fc493..b4eced7f9 100644 --- a/benchmarks/cpp/gptSessionBenchmark.cpp +++ b/benchmarks/cpp/gptSessionBenchmark.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -213,6 +214,7 @@ void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector std::vector latencies; std::vector generationTimes; auto generationProfiler = std::make_shared(); + cudaProfilerStart(); while (iterIdx < numRuns) { auto const start = std::chrono::steady_clock::now(); @@ -242,6 +244,7 @@ void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector break; } } + cudaProfilerStop(); TLLM_LOG_INFO(memoryCounter.toString()); done = true; diff --git a/benchmarks/python/benchmark.py b/benchmarks/python/benchmark.py index a7087dd1f..c081cb140 100644 --- a/benchmarks/python/benchmark.py +++ b/benchmarks/python/benchmark.py @@ -198,10 +198,6 @@ def parse_arguments(): help= 'Quick sanity check with num_layer=1; will be silently ignored if --engine_dir is specified.' ) - parser.add_argument('--strongly_typed', - default=False, - action='store_true', - help='This option will reduce the building time.') parser.add_argument( '--gpu_weights_percent', type=str, diff --git a/benchmarks/python/build.py b/benchmarks/python/build.py index 11b888d74..e5b4efef7 100644 --- a/benchmarks/python/build.py +++ b/benchmarks/python/build.py @@ -151,10 +151,6 @@ def parse_arguments(): default=False, action='store_true', help="Build engines serially") - parser.add_argument('--strongly_typed', - default=False, - action='store_true', - help='This option will reduce the building time.') parser.add_argument( '--multiple_profiles', default=False, @@ -251,9 +247,6 @@ def build_gpt(args): if not args.serial_build: torch.cuda.set_device(runtime_rank) - strongly_typed = args.strongly_typed - if args.quantization is not None and "fp8" in args.quantization: - strongly_typed = True num_kv_heads = build_config['num_heads'] \ if build_config['num_kv_heads'] is None else build_config['num_kv_heads'] apply_query_key_layer_scaling = False @@ -321,7 +314,7 @@ def build_gpt(args): quant_mode=quant_mode, use_refit=False, opt_level=build_config['builder_opt'], - strongly_typed=strongly_typed, + strongly_typed=True, weight_streaming=is_weight_streaming, **builder_config_extra_kwargs) engine_name = get_engine_name(args.model, args.dtype, world_size, @@ -363,8 +356,10 @@ def build_gpt(args): 'apply_query_key_layer_scaling': builder_config.apply_query_key_layer_scaling, 'rotary_pct': build_config['rotary_pct'], - 'moe_num_experts': build_config["moe_num_experts"], - 'moe_top_k': build_config["moe_top_k"], + 'moe': { + 'num_experts': build_config["moe_num_experts"], + 'top_k': build_config["moe_top_k"], + }, } config = PretrainedConfig.from_dict(config) tensorrt_llm_model = tensorrt_llm.models.GPTForCausalLM(config) @@ -399,7 +394,7 @@ def build_gpt(args): elif family == "llama": config = { 'architecture': - 'LLaMAForCausalLM', + 'LlamaForCausalLM', 'dtype': args.dtype, 'num_hidden_layers': @@ -430,10 +425,10 @@ def build_gpt(args): 'world_size': world_size, 'tp_size': world_size }, - 'moe_num_experts': - build_config["moe_num_experts"], - 'moe_top_k': - build_config["moe_top_k"], + 'moe': { + 'num_experts': build_config["moe_num_experts"], + 'top_k': build_config["moe_top_k"], + } } config = PretrainedConfig.from_dict(config) tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config) @@ -602,9 +597,6 @@ def build_gpt(args): } config = PretrainedConfig.from_dict(config) tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(config) - tensorrt_llm_model = optimize_model( - tensorrt_llm_model, - use_parallel_embedding=config.use_parallel_embedding) elif family == "falcon": config = { 'architecture': @@ -696,7 +688,7 @@ def build_gpt(args): elif family == "internlm": config = { 'architecture': - 'LLaMAForCausalLM', + 'LlamaForCausalLM', 'dtype': args.dtype, 'num_hidden_layers': @@ -778,10 +770,10 @@ def build_gpt(args): 'world_size': world_size, 'tp_size': world_size }, - 'moe_num_experts': - build_config["moe_num_experts"], - 'moe_top_k': - build_config["moe_top_k"], + 'moe': { + 'num_experts': build_config["moe_num_experts"], + 'top_k': build_config["moe_top_k"], + }, 'qwen_type': 'qwen', } @@ -821,10 +813,10 @@ def build_gpt(args): 'world_size': world_size, 'tp_size': world_size }, - 'moe_num_experts': - build_config["moe_num_experts"], - 'moe_top_k': - build_config["moe_top_k"], + 'moe': { + 'num_experts': build_config["moe_num_experts"], + 'top_k': build_config["moe_top_k"], + }, 'qwen_type': 'qwen2', } @@ -1029,7 +1021,7 @@ def build_bert(args): max_batch_size=max_batch_size, max_input_len=max_input_len, opt_level=build_config['builder_opt'], - strongly_typed=args.strongly_typed, + strongly_typed=True, weight_streaming=is_weight_streaming, ) engine_name = get_engine_name(args.model, args.dtype, world_size, @@ -1207,7 +1199,7 @@ def enc_dec_build_helper(component, config, args): cross_attention=(component == 'decoder'), has_position_embedding=has_position_embedding, has_token_type_embedding=False, # by default - strongly_typed=False, # by default + strongly_typed=True, gather_all_token_logits=False, # by default int8=(quant_mode.has_act_and_weight_quant() or quant_mode.is_int8_weight_only()), diff --git a/benchmarks/python/check_accuracy_mlperf.py b/benchmarks/python/check_accuracy_mlperf.py index 025b08201..c02c8537a 100644 --- a/benchmarks/python/check_accuracy_mlperf.py +++ b/benchmarks/python/check_accuracy_mlperf.py @@ -1,4 +1,5 @@ import json +import os from enum import Enum import evaluate @@ -82,9 +83,11 @@ def calculate_toks_per_sample(preds, eos_id): return avg_len / num_samples -def calculate_rouge_score(preds, targets): +def calculate_rouge_score(preds, targets, rouge_dir=None): print("Calculating ROUGE scores...") - metric = evaluate.load("rouge") + rouge_dir = rouge_dir if rouge_dir and os.path.exists( + rouge_dir) else "rouge" + metric = evaluate.load(rouge_dir) preds, targets = postprocess_text(preds, targets[0:len(preds)]) result = metric.compute(predictions=preds, references=targets, @@ -114,6 +117,15 @@ def parse_arguments(): parser.add_argument("--base_model", type=str, help="Location of the model used (to create tokenizer)") + + parser.add_argument( + '--rouge_dir', + default=None, + type=str, + help= + "evaluate.load('rouge') will attempt to pull rouge package from HF. Use cached rouge can avoid network outage of host or HF." + ) + args = parser.parse_args() return args @@ -146,7 +158,8 @@ def main(): tps_score = calculate_toks_per_sample(pred_toks, tokenizer.eos_token) pred_texts = tokenizer.batch_decode(pred_toks, skip_special_tokens=True) - achieved_scores = calculate_rouge_score(pred_texts, target_texts) + achieved_scores = calculate_rouge_score(pred_texts, target_texts, + args.rouge_dir) achieved_scores['tokens_per_sample'] = tps_score targets = ACCURACY_TARGETS[model] diff --git a/benchmarks/python/gpt_benchmark.py b/benchmarks/python/gpt_benchmark.py index b35cac85e..b075a47c0 100644 --- a/benchmarks/python/gpt_benchmark.py +++ b/benchmarks/python/gpt_benchmark.py @@ -279,14 +279,10 @@ def check_memory(self, io_shapes: list, raise_exception=False): self.kv_cache_elem_per_token(self.build_config, self.runtime_mapping.tp_size, self.runtime_mapping.pp_size) * element_size(self.kv_dtype) # when MHA is OOTB, it requires extra KV cache size, because OOTB don't support inplace updating KV cache. if not self.use_gpt_attention_plugin: - if os.getenv('TRTLLM_DISABLE_OOTB_KVCACHE_REUSE') != 'ON': - local_n_layer = ceil(self.build_config.num_layers / - self.runtime_mapping.pp_size) - kv_cache_size_in_bytes = kv_cache_size_in_bytes / local_n_layer * ( - local_n_layer + 1) - else: - # without reusing, we need one for past as engine inputs, one for present as engine outputs. - kv_cache_size_in_bytes *= 2 + local_n_layer = ceil(self.build_config.num_layers / + self.runtime_mapping.pp_size) + kv_cache_size_in_bytes = kv_cache_size_in_bytes / local_n_layer * ( + local_n_layer + 1) kv_cache_size_in_mb = bytes_to_target_unit(kv_cache_size_in_bytes, "MiB") diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/enums.py b/benchmarks/suite/tensorrt_llm_bench/utils/enums.py index cfd6c34b8..a9bd3e478 100644 --- a/benchmarks/suite/tensorrt_llm_bench/utils/enums.py +++ b/benchmarks/suite/tensorrt_llm_bench/utils/enums.py @@ -51,9 +51,7 @@ def get_build_options(self, dtype: str) -> List[str]: List[str]: A list of command line arguments to be added to build commands. """ - if self.value == self.FP8: - return ["--strongly_typed"] - else: + if not self.value == self.FP8: return ["--gemm_plugin", dtype] diff --git a/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py b/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py index 53e3cefd2..d7302098f 100644 --- a/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py +++ b/benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py @@ -3,8 +3,7 @@ from argparse import ArgumentParser from typing import Literal, Optional -from pydantic import (AliasChoices, AliasPath, BaseModel, Field, computed_field, - model_validator) +from pydantic import AliasChoices, AliasPath, BaseModel, Field, model_validator from transformers import AutoConfig from utils import VALID_QUANT_ALGOS @@ -132,7 +131,7 @@ class TRTLLMConfig(BaseModel): mapping: TRTLLM_Mapping quantization: TRTLLM_Quantization - @computed_field + @property def kv_dtype(self) -> str: if self.quantization.kv_cache_quant_algo == "FP8": return "fp8" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4cba62c87..553d5a767 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -181,6 +181,8 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) endif() message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES}") +# Store CMAKE_CUDA_ARCHITECTURES for later use since torch sets this to "OFF" +set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}") enable_language(C CXX CUDA) diff --git a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h index 4f6673745..65d12b388 100644 --- a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h @@ -60,6 +60,7 @@ auto constexpr kReturnGenerationLogitsTensorName = "return_generation_logits"; auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table"; auto constexpr kPromptVocabSizeName = "prompt_vocab_size"; auto constexpr kLoraTaskId = "lora_task_id"; +auto constexpr kNoRepeatNgramSizeTensorName = "noRepeatNgramSize"; // weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] // where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer // each of the in / out tensors are first flattened and then concatenated together in the format above. @@ -194,6 +195,7 @@ class GenericInferenceRequest inference_request::kReturnGenerationLogitsTensorName, inference_request::kPromptEmbeddingTableName, inference_request::kPromptVocabSizeName, + inference_request::kNoRepeatNgramSizeTensorName, // obsolete names for backward compatibility inference_request::kInputLengthsTensorName, inference_request::kLoraTaskId, @@ -264,6 +266,7 @@ class GenericInferenceRequest TENSOR_GETTER_SETTER(LoraTaskId, inference_request::kLoraTaskId) TENSOR_GETTER_SETTER(LoraWeights, inference_request::kLoraWeights) TENSOR_GETTER_SETTER(LoraConfig, inference_request::kLoraConfig) + TENSOR_GETTER_SETTER(NoRepeatNgramSize, inference_request::kNoRepeatNgramSizeTensorName) #undef TENSOR_GETTER_SETTER diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h index a3bf684af..a7f1ba430 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h @@ -25,6 +25,12 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager { +enum class CacheType +{ + kSELF = 0, + kCROSS = 1, +}; + //! @brief Encapsulates parameters to configure paged KV cache. class KvCacheConfig { diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 03c4fa116..a28ab7269 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -286,10 +286,11 @@ class BlockManager { public: using SizeType32 = tensorrt_llm::runtime::SizeType32; + using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType; explicit BlockManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, - std::shared_ptr stream, bool onboardBlocks); + std::shared_ptr stream, bool onboardBlocks, CacheType cacheType = CacheType::kSELF); ~BlockManager(); @@ -453,6 +454,8 @@ class BlockManager BlockPtr mCachedBlocksRoot; // Statistics for block allocations/reuse std::size_t mAllocTotalBlocks, mAllocNewBlocks, mReusedBlocks; + // KV cache type (self or cross) + CacheType mCacheType; }; class KVCacheManager @@ -461,11 +464,13 @@ class KVCacheManager using SizeType32 = tensorrt_llm::runtime::SizeType32; using SequencesPtr = GenerationRequest::SharedPtr; using CudaStreamPtr = std::shared_ptr; + using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType; KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock, - CudaStreamPtr stream, bool enableBlockReuse = false, bool onboardBlocks = true); + CudaStreamPtr stream, bool enableBlockReuse = false, bool onboardBlocks = true, + CacheType cacheType = CacheType::kSELF); void allocatePools(nvinfer1::DataType dtype, bool useUvm = false); @@ -576,6 +581,11 @@ class KVCacheManager void removeToken(SizeType32 seqSlotIdx); void rewindKVCache(SizeType32 seqSlotIdx, SizeType32 rewindLengths); + [[nodiscard]] bool isCrossKv() const + { + return mCacheType == CacheType::kCROSS; + } + private: void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 seqSlotIdx, SizeType32 beamIdx, SizeType32 blockIdx, KVCacheBlock::IdType blockId) const; @@ -610,5 +620,7 @@ class KVCacheManager runtime::ITensor::SharedPtr mSequenceBlockIndices; // Whether to cache KV pages for reuse bool mEnableBlockReuse; + // KV cache type (self or cross) + CacheType mCacheType; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index c1a64e8bb..5d8308feb 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/samplingConfig.h" @@ -32,15 +33,21 @@ namespace tensorrt_llm::batch_manager { -// TODO(rkobus): refactor +/** + * @brief The state of the request. + * + * Enum order must follow chronological order for state dependency check, @see hasReachedState(). + * + * @todo(rkobus): refactor + */ enum LlmRequestState_t { - REQUEST_STATE_UNKNOWN = 0, - REQUEST_STATE_CONTEXT_INIT = 1, - REQUEST_STATE_GENERATION_IN_PROGRESS = 2, - REQUEST_STATE_GENERATION_TO_COMPLETE = 3, - REQUEST_STATE_GENERATION_COMPLETE = 4, - REQUEST_STATE_ENC_INIT = 5 // For enc-dec models, encoder output has been computed + REQUEST_STATE_UNKNOWN = 0, ///< Unknown state + REQUEST_STATE_ENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models) + REQUEST_STATE_CONTEXT_INIT = 2, ///< Context phase starts + REQUEST_STATE_GENERATION_IN_PROGRESS = 3, ///< Generation phase is in progress + REQUEST_STATE_GENERATION_TO_COMPLETE = 4, ///< Generation phase is to be completed + REQUEST_STATE_GENERATION_COMPLETE = 5, ///< Generation phase completed }; template @@ -69,7 +76,7 @@ class GenericLlmRequest std::optional> draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt, - std::shared_ptr encoderInputTokens = nullptr) + std::optional> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false) : mRequestId(requestId) , mPromptLen(inputTokens->size()) , mMaxNewTokens(maxNewTokens) @@ -99,9 +106,14 @@ class GenericLlmRequest , mReturnContextLogits(returnContextLogits) , mReturnGenerationLogits(returnGenerationLogits) , mExcludeInputFromOutput(excludeInputFromOutput) - , mEncoderInputTokens(encoderInputTokens) + , mEncoderTokens(std::move(encoderInputTokens)) + , mReturnEncoderOutput(returnEncoderOutput) , mDecodingIter(0) { + if (mEncoderTokens.has_value()) + { + mState = REQUEST_STATE_ENCODER_INIT; + } initialize(*inputTokens, returnLogProbs); } @@ -134,8 +146,15 @@ class GenericLlmRequest , mReturnContextLogits(req.getOutputConfig().returnContextLogits) , mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits) , mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput) + , mEncoderTokens(std::nullopt) + , mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput) , mDecodingIter(0) { + if (req.getEncoderInputTokenIds()) + { + mState = REQUEST_STATE_ENCODER_INIT; + mEncoderTokens = std::make_shared(req.getEncoderInputTokenIds().value()); + } if (req.getEmbeddingBias()) { mEmbeddingBias = executor::detail::toITensor(req.getEmbeddingBias().value()); @@ -194,8 +213,13 @@ class GenericLlmRequest initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs); } - void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen) + void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, + std::optional maxEncoderInputLen = std::nullopt) { + TLLM_CHECK_WITH_INFO(!(maxEncoderInputLen.has_value() && getEncoderLen() > maxEncoderInputLen.value()), + "Encoder length (%d) exceeds maximum encoder input length (%d).", getEncoderLen(), + maxEncoderInputLen.value()); + if (mPromptLen > maxInputLen) { TLLM_THROW("Prompt length (%d) exceeds maximum input length (%d).", mPromptLen, maxInputLen); @@ -287,15 +311,17 @@ class GenericLlmRequest /// @brief Get input tokens to encoder /// @return A vector of tokens. - std::shared_ptr const& getEncoderInputTokens() const + [[nodiscard]] std::optional> const& getEncoderTokens() const { - return mEncoderInputTokens; + return mEncoderTokens; } - SizeType32 getEncoderInputSize() const + /// @brief Get the number of input tokens to encoder + /// @return The number of encoder input tokens. + [[nodiscard]] SizeType32 getEncoderLen() const { - TLLM_CHECK_WITH_INFO(static_cast(getEncoderInputTokens()), "Encoder input tokens must not be nullptr"); - return getEncoderInputTokens()->size(); + TLLM_CHECK_WITH_INFO(getEncoderTokens().has_value(), "Encoder tokens are not given"); + return getEncoderTokens().value()->size(); } /// @brief Get the draft tokens @@ -396,7 +422,9 @@ class GenericLlmRequest mMaxNewTokens -= (newPromptLen - mPromptLen); mPromptLen = newPromptLen; } - mState = REQUEST_STATE_CONTEXT_INIT; + + // for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase + mState = mEncoderTokens.has_value() ? REQUEST_STATE_ENCODER_INIT : REQUEST_STATE_CONTEXT_INIT; mContextCurrentPosition = 0; mContextChunkSize = std::nullopt; mSeqSlot.reset(); @@ -550,6 +578,68 @@ class GenericLlmRequest return mNumTokensPerIteration; } + void setReturnEncoderOutput(bool const returnEncoderOutput) + { + mReturnEncoderOutput = returnEncoderOutput; + } + + [[nodiscard]] bool getReturnEncoderOutput() const + { + return mReturnEncoderOutput; + } + + [[nodiscard]] TensorPtr const& getEncoderOutputHost() const + { + return mEncoderOutputHost; + } + + void setEncoderOutputHost(TensorPtr encoderOutputHost) + { + mEncoderOutputHost = std::move(encoderOutputHost); + } + + void allocEncoderOutputHost(SizeType32 encoderHiddenSize, nvinfer1::DataType dataType) + { + mEncoderOutputHost = runtime::BufferManager::pinned( + runtime::ITensor::makeShape({getEncoderLen(), encoderHiddenSize}), dataType); + } + + [[nodiscard]] TensorPtr const& getEncoderOutput() const noexcept + { + return mEncoderOutput; + } + + [[nodiscard]] TensorPtr const& getEncoderHiddenStates() const noexcept + { + return mEncoderHiddenStates; + } + + void allocEncoderOutput(runtime::BufferManager const& manager, nvinfer1::DataType dataType) + { + // unique_ptr --> shared_ptr ownership move + mEncoderOutput = std::move(manager.emptyTensor(runtime::MemoryType::kGPU, dataType)); + } + + void allocEncoderHiddenStates(runtime::BufferManager const& manager, nvinfer1::DataType dataType) + { + // unique_ptr --> shared_ptr ownership move + mEncoderHiddenStates = std::move(manager.emptyTensor(runtime::MemoryType::kGPU, dataType)); + } + + void freeEncoderOutputBuffers() + { + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + TLLM_LOG_DEBUG( + "Encoder output buffers use count: %u, %u", mEncoderOutput.use_count(), mEncoderHiddenStates.use_count()); + + // TODO: better ways to free shared_ptr buffers + mEncoderOutput.reset(); + mEncoderHiddenStates.reset(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + } + void setReturnContextLogits(bool const returnContextLogits) { mReturnContextLogits = returnContextLogits; @@ -629,6 +719,16 @@ class GenericLlmRequest mGenerationLogitsFragments.clear(); } + [[nodiscard]] bool hasReachedState(LlmRequestState_t state) const noexcept + { + return mState >= state; + } + + [[nodiscard]] bool isEncoderInitState() const noexcept + { + return mState == REQUEST_STATE_ENCODER_INIT; + } + [[nodiscard]] bool isContextInitState() const noexcept { return mState == REQUEST_STATE_CONTEXT_INIT; @@ -664,16 +764,6 @@ class GenericLlmRequest return mPromptLen - getContextCurrentPosition(); } - TensorPtr getEncoderOutput() const noexcept - { - return mEncoderOutput; - } - - void setEncoderOutput(TensorPtr encoderOutput) - { - mEncoderOutput = std::move(encoderOutput); - } - /// To retrieve the context chunk size, throw an exception when the context is not chunked. [[nodiscard]] SizeType32 getContextChunkSize() const { @@ -746,6 +836,8 @@ class GenericLlmRequest { if (isGenerationCompleteState() || (mIsStreaming && isGenerationInProgressState())) { + TLLM_LOG_DEBUG("Creating response for request %lu", mRequestId); + executor::Result result; result.isFinal = isGenerationCompleteState(); @@ -809,6 +901,11 @@ class GenericLlmRequest result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost()); } + if (getReturnEncoderOutput()) + { + result.encoderOutput = executor::detail::ofITensor(getEncoderOutputHost()); + } + // Update position of last sent response mMaxSentTokenPos = tokenPos; @@ -835,8 +932,8 @@ class GenericLlmRequest std::optional mLogitsPostProcessor; protected: - SizeType32 mOrigPromptLen; BeamTokens mTokens; + SizeType32 mOrigPromptLen; SizeType32 mMaxSentTokenPos; std::optional mEmbeddingBias; @@ -850,9 +947,6 @@ class GenericLlmRequest std::optional mLoraWeights; std::optional mLoraConfig; - // encoder output, saved for computing cross attention KV Cache - TensorPtr mEncoderOutput; - // To enable chunked context, the FHMA paged kv-cache also needs to be enabled. Except for the last one, // the size of the context chunk needs to be an integer multiple of the kv-cache block size. The meaning // of null value is that the context is not chunked. @@ -868,15 +962,21 @@ class GenericLlmRequest // Save logits bool mReturnContextLogits; bool mReturnGenerationLogits; - TensorPtr mContextLogits; // [mPromptLen, vocab_size_padded] - TensorPtr mContextLogitsHost; - TensorPtr mGenerationLogits; // [beam_size, mMaxNewTokens, vocab_size_padded] - TensorPtr mGenerationLogitsHost; + bool mReturnLogProbs; + TensorPtr mContextLogitsHost; // [mPromptLen, vocab_size_padded] + TensorPtr mGenerationLogitsHost; // [beam_size, mMaxNewTokens, vocab_size_padded] std::vector mGenerationLogitsFragments; bool mExcludeInputFromOutput; - std::shared_ptr - mEncoderInputTokens; // Input tokens to the encoder for enc only models and enc-dec models + + // Encoder-only and Encoder-Decoder models + // Encoder input tokens + std::optional> mEncoderTokens; + bool mReturnEncoderOutput; + // Encoder output, used to compute cross attention KV Cache + TensorPtr mEncoderOutput; // [numTokens, hidden_size] + TensorPtr mEncoderHiddenStates; // for pipeline parallelism, [numTokens, hiddenSize] + TensorPtr mEncoderOutputHost; SizeType32 mDecodingIter; @@ -954,12 +1054,12 @@ class LlmRequest : public GenericLlmRequest std::optional> draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt, - std::shared_ptr encoderInputTokens = nullptr) + std::optional> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false) : Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig), returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits), - excludeInputFromOutput, std::move(logitsPostProcessor), encoderInputTokens) + excludeInputFromOutput, std::move(logitsPostProcessor), std::move(encoderInputTokens), returnEncoderOutput) { } @@ -970,28 +1070,6 @@ class LlmRequest : public GenericLlmRequest mLogitsPostProcessor = std::move(logitsPostProcessor); } - static LlmRequest createEncoderRequest(RequestIdType requestId, SizeType32 maxNewTokens, - std::shared_ptr encoderInputTokens, std::shared_ptr inputTokens, - runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, - std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, - std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, - std::optional promptEmbeddingTable = std::nullopt, - std::optional promptVocabSize = std::nullopt, - std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt, - std::optional loraConfig = std::nullopt, bool returnLogProbs = false, - bool returnContextLogits = false, bool returnGenerationLogits = false, - std::optional> draftTokens = std::nullopt, - std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, - std::optional logitsPostProcessor = std::nullopt) - { - LlmRequest request(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, - embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, loraTaskId, loraWeights, - loraConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, draftTokens, draftLogits, - excludeInputFromOutput, logitsPostProcessor, encoderInputTokens); - request.mState = REQUEST_STATE_ENC_INIT; - return request; - } - void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager) { if (!mPromptEmbeddingTable.has_value() diff --git a/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h index 984d0215d..65808134b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h @@ -15,6 +15,7 @@ #include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h" +#include "tensorrt_llm/common/tllmException.h" #include "tensorrt_llm/runtime/loraCache.h" #include "tensorrt_llm/runtime/modelConfig.h" #include "tensorrt_llm/runtime/workerPool.h" @@ -34,7 +35,7 @@ namespace tensorrt_llm::batch_manager using runtime::SizeType32; -class PeftTaskNotCachedException : public std::runtime_error +class PeftTaskNotCachedException : public runtime::LoraExpectedException { public: explicit PeftTaskNotCachedException(std::string const& msg); diff --git a/cpp/tensorrt_llm/common/cudaBf16Wrapper.h b/cpp/include/tensorrt_llm/common/cudaBf16Wrapper.h similarity index 100% rename from cpp/tensorrt_llm/common/cudaBf16Wrapper.h rename to cpp/include/tensorrt_llm/common/cudaBf16Wrapper.h diff --git a/cpp/tensorrt_llm/common/cudaFp8Utils.h b/cpp/include/tensorrt_llm/common/cudaFp8Utils.h similarity index 100% rename from cpp/tensorrt_llm/common/cudaFp8Utils.h rename to cpp/include/tensorrt_llm/common/cudaFp8Utils.h diff --git a/cpp/tensorrt_llm/common/cudaUtils.h b/cpp/include/tensorrt_llm/common/cudaUtils.h similarity index 100% rename from cpp/tensorrt_llm/common/cudaUtils.h rename to cpp/include/tensorrt_llm/common/cudaUtils.h diff --git a/cpp/tensorrt_llm/common/dataType.h b/cpp/include/tensorrt_llm/common/dataType.h similarity index 100% rename from cpp/tensorrt_llm/common/dataType.h rename to cpp/include/tensorrt_llm/common/dataType.h diff --git a/cpp/include/tensorrt_llm/common/mpiUtils.h b/cpp/include/tensorrt_llm/common/mpiUtils.h index 4f5ff38db..97e5c47f0 100644 --- a/cpp/include/tensorrt_llm/common/mpiUtils.h +++ b/cpp/include/tensorrt_llm/common/mpiUtils.h @@ -368,7 +368,7 @@ class MpiComm bool mFreeComm; }; -void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_FUNNELED); +void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_FUNNELED, bool forwardAbortToParent = false); } // namespace tensorrt_llm::mpi diff --git a/cpp/tensorrt_llm/common/quantization.h b/cpp/include/tensorrt_llm/common/quantization.h similarity index 99% rename from cpp/tensorrt_llm/common/quantization.h rename to cpp/include/tensorrt_llm/common/quantization.h index 93fd29600..c2636b3b4 100644 --- a/cpp/tensorrt_llm/common/quantization.h +++ b/cpp/include/tensorrt_llm/common/quantization.h @@ -16,8 +16,9 @@ #pragma once -#include "stdlib.h" -#include "tensor.h" +#include +#include +#include namespace tensorrt_llm { diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index a2fce596c..aef5eb617 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -58,7 +58,8 @@ class SamplingConfig std::optional const& presencePenalty = std::nullopt, std::optional const& frequencyPenalty = std::nullopt, std::optional const& lengthPenalty = std::nullopt, - std::optional const& earlyStopping = std::nullopt); + std::optional const& earlyStopping = std::nullopt, + std::optional const& noRepeatNgramSize = std::nullopt); bool operator==(SamplingConfig const& other) const; @@ -77,6 +78,7 @@ class SamplingConfig [[nodiscard]] std::optional getFrequencyPenalty() const; [[nodiscard]] std::optional getLengthPenalty() const; [[nodiscard]] std::optional getEarlyStopping() const; + [[nodiscard]] std::optional getNoRepeatNgramSize() const; void setBeamWidth(SizeType32 beamWidth); void setTopK(std::optional const& topK); @@ -93,6 +95,7 @@ class SamplingConfig void setFrequencyPenalty(std::optional const& frequencyPenalty); void setLengthPenalty(std::optional const& lengthPenalty); void setEarlyStopping(std::optional const& earlyStopping); + void setNoRepeatNgramSize(std::optional const& noRepeatNgramSize); private: static SizeType32 checkBeamWidth(SizeType32 beamWidth); @@ -103,6 +106,7 @@ class SamplingConfig static std::optional const& checkTopPDecay(std::optional const& topPDecay); static std::optional const& checkTemperature(std::optional const& temperature); static std::optional const& checkMinLength(std::optional const& minLength); + static std::optional const& checkNoRepeatNgramSize(std::optional const& noRepeatNgramSize); static std::optional const& checkBeamSearchDiversityRate( std::optional const& beamSearchDiversityRate); @@ -142,6 +146,8 @@ class SamplingConfig /// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with /// end_token) std::optional mEarlyStopping; + /// @brief Controls how many repeat ngram size are acceptable. Default is 1 << 30. + std::optional mNoRepeatNgramSize; }; /// @brief Configuration that controls the outputs of a Result @@ -149,7 +155,7 @@ class OutputConfig { public: explicit OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false, - bool returnGenerationLogits = false, bool excludeInputFromOutput = false); + bool returnGenerationLogits = false, bool excludeInputFromOutput = false, bool returnEncoderOutput = false); /// @brief Controls if Result should contain log probabilities. Default is false. bool returnLogProbs; @@ -159,6 +165,9 @@ class OutputConfig bool returnGenerationLogits; /// @brief Controls if output tokens in Result should include the input tokens. Default is false. bool excludeInputFromOutput; + /// @brief Controls if Result should contain encoder output hidden states (for encoder-only and encoder-decoder + /// models). Default is false. + bool returnEncoderOutput; }; /// @brief Configuration for speculative decoding with external draft tokens. @@ -241,6 +250,7 @@ class Request /// @param loraConfig The LoRA configuration /// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor /// name provided to the ExecutorConfig. + /// @param encoderInputTokenIds The encoder input token ids for encoder-decoder models, or encoder-only models Request(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming = false, SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(), std::optional const& endId = std::nullopt, std::optional const& padId = std::nullopt, @@ -250,7 +260,8 @@ class Request std::optional externalDraftTokensConfig = std::nullopt, std::optional pTuningConfig = std::nullopt, std::optional loraConfig = std::nullopt, - std::optional logitsPostProcessorName = std::nullopt); + std::optional logitsPostProcessorName = std::nullopt, + std::optional encoderInputTokenIds = std::nullopt); Request(Request const& other); Request(Request&& other) noexcept; @@ -272,6 +283,7 @@ class Request [[nodiscard]] std::optional getPromptTuningConfig() const; [[nodiscard]] std::optional getLoraConfig() const; [[nodiscard]] std::optional getLogitsPostProcessorName() const; + [[nodiscard]] std::optional getEncoderInputTokenIds() const; void setStreaming(bool streaming); void setSamplingConfig(SamplingConfig const& config); @@ -285,6 +297,7 @@ class Request void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig); void setLoraConfig(LoraConfig const& loraConfig); void setLogitsPostProcessorName(std::string const& logitsPostProcessorName); + void setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds); private: friend class Serialization; @@ -312,6 +325,9 @@ struct Result /// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded] std::optional generationLogits; + + /// @brief The encoder output. Size [encoderLen, hiddenSize] + std::optional encoderOutput; }; /// @brief Class that holds either an error or a result @@ -695,11 +711,21 @@ class Executor /// @param comm An optional inter-process communicator configuration Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig const& executorConfig); + Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath, + ModelType modelType, ExecutorConfig const& executorConfig); + Executor(std::vector const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType, ExecutorConfig const& executorConfig); + Executor(std::vector const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, + std::vector const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, ModelType modelType, + ExecutorConfig const& executorConfig); + Executor(std::shared_ptr model, ExecutorConfig const& executorConfig); + Executor( + std::shared_ptr encoderModel, std::shared_ptr decoderModel, ExecutorConfig const& executorConfig); + ~Executor(); /// @brief Enqueue a new request diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index 2af73044f..ff147e882 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -122,6 +122,28 @@ class Serialization static void serialize(ExecutorConfig const& executorConfig, std::ostream& os); static size_t serializedSize(ExecutorConfig const& executorConfig); + // KvCacheStats + static KvCacheStats deserializeKvCacheStats(std::istream& is); + static void serialize(KvCacheStats const& kvCacheStats, std::ostream& os); + static size_t serializedSize(KvCacheStats const& kvCacheStats); + + // StaticBatchingStats + static StaticBatchingStats deserializeStaticBatchingStats(std::istream& is); + static void serialize(StaticBatchingStats const& staticBatchingStats, std::ostream& os); + static size_t serializedSize(StaticBatchingStats const& staticBatchingStats); + + // InflightBatchingStats + static InflightBatchingStats deserializeInflightBatchingStats(std::istream& is); + static void serialize(InflightBatchingStats const& inflightBatchingStats, std::ostream& os); + static size_t serializedSize(InflightBatchingStats const& inflightBatchingStats); + + // IterationStats + static IterationStats deserializeIterationStats(std::vector& buffer); + static IterationStats deserializeIterationStats(std::istream& is); + static void serialize(IterationStats const& iterStats, std::ostream& os); + static std::vector serialize(IterationStats const& iterStats); + static size_t serializedSize(IterationStats const& iterStats); + // String static std::string deserializeString(std::istream& is); diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 537751956..e9f163450 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -153,6 +153,8 @@ enum class MemoryType enum class ModelType { kDECODER_ONLY = 0, + kENCODER_ONLY = 1, + kENCODER_DECODER = 2, }; /// @brief The batching type @@ -276,6 +278,8 @@ struct IterationStats size_t pinnedMemUsage; /// @brief Stats specific to KV caches std::optional kvCacheStats; + /// @brief Stats specific to cross KV caches + std::optional crossKvCacheStats; /// @brief Stats specific to static batching std::optional staticBatchingStats; /// @brief Stats specific to inflight batching @@ -288,13 +292,14 @@ enum class RequestStage /// @brief Request that have been received but not yet included in the active requests (due to constraints such as /// maximum batch size for example). kQUEUED, + /// @brief Active request in encoder phase + kENCODER_IN_PROGRESS, /// @brief Active request in context phase kCONTEXT_IN_PROGRESS, /// @brief Active request in generation phase kGENERATION_IN_PROGRESS, /// @brief Active request for which generation has completed kGENERATION_COMPLETE, - }; /// @brief Struct that holds the stats of a single request @@ -339,22 +344,22 @@ class DecodingMode static auto constexpr TopK() { - return DecodingMode{kTopK | kUsePenalties | kUseBanWords | kStandardStopCriteria}; + return DecodingMode{kTopK | kUsePenalties | kUseBanTokens | kStandardStopCriteria}; } static auto constexpr TopP() { - return DecodingMode{kTopP | kUsePenalties | kUseBanWords | kStandardStopCriteria}; + return DecodingMode{kTopP | kUsePenalties | kUseBanTokens | kStandardStopCriteria}; } static auto constexpr TopKTopP() { - return DecodingMode{kTopKTopP | kUsePenalties | kUseBanWords | kStandardStopCriteria}; + return DecodingMode{kTopKTopP | kUsePenalties | kUseBanTokens | kStandardStopCriteria}; } static auto constexpr BeamSearch() { - return DecodingMode{kBeamSearch | kUsePenalties | kUseBanWords | kStandardStopCriteria}; + return DecodingMode{kBeamSearch | kUsePenalties | kUseBanTokens | kStandardStopCriteria}; } static auto constexpr Medusa() @@ -408,12 +413,24 @@ class DecodingMode return *this; } + auto constexpr useBanTokens(bool banTokens) + { + mState = setBitTo(kUseBanTokens, banTokens); + return *this; + } + auto constexpr useBanWords(bool banWords) { mState = setBitTo(kUseBanWords, banWords); return *this; } + auto constexpr useNoRepeatNgramSize(bool noRepeatNgramSize) + { + mState = setBitTo(kUseNoRepeatNgramSize, noRepeatNgramSize); + return *this; + } + auto constexpr useStopWords(bool stopWords) { mState = setBitTo(kUseStopWords, stopWords); @@ -517,6 +534,16 @@ class DecodingMode return anyBitSet(kUseBanWords); } + bool constexpr isUseNoRepeatNgramSize() const + { + return anyBitSet(kUseNoRepeatNgramSize); + } + + bool constexpr isUseBanTokens() const + { + return anyBitSet(kUseBanTokens); + } + bool constexpr isUseStopWords() const { return anyBitSet(kUseStopWords); @@ -566,11 +593,13 @@ class DecodingMode static UnderlyingType constexpr kUseStopWords{1u << 6}; static UnderlyingType constexpr kUseMaxLengthStop{1u << 7}; static UnderlyingType constexpr kUseExplicitEosStop{1u << 8}; + static UnderlyingType constexpr kUseNoRepeatNgramSize{1u << 9}; static UnderlyingType constexpr kStandardStopCriteria{kUseStopWords | kUseMaxLengthStop}; static UnderlyingType constexpr kUseOccurrencePenalties{ kUseRepetitionPenalties | kUseFrequencyPenalties | kUsePresencePenalties}; static UnderlyingType constexpr kUsePenalties{kUseOccurrencePenalties | kUseTemperature | kUseMinLength}; - static SizeType32 constexpr kNumFlags{9}; + static UnderlyingType constexpr kUseBanTokens{kUseNoRepeatNgramSize | kUseBanWords}; + static SizeType32 constexpr kNumFlags{10}; static UnderlyingType constexpr kAuto{1u << (kNumFlags + 0)}; static UnderlyingType constexpr kTopK{1u << (kNumFlags + 1)}; static UnderlyingType constexpr kTopP{1u << (kNumFlags + 2)}; diff --git a/cpp/tensorrt_llm/kernels/kvCacheIndex.h b/cpp/include/tensorrt_llm/kernels/kvCacheIndex.h similarity index 100% rename from cpp/tensorrt_llm/kernels/kvCacheIndex.h rename to cpp/include/tensorrt_llm/kernels/kvCacheIndex.h diff --git a/cpp/tensorrt_llm/layers/defaultDecodingParams.h b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h similarity index 95% rename from cpp/tensorrt_llm/layers/defaultDecodingParams.h rename to cpp/include/tensorrt_llm/layers/defaultDecodingParams.h index 9236ea9ba..7b2e05629 100644 --- a/cpp/tensorrt_llm/layers/defaultDecodingParams.h +++ b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h @@ -109,6 +109,11 @@ class DefaultDecodingParams { return {}; } + + [[nodiscard]] __host__ __device__ static constexpr runtime::SizeType32 getNoRepeatNgramSize() + { + return 1 << 30; + } }; } // namespace layers } // namespace tensorrt_llm diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoder.h b/cpp/include/tensorrt_llm/runtime/gptDecoder.h index 0802b272b..88daf100f 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoder.h @@ -132,7 +132,7 @@ inline std::unique_ptr IGptDecoder::create(executor::DecodingMode c case nvinfer1::DataType::kHALF: return std::make_unique>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream, maxTokensPerStep, maxAcceptedDraftTokensPerStep); - default: return nullptr; + default: TLLM_THROW("Unsupported decoder data type. Use either kFLOAT or kHALF."); return nullptr; } } } // namespace runtime diff --git a/cpp/include/tensorrt_llm/runtime/gptJsonConfig.h b/cpp/include/tensorrt_llm/runtime/gptJsonConfig.h index 9dc7c5180..e2fea529d 100644 --- a/cpp/include/tensorrt_llm/runtime/gptJsonConfig.h +++ b/cpp/include/tensorrt_llm/runtime/gptJsonConfig.h @@ -49,7 +49,12 @@ class GptJsonConfig static GptJsonConfig parse(std::filesystem::path const& path); - [[nodiscard]] ModelConfig getModelConfig() const + [[nodiscard]] ModelConfig const& getModelConfig() const + { + return mModelConfig; + } + + [[nodiscard]] ModelConfig& getModelConfigMutable() { return mModelConfig; } @@ -103,7 +108,7 @@ class GptJsonConfig SizeType32 const mTensorParallelism; SizeType32 const mPipelineParallelism; SizeType32 const mGpusPerNode; - ModelConfig const mModelConfig; + ModelConfig mModelConfig; // remove const qualifier because config has to mutable after json parsing }; } // namespace tensorrt_llm::runtime diff --git a/cpp/include/tensorrt_llm/runtime/iBuffer.h b/cpp/include/tensorrt_llm/runtime/iBuffer.h index 5a55c66ff..2a9ab9372 100644 --- a/cpp/include/tensorrt_llm/runtime/iBuffer.h +++ b/cpp/include/tensorrt_llm/runtime/iBuffer.h @@ -579,8 +579,15 @@ class BufferRange : public tensorrt_llm::common::ArrayView { } + template , bool> = true> explicit BufferRange(IBuffer& buffer) - : BufferRange(bufferCast(buffer), buffer.getSize()) + : BufferRange(bufferCast(buffer), buffer.getSize()) + { + } + + template , bool> = true> + explicit BufferRange(IBuffer const& buffer) + : BufferRange(bufferCast(buffer), buffer.getSize()) { } }; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatch.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatch.h index 37f7006d3..6352e8eb1 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatch.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatch.h @@ -113,6 +113,7 @@ class Input using Output = decoder::Output; +// TODO: is this a bad name to mix up with token concept in LLM? Would 'Event' be better? And should move to common.h class Token { public: diff --git a/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h b/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h index e0b504e95..60e97f549 100644 --- a/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/iStatefulGptDecoder.h @@ -21,6 +21,7 @@ #include "tensorrt_llm/runtime/generationInput.h" #include "tensorrt_llm/runtime/generationOutput.h" #include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/modelConfig.h" #include "tensorrt_llm/runtime/samplingConfig.h" #include diff --git a/cpp/include/tensorrt_llm/runtime/iTensor.h b/cpp/include/tensorrt_llm/runtime/iTensor.h index 6da87035f..54a975bf4 100644 --- a/cpp/include/tensorrt_llm/runtime/iTensor.h +++ b/cpp/include/tensorrt_llm/runtime/iTensor.h @@ -106,6 +106,25 @@ class ITensor : virtual public IBuffer return static_cast(vol); } + //! + //! \brief Returns the strides of each dimemsion in a Shape. + //! + static Shape strides(Shape const& dims) + { + auto const nbDims = dims.nbDims; + Shape strides{}; + strides.nbDims = nbDims; + if (nbDims > 0) + { + strides.d[nbDims - 1] = 1; + } + for (int i = nbDims - 2; i >= 0; i--) + { + strides.d[i] = dims.d[i + 1] * strides.d[i + 1]; + } + return strides; + } + //! //! \brief Removes the given *unit* dimension from `shape`. //! @@ -169,6 +188,95 @@ class ITensor : virtual public IBuffer return ITensor::slice(constPointerCast(std::forward(tensor)), offset); } + //! + //! \param offsetDims The offset in multiple dimensions. + //! + //! \param tensor The tensor to view. + //! \param offsetDims The offset dimensions of the view. + //! \param size The size of the view w.r.t. the last dimension in offsetDims. + //! \return A view of shape [size, the rest dimensions] + //! or [size] when \param offsetDims specifies all dimensions. + //! \throw Whenever offset overflows or the last dimension offset+size overflows. + //! + static UniquePtr slice(SharedPtr tensor, Shape const& offsetDims, DimType64 size); + + static UniquePtr slice(SharedPtr tensor, std::initializer_list const& offsetDims, DimType64 size) + { + return slice(std::move(tensor), makeShape(offsetDims), size); + } + + template >, int> = 0> + static UniqueConstPtr slice(TConstPtr&& tensor, Shape const& offsetDims, std::size_t size) + { + return slice(constPointerCast(std::forward(tensor)), offsetDims, size); + } + + template >, int> = 0> + static UniqueConstPtr slice( + TConstPtr&& tensor, std::initializer_list const& offsetDims, std::size_t size) + { + return slice(constPointerCast(std::forward(tensor)), offsetDims, size); + } + + //! + //! \brief return the rest slices at the last dimension when `size` omitted. + //! + static UniquePtr slice(SharedPtr tensor, Shape const& offsetDims) + { + auto const dims = tensor->getShape(); + auto const nbDims = offsetDims.nbDims; + auto const size = (dims.nbDims > 0 && nbDims > 0) ? dims.d[nbDims - 1] - offsetDims.d[nbDims - 1] : 0; + return ITensor::slice(std::move(tensor), offsetDims, size); + } + + static UniquePtr slice(SharedPtr tensor, std::initializer_list const& offsetDims) + { + return slice(std::move(tensor), makeShape(offsetDims)); + } + + template >, int> = 0> + static UniqueConstPtr slice(TConstPtr&& tensor, Shape const& offsetDims) + { + return slice(constPointerCast(std::forward(tensor)), offsetDims); + } + + template >, int> = 0> + static UniqueConstPtr slice(TConstPtr&& tensor, std::initializer_list const& offsetDims) + { + return slice(constPointerCast(std::forward(tensor)), offsetDims); + } + + //! + //! \return Just the block at the point, with shape of [the rest dimensions] + //! or [1] when \param offsetDims specifies all dimensions. + //! + static UniquePtr at(SharedPtr tensor, Shape const& offsetDims) + { + auto result = slice(std::move(tensor), offsetDims, 1); + if (result->getShape().nbDims > 1) + { + result->squeeze(0); + } + return result; + } + + static UniquePtr at(SharedPtr tensor, std::initializer_list const& offsetDims) + { + return at(std::move(tensor), makeShape(offsetDims)); + } + + template >, int> = 0> + static UniqueConstPtr at(TConstPtr&& tensor, Shape const& offsetDims) + { + return at(constPointerCast(std::forward(tensor)), offsetDims); + } + + template >, int> = 0> + static ITensor::UniqueConstPtr at(TConstPtr&& tensor, std::initializer_list const& offsetDims) + { + return at(constPointerCast(std::forward(tensor)), offsetDims); + } + //! //! \brief Returns a view on the underlying `buffer` (or tensor) with the given shape. //! @@ -196,6 +304,23 @@ class ITensor : virtual public IBuffer return ITensor::view(std::move(tensor), shapes); } + //! + //! \brief Returns a flattened view on the underlying `tensor` which can be independently reshaped. + //! + //! \param tensor The tensor to flatten. + //! \param sliceN Slice the first N elements after flattening. -1 means take the whole flattened tensor. + //! \return A flatten view on the `tensor`. + //! + static UniquePtr flattenN(SharedPtr tensor, std::int64_t sliceN = -1) + { + UniquePtr flatten = ITensor::view(tensor, ITensor::makeShape({ITensor::volume(tensor->getShape()), 1})); + if (sliceN > 0) + { + flatten = ITensor::slice(std::move(flatten), 0, sliceN); + } + return flatten; + } + //! //! \brief Wraps the given `data` in an `ITensor`. The `ITensor` will not own the underlying `data` and cannot //! be reshaped beyond `capacity`. diff --git a/cpp/tensorrt_llm/runtime/lookaheadModule.h b/cpp/include/tensorrt_llm/runtime/lookaheadModule.h similarity index 100% rename from cpp/tensorrt_llm/runtime/lookaheadModule.h rename to cpp/include/tensorrt_llm/runtime/lookaheadModule.h diff --git a/cpp/include/tensorrt_llm/runtime/loraCache.h b/cpp/include/tensorrt_llm/runtime/loraCache.h index d96177a85..11491d10d 100644 --- a/cpp/include/tensorrt_llm/runtime/loraCache.h +++ b/cpp/include/tensorrt_llm/runtime/loraCache.h @@ -16,6 +16,7 @@ #pragma once +#include "tensorrt_llm/common/tllmException.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iTensor.h" @@ -31,11 +32,26 @@ #include #include #include +#include #include namespace tensorrt_llm::runtime { +class LoraExpectedException : public std::runtime_error +{ +public: + explicit LoraExpectedException(std::string const& msg); + ~LoraExpectedException() noexcept override; +}; + +class LoraCacheFullException : public LoraExpectedException +{ +public: + explicit LoraCacheFullException(std::string const& msg); + ~LoraCacheFullException() noexcept override; +}; + /** * Holds memory of lora cache pages, and manages allocation and freeing of whole pages. * Memory is pre-allocated either on the host or device diff --git a/cpp/tensorrt_llm/runtime/medusaModule.h b/cpp/include/tensorrt_llm/runtime/medusaModule.h similarity index 100% rename from cpp/tensorrt_llm/runtime/medusaModule.h rename to cpp/include/tensorrt_llm/runtime/medusaModule.h diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h index 49387397b..d0d3759be 100644 --- a/cpp/include/tensorrt_llm/runtime/modelConfig.h +++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h @@ -36,6 +36,7 @@ class ModelConfig kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B kMamba = 2, // https://github.com/state-spaces/mamba kRecurrentGemma = 3, // https://github.com/google-deepmind/recurrentgemma + kEncDec = 4, }; struct RnnConfig @@ -84,7 +85,7 @@ class ModelConfig , mUseLoraPlugin(false) , mMlpHiddenSize(0) , mUseCrossAttention(false) - , mUsePositionEmbedding(true) // TODO: remove these two properties? + , mUsePositionEmbedding(false) , mUseTokenTypeEmbedding(false) , mSpeculativeDecodingMode(SpeculativeDecodingMode::None()) { @@ -132,6 +133,16 @@ class ModelConfig return mHiddenSize; } + [[nodiscard]] SizeType32 constexpr getEncoderHiddenSize() const noexcept + { + return mEncoderHiddenSize; + } + + void constexpr setEncoderHiddenSize(SizeType32 encoderHiddenSize) noexcept + { + mEncoderHiddenSize = encoderHiddenSize; + } + [[nodiscard]] SizeType32 constexpr getSizePerHead() const noexcept { return mSizePerHead; @@ -273,6 +284,16 @@ class ModelConfig mMaxNumTokens = maxNumTokens; } + [[nodiscard]] SizeType32 constexpr getMaxEncoderLen() const noexcept + { + return mMaxEncoderLen; + } + + void constexpr setMaxEncoderLen(SizeType32 maxEncoderLen) noexcept + { + mMaxEncoderLen = maxEncoderLen; + } + [[nodiscard]] bool constexpr usePromptTuning() const noexcept { return mMaxPromptEmbeddingTableSize > 0; @@ -398,9 +419,9 @@ class ModelConfig return mUseCrossAttention; } - void constexpr useCrossAttention(bool newCrossAttention) noexcept + void constexpr setUseCrossAttention(bool useCrossAttention) noexcept { - mUseCrossAttention = newCrossAttention; + mUseCrossAttention = useCrossAttention; } [[nodiscard]] bool constexpr usePositionEmbedding() const noexcept @@ -408,9 +429,9 @@ class ModelConfig return mUsePositionEmbedding; } - void constexpr usePositionEmbedding(bool newPositionEmbedding) noexcept + void constexpr setUsePositionEmbedding(bool usePositionEmbedding) noexcept { - mUsePositionEmbedding = newPositionEmbedding; + mUsePositionEmbedding = usePositionEmbedding; } [[nodiscard]] bool constexpr useTokenTypeEmbedding() const noexcept @@ -418,19 +439,9 @@ class ModelConfig return mUseTokenTypeEmbedding; } - void constexpr useTokenTypeEmbedding(bool newTokenTypeEmbedding) noexcept - { - mUseTokenTypeEmbedding = newTokenTypeEmbedding; - } - - [[nodiscard]] SizeType32 constexpr getFfnHiddenSize() const noexcept - { - return mFfnHiddenSize; - } - - void constexpr setFfnHiddenSize(SizeType32 ffnHiddenSize) noexcept + void constexpr setUseTokenTypeEmbedding(bool useTokenTypeEmbedding) noexcept { - mFfnHiddenSize = ffnHiddenSize; + mUseTokenTypeEmbedding = useTokenTypeEmbedding; } [[nodiscard]] SizeType32 constexpr getMaxLoraRank() const noexcept @@ -575,10 +586,11 @@ class ModelConfig std::optional mRnnConfig; // Configs related to encoder / enc-dec models + SizeType32 mMaxEncoderLen{}; + SizeType32 mEncoderHiddenSize{}; bool mUseCrossAttention; bool mUsePositionEmbedding; bool mUseTokenTypeEmbedding; - SizeType32 mFfnHiddenSize; // indicates encoder output hidden size std::vector mLayerTypes; // Speculative decoding members diff --git a/cpp/include/tensorrt_llm/runtime/samplingConfig.h b/cpp/include/tensorrt_llm/runtime/samplingConfig.h index 4336740b1..3597dfc6f 100644 --- a/cpp/include/tensorrt_llm/runtime/samplingConfig.h +++ b/cpp/include/tensorrt_llm/runtime/samplingConfig.h @@ -131,6 +131,9 @@ class SamplingConfig frequencyPenalty = fuseValues( configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; }, layers::DefaultDecodingParams::getFrequencyPenalty()); + noRepeatNgramSize = fuseValues( + configs, [&configs](size_t ci) { return configs[ci].noRepeatNgramSize; }, + layers::DefaultDecodingParams::getNoRepeatNgramSize()); topK = fuseValues( configs, [&configs](size_t ci) { return configs[ci].topK; }, layers::DefaultDecodingParams::getTopK()); topP = fuseValues( @@ -200,6 +203,7 @@ class SamplingConfig SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType) SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType) SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType32) + SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32) #undef SET_FROM_OPTIONAL } @@ -225,6 +229,7 @@ class SamplingConfig valid &= validateVec("temperature", temperature, -fltEpsilon); valid &= validateVec("repetitionPenalty", repetitionPenalty, 0.f); valid &= validateVec("minLength", minLength, -1); + valid &= validateVec("noRepeatNgramSize", noRepeatNgramSize, 0); valid &= validateVec("beamSearchDiversityRate", beamSearchDiversityRate, -fltEpsilon); @@ -256,11 +261,12 @@ class SamplingConfig SizeType32 beamWidth; // penalties - OptVec temperature; // [1] or [batch_size] on cpu - OptVec minLength; // [1] or [batch_size] on cpu - OptVec repetitionPenalty; // [1] or [batch_size] on cpu - OptVec presencePenalty; // [1] or [batch_size] on cpu - OptVec frequencyPenalty; // [1] or [batch_size] on cpu + OptVec temperature; // [1] or [batch_size] on cpu + OptVec minLength; // [1] or [batch_size] on cpu + OptVec repetitionPenalty; // [1] or [batch_size] on cpu + OptVec presencePenalty; // [1] or [batch_size] on cpu + OptVec frequencyPenalty; // [1] or [batch_size] on cpu + OptVec noRepeatNgramSize; // [1] or [batch_size] on cpu // probs OptVec outputLogProbs; @@ -291,13 +297,13 @@ class SamplingConfig { return beamWidth == other.beamWidth && temperature == other.temperature && minLength == other.minLength && repetitionPenalty == other.repetitionPenalty && presencePenalty == other.presencePenalty - && frequencyPenalty == other.frequencyPenalty && topK == other.topK && topP == other.topP - && randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin - && topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate - && lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping - && draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads - && normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs - && cumLogProbs == other.cumLogProbs; + && frequencyPenalty == other.frequencyPenalty && noRepeatNgramSize == other.noRepeatNgramSize + && topK == other.topK && topP == other.topP && randomSeed == other.randomSeed + && topPDecay == other.topPDecay && topPMin == other.topPMin && topPResetIds == other.topPResetIds + && beamSearchDiversityRate == other.beamSearchDiversityRate && lengthPenalty == other.lengthPenalty + && earlyStopping == other.earlyStopping && draftAcceptanceThreshold == other.draftAcceptanceThreshold + && topKMedusaHeads == other.topKMedusaHeads && normalizeLogProbs == other.normalizeLogProbs + && outputLogProbs == other.outputLogProbs && cumLogProbs == other.cumLogProbs; } }; diff --git a/cpp/tensorrt_llm/runtime/speculativeDecodingModule.h b/cpp/include/tensorrt_llm/runtime/speculativeDecodingModule.h similarity index 100% rename from cpp/tensorrt_llm/runtime/speculativeDecodingModule.h rename to cpp/include/tensorrt_llm/runtime/speculativeDecodingModule.h diff --git a/cpp/tensorrt_llm/runtime/utils/debugUtils.h b/cpp/include/tensorrt_llm/runtime/utils/debugUtils.h similarity index 70% rename from cpp/tensorrt_llm/runtime/utils/debugUtils.h rename to cpp/include/tensorrt_llm/runtime/utils/debugUtils.h index cb03f7bc7..35b466a20 100644 --- a/cpp/tensorrt_llm/runtime/utils/debugUtils.h +++ b/cpp/include/tensorrt_llm/runtime/utils/debugUtils.h @@ -20,6 +20,10 @@ namespace tensorrt_llm::runtime::utils { -bool tensorHasNan(IBuffer const& tensor, BufferManager const& manager); +template +bool tensorHasNan(ITensor const& tensor, BufferManager const& manager, std::string const& infoStr); -} +bool tensorHasNan( + size_t M, size_t K, nvinfer1::DataType type, void const* data, cudaStream_t stream, std::string const& infoStr); + +} // namespace tensorrt_llm::runtime::utils diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index 3a4dfcd38..6553a5570 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -256,6 +256,7 @@ set(TRTLLM_LINK_LIBS decoder_attention_src fpA_intB_gemm_src moe_gemm_src + gemm_swiglu_sm90_src cutlass_src layers_src runtime_src) diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a index e5a766f64..11f1e420e 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2eaa6076f6ae6341710ca84097a3083f57235093bfc3bb1cd04f2bce06e1a99e -size 3412616 +oid sha256:c3617c1311c26ceaa826cdb5c1529e4bb4e200b314dccf28152537ab1b3f7c0d +size 3871208 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index e5a766f64..11f1e420e 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2eaa6076f6ae6341710ca84097a3083f57235093bfc3bb1cd04f2bce06e1a99e -size 3412616 +oid sha256:c3617c1311c26ceaa826cdb5c1529e4bb4e200b314dccf28152537ab1b3f7c0d +size 3871208 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt index 14002db23..440357fcc 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -4b12adf3182aabf1df0b5dac217509e0 libtensorrt_llm_batch_manager_static.a -4b12adf3182aabf1df0b5dac217509e0 libtensorrt_llm_batch_manager_static.pre_cxx11.a -fc46fa01e555f9f97387340e46e9571fabf73988 commit \ No newline at end of file +333d12d1d85551be1a3f14f8f0147f0b libtensorrt_llm_batch_manager_static.a +333d12d1d85551be1a3f14f8f0147f0b libtensorrt_llm_batch_manager_static.pre_cxx11.a +736b3fc4259916d31211104b91e6b2b4db995b17 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 437cb844f..50b607de3 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e4fb588419244c8c07a9ce949edd7ed4e3dded008ed82aa993ff69e524394be9 -size 3310186 +oid sha256:c23b9d348006141fc3aa8b6a6fe2b64ac9efedb99bfdbb46a3a7b749d28ed169 +size 3743066 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index e488d6d19..da739cabe 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4b31c50b2879f57022e788700ebf3a86fa8f30133a01533b03c7bc15d64ad364 -size 3283536 +oid sha256:e715ff09c6bd27539faac3922f0fee146e77cf5269a9c43cc9387002adddd71d +size 3716272 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib index f742fdfcc..9f09a8dd9 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib +++ b/cpp/tensorrt_llm/batch_manager/x86_64-windows-msvc/tensorrt_llm_batch_manager_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8f458ce861720a4ce15c9cedeae4bb6c1f6a8a98f0fced35198c9802feaddb10 -size 20305606 +oid sha256:1733af6b615c9eb6d755a37fc786d0063e5d8cb5f6c177a1bd898508282702c4 +size 21651836 diff --git a/cpp/tensorrt_llm/common/assert.cpp b/cpp/tensorrt_llm/common/assert.cpp index c3683de41..eaaf66244 100755 --- a/cpp/tensorrt_llm/common/assert.cpp +++ b/cpp/tensorrt_llm/common/assert.cpp @@ -21,7 +21,7 @@ namespace bool initCheckDebug() { - auto constexpr kDebugEnabled = "TRT_LLM_DEBUG_MODE"; + auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE"; auto const debugEnabled = std::getenv(kDebugEnabled); return debugEnabled && debugEnabled[0] == '1'; } diff --git a/cpp/tensorrt_llm/common/mpiUtils.cpp b/cpp/tensorrt_llm/common/mpiUtils.cpp index d3a4d6d4c..c3cb8aad9 100644 --- a/cpp/tensorrt_llm/common/mpiUtils.cpp +++ b/cpp/tensorrt_llm/common/mpiUtils.cpp @@ -22,8 +22,12 @@ #include "tensorrt_llm/runtime/iBuffer.h" #include +#include #include #include +#ifndef _WIN32 +#include +#endif // We rely on SizeType32 being int32_t in some places with weak type checking, // i.e. we're passing void ptr to some function. To prevent mysterious errors @@ -91,7 +95,7 @@ std::mutex mpiMutex; } // namespace -void initialize(MpiThreadSupport threadMode) +void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent) { std::lock_guard lk(mpiMutex); if (mpiInitialized) @@ -110,11 +114,36 @@ void initialize(MpiThreadSupport threadMode) TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed"); std::atexit([]() { MPI_Finalize(); }); - auto previousHandler = std::signal(SIGABRT, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); - TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); + /* + * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2 + * signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers + * correctly. + */ + for (int sig : {SIGABRT, SIGSEGV}) + { + __sighandler_t previousHandler = nullptr; + if (forwardAbortToParent) + { + previousHandler = std::signal(sig, + [](int signal) + { +#ifndef _WIN32 + pid_t parentProcessId = getppid(); + kill(parentProcessId, SIGKILL); +#endif + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + }); + } + else + { + previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); + } + TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); + } // ensure local MPI communicator is initialized MpiComm::localSession(); + TLLM_LOG_INFO("Initialized MPI"); } #endif // ENABLE_MULTI_DEVICE mpiInitialized = true; diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl new file mode 100644 index 000000000..593eca06e --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout stage_count) +{ + // 32 bytes to account for barriers etc. + constexpr int stage_barrier_bytes = 32; + constexpr int a_bits = static_cast(sizeof_bits::value); + constexpr int b_bits = static_cast(sizeof_bits::value); + constexpr int stage_bytes = [&]() -> int + { + if constexpr (SwapAB) + { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes; + } + else + { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes; + } + }(); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v) &¬ detail:: + is_use_rmem_A()>> +{ + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm + = (cute::is_same_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), + "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t + || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_gated(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + /* For FP8 use a separate mainloop compared to other datatypes */ + cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v>> +{ + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert( + detail::is_input_fp8(), "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsArrayOfPointersGemm + = (cute::is_same_v); + using AtomLayoutMNK + = cute::conditional_t + || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_gated(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp new file mode 100644 index 000000000..2f2422c99 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp" + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, + bool SwapAB = false, class Enable = void> +struct CollectiveBuilderGated +{ + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp new file mode 100644 index 000000000..d850f36df --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, bool SwapAB = false> +struct CollectiveMmaGated +{ + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 000000000..8c8191fbc --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,643 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> +{ + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value + && cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + using InternalElementAux = cute::conditional_t; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params + { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + if constexpr (SwapAB) + { + auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; + } + else + { + auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; + } + } + + template + CUTLASS_HOST_DEVICE static bool can_implement( + ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes + = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) + / 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const + { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + else + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) + { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id + = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) + { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) + { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) + { + mcast_mask_aux = mcast_mask_a; + } + else + { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) + { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, + FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, + Params const& mainloop_params) + { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.partition_A(sAux); + } + else + { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.make_fragment_A(tCsAux); + } + else + { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) + { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } + else + { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) + { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); + } + else + { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) + { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); + } + else + { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) + { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 000000000..28605060d --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,666 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> +{ + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value + && cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params + { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + if constexpr (SwapAB) + { + auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + else + { + auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + } + + template + CUTLASS_HOST_DEVICE static bool can_implement( + ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA + * instructions. */ + implementable = implementable && (args.mma_promotion_interval % 4 == 0); + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes + = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) + / 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const + { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + else + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) + { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id + = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) + { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) + { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) + { + mcast_mask_aux = mcast_mask_a; + } + else + { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) + { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, + FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, + Params const& mainloop_params) + { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.partition_A(sAux); + } + else + { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.make_fragment_A(tCsAux); + } + else + { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) + { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } + else + { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation0.prepare_if_needed()) + { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) + { + cute::gemm( + tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); + } + else + { + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation0.prepare_if_needed()) + { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) + { + cute::gemm( + tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); + } + else + { + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation0.promote_residue_if_needed(); + accumulation1.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) + { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index ee084116a..3db9bf532 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -140,9 +140,9 @@ struct MixedGemmArchTraits::value || cutlass::platform::is_same::value #ifdef ENABLE_FP8 - || cutlass::platform::is_same::value>::type + || cutlass::platform::is_same::value #endif - > + >::type> { private: using LayoutDetails = LayoutDetailsB; diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh new file mode 100644 index 000000000..05dc41925 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh @@ -0,0 +1,217 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +namespace fused_moe +{ +template +struct Fused_Moe_Kernel_sm80 +{ + static constexpr int kMaxTileM = MaxTileM_; + static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_; + static constexpr int kTileK = TileK_; + static constexpr int kStages = Stages_; + static constexpr Activation_Type activation_type = activation_type_; + + using ElementInput = ElementInput_; + using ElementWeight = ElementWeight_; + using ElementOutput = ElementOutput_; + using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80; + using Routine_Arguments = Routine_Arguments; + using Routine_Params = Routine_Params; + using ProblemVisitor + = cutlass::gemm::kernel::MoeProblemVisitor, false>, + cutlass::gemm::GemmShape, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>; + + struct Arguments + { + Routine_Arguments routine_args; + int problem_count{}; + int threadblock_count{}; + }; + + struct Params + { + Routine_Params routine_params; + int threadblock_count{}; + typename ProblemVisitor::Params problem_visitor_param; + }; + + using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80; + static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64 + + static constexpr int kSmemSize = use_m16 + ? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize + : BaseKernelTraits_m16::kSmemSize) + : BaseKernelTraits::kSmemSize; + static constexpr int kThreadCount = BaseKernelTraits::kThreadCount; + + static constexpr bool can_implement(int const avaliable_smem_size) + { + return BaseKernelTraits::can_implement(avaliable_smem_size); + } + + static Params to_underlying_arguments(Arguments const& args) + { + return {{args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias, + args.routine_args.ptr_output, args.routine_args.total_rows_before_expert, args.routine_args.gemm_n, + args.routine_args.gemm_k, args.routine_args.num_expert}, + args.threadblock_count, + {args.routine_args.total_rows_before_expert, args.routine_args.gemm_n, args.routine_args.gemm_k, + args.problem_count, nullptr, 0}}; + } + + CUTE_DEVICE + void run_device(Params const& params) + { +#define ROUTINE_PATH(kTileM_size) \ + { \ + constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \ + using RoutineTraits = Fused_Moe_Kernel_routine_sm80; \ + RoutineTraits routine{}; \ + const int block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \ + routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \ + } + typename ProblemVisitor::SharedStorage dummy_storage{}; + ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x); + while (problem_visitor.next_tile()) + { + auto problem_size = problem_visitor.problem_size(); + auto grid_size = problem_visitor.grid_shape(problem_size); + auto problem_index = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + int const gemm_m = problem_size.m(); + const int32_t block_m_idx_temp = cta_idx / grid_size.n(); + const int32_t block_n_idx = cta_idx % grid_size.n(); + + int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp; + if (residue_m > kMaxTileM / 2) + { + using RoutineTraits = Fused_Moe_Kernel_routine_sm80; + RoutineTraits routine{}; + routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m); + } + else + { + + if constexpr (kMaxTileM >= 128) + { + if (residue_m > 32) + { + ROUTINE_PATH(64); + } + else if (residue_m > 16) + { + ROUTINE_PATH(32); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + else if (kMaxTileM == 64) + { + if (residue_m > 16) + { + ROUTINE_PATH(32); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + else if (kMaxTileM == 32) + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + problem_visitor.advance(gridDim.x); + } +#undef ROUTINE_PATH + } +}; + +template +__global__ void run_global(__grid_constant__ typename GemmType::Params const params) +{ + GemmType gemm; + gemm.run_device(params); +} + +/// Computes the maximum number of active blocks per multiprocessor +template +static int fused_gemm_maximum_active_blocks(int smem_capacity = -1) +{ + + CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); + + constexpr int smem_size = GemmType::kSmemSize; + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) + { + result = cudaFuncSetAttribute(run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, run_global, GemmType::kThreadCount, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; +} +} // namespace fused_moe diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh new file mode 100644 index 000000000..e4f061e4e --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh @@ -0,0 +1,694 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace fused_moe +{ + +template +struct Fused_Moe_Kernel_routine_sm80; + +template +struct Fused_Moe_Kernel_routine_sm80> +{ + using KT = Fused_Moe_Kernel_traits_sm80; + using Params = Routine_Params; + + CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) + { + using X = cute::Underscore; + + int const M = gemm_m; + int const N1 = params.gemm_n; + int const K1 = params.gemm_k; + + int const row_jump = ((problem_index == 0) ? 0 : params.total_rows_before_expert[problem_index - 1]); + typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; + typename KT::ElementWeight const* ptr_fc1_gate_ + = params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation.. + typename KT::ElementWeight const* ptr_fc1_ + = params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation.. + typename KT::ElementInput const* ptr_bias_ + = (params.ptr_bias == nullptr) ? nullptr : params.ptr_bias + 2 * problem_index * N1; + typename KT::ElementInput const* ptr_bias_gate_ + = (params.ptr_bias == nullptr) ? nullptr : params.ptr_bias + (2 * problem_index + 1) * N1; + typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; + + cute::Tensor mInput_mk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), + cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_gate_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_gate_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mBias_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), + cute::make_stride(cute::Int<0>{}, cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mBias_gate_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_gate_)), cute::make_shape(M, N1), + cute::make_stride(cute::Int<0>{}, cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mOutput_mn + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), + cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); + + cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) + cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + + cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn); + } + + // be careful, m_idx will change when use another tile shape.. + CUTE_DEVICE void run_routine( + Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) + { + extern __shared__ char smem_[]; + typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); + int const thread_idx = threadIdx.x; + // gmem tensor partition .. + auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn] + = gmem_tensor_init(problem_index, gemm_m, params); + int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); + auto const n_tile_count = cute::size<2>(gfc1_gate_nk); + + // smem tensor .. + cute::Tensor sInput = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) + cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sfc1_gate_weight + = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sO = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) + + // (1) first step, get the fc1_res and fc1_gate + + // (1.1) get partition for gmem -> smem + cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) + cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + + typename KT::GmemTiledCopyA gmem_tiled_copy_A; + typename KT::GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) + cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) + cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + + // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) + cute::Tensor tInputpInput + = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), + cute::Stride{}); + // Construct identity layout for sInput + cute::Tensor cInput = make_identity_tensor( + make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tInputpInput); ++m) + { + tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // (1.2) prefetch gmem -> smem + cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. + auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 + int k_tile_count = cute::size<2>(gInput); + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + // use copy_if + cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe)); + cute::cp_async_fence(); + k_tile_count--; + if (k_tile_count > 0) + { + ++k_tile_iter; + } + } + + // (1.3) get partition for rf + typename KT::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) + cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + + cute::Tensor accum + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::Tensor accum_gate + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::clear(accum); + cute::clear(accum_gate); + // checkout the shape + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); + + // (1.4)retiling the smem and rf for copy.. + auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) + cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K + + auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) + cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K + + // (1.5) mainloop + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = KT::Stages - 1; + + cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + + constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); + // prefetch register pipeline + if constexpr (K_BLOCK_MAX > 1) + { + cute::cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), + tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{})); + } + // k loop for mainloop (k - (stage - 1) -> -(stage - 1), if k_tile_count > 0, it means we still need to + // fetch gmem to smem) + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > -(KT::Stages - 1); --k_tile_count) + { + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(gmem_tiled_copy_A, tInputpInput, + tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + if (k_tile_count - 1 > 0) + { + ++k_tile_iter; + } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), + accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + } + // if (cute::thread0()) { + // cute::print(accum_gate(0, 0, 0)); + // printf("\n"); + // } + // (2) add bias if it has.. + if (params.ptr_bias != nullptr) + { + cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row.. + cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, 0, block_n_idx); + cute::Tensor tOgBias = thr_mma.partition_C(gBias); + cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate); + for (int i = 0; i < cute::size(accum); i++) + { + accum(i) += tOgBias(i); + accum_gate(i) += tOgBiasg(i); + } + } + + // (3) calculate swiglu + using ActivationFn = typename KT::ActivationFn; + ActivationFn fn{}; + CUTLASS_PRAGMA_UNROLL + for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) + { + accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter); + } + + // (4) push all the result to smem + // (4.1) convert result from ElementAccum to ElementInput + cute::Tensor temp_accum = util_convert_type(accum); + // if (cute::thread0()) { + // cute::print(temp_accum(0, 0, 0)); + // printf("\n"); + // } + // (4.2) retile rf and smem for copy back.. + auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + // cute::clear(sO); + cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); + cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); + + // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) + cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); + __syncthreads(); + + // (4.4) sO -> rO -> gO + + typename KT::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // + // remember, for all the threads in the same col, they have the same idx for bias.. + cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); + // cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row.. + auto tOsO = gmem_thr_copy_O.partition_S(sO); + auto tOgO = gmem_thr_copy_O.partition_D(gO); + // auto tOgBias = gmem_thr_copy_O.partition_D(gBias); + cute::Tensor cOutput = cute::make_identity_tensor( + cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); + cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(tOgO); ++m) + { + if (cute::get<0>(tOcO(0, m, 0)) < residue_m) + { + cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); + } + } + } +}; + +template +struct Fused_Moe_Kernel_routine_sm80> +{ + + using KT = Fused_Moe_Kernel_traits_sm80; + using Params = Routine_Params; + + CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) + { + using X = cute::Underscore; + + int const M = gemm_m; + int const N1 = params.gemm_n; + int const K1 = params.gemm_k; + + int const row_jump = ((problem_index == 0) ? 0 : params.total_rows_before_expert[problem_index - 1]); + typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; + typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1; + typename KT::ElementInput const* ptr_bias_ + = (params.ptr_bias == nullptr) ? nullptr : params.ptr_bias + problem_index * N1; + typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; + + cute::Tensor mInput_mk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), + cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mBias_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), + cute::make_stride(cute::Int<0>{}, cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mOutput_mn + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), + cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); + + cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) + cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + + cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn); + } + + // be careful, m_idx will change when use another tile shape.. + CUTE_DEVICE void run_routine( + Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) + { + extern __shared__ char smem_[]; + typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); + int const thread_idx = threadIdx.x; + // gmem tensor partition .. + auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params); + int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); + auto const n_tile_count = cute::size<2>(gfc1_nk); + + // smem tensor .. + cute::Tensor sInput = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) + cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sO = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) + + // (1) first step, get the fc1_res and fc1_gate + + // (1.1) get partition for gmem -> smem + cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) + cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + + typename KT::GmemTiledCopyA gmem_tiled_copy_A; + typename KT::GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) + cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) + cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + + // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) + cute::Tensor tInputpInput + = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), + cute::Stride{}); + // Construct identity layout for sInput + cute::Tensor cInput = make_identity_tensor( + make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tInputpInput); ++m) + { + tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // (1.2) prefetch gmem -> smem + cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. + auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 + int k_tile_count = cute::size<2>(gInput); + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + // use copy_if + cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); + cute::cp_async_fence(); + k_tile_count--; + if (k_tile_count > 0) + { + ++k_tile_iter; + } + } + + // (1.3) get partition for rf + typename KT::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) + cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + + cute::Tensor accum + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::clear(accum); + // checkout the shape + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); + + // (1.4)retiling the smem and rf for copy.. + auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) + cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K + + auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K + + // (1.5) mainloop + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = KT::Stages - 1; + + cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + + constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); + // prefetch register pipeline + if constexpr (K_BLOCK_MAX > 1) + { + cute::cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), + tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); + } + // k loop for mainloop (k - (stage - 1) -> -(stage - 1), if k_tile_count > 0, it means we still need to + // fetch gmem to smem) + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > -(KT::Stages - 1); --k_tile_count) + { + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(gmem_tiled_copy_A, tInputpInput, + tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + if (k_tile_count - 1 > 0) + { + ++k_tile_iter; + } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), + accum); + }); + } + // if (cute::thread0()) { + // cute::print(accum_gate(0, 0, 0)); + // printf("\n"); + // } + // (2) add bias if it has.. + if (params.ptr_bias != nullptr) + { + cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row.. + cute::Tensor tOgBias = thr_mma.partition_C(gBias); + for (int i = 0; i < cute::size(accum); i++) + { + accum(i) += tOgBias(i); + } + } + // (3) calculate swiglu + using ActivationFn = typename KT::ActivationFn; + ActivationFn fn{}; + CUTLASS_PRAGMA_UNROLL + for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) + { + accum(temp_iter) = fn(accum(temp_iter)); + } + + // (4) push all the result to smem + // (4.1) convert result from ElementAccum to ElementInput + cute::Tensor temp_accum = util_convert_type(accum); + // if (cute::thread0()) { + // cute::print(temp_accum(0, 0, 0)); + // printf("\n"); + // } + // (4.2) retile rf and smem for copy back.. + auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + // cute::clear(sO); + cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); + cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); + + // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) + cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); + __syncthreads(); + + // (4.4) sO -> rO -> gO + + typename KT::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // + cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); + auto tOsO = gmem_thr_copy_O.partition_S(sO); + auto tOgO = gmem_thr_copy_O.partition_D(gO); + cute::Tensor cOutput = cute::make_identity_tensor( + cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); + cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(tOgO); ++m) + { + if (cute::get<0>(tOcO(0, m, 0)) < residue_m) + { + cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); + } + } + } +}; + +} // namespace fused_moe diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh new file mode 100644 index 000000000..b7536b6fb --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh @@ -0,0 +1,213 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fused_moe +{ +template +struct Routine_Arguments +{ + ElementInput* ptr_input{}; + ElementWeight* ptr_fc1{}; + ElementInput* ptr_bias{}; + ElementOutput* ptr_output{}; + int64_t* total_rows_before_expert{}; + int gemm_n{}; + int gemm_k{}; + int num_expert{}; +}; + +template +struct Routine_Params +{ + ElementInput* ptr_input{}; + ElementWeight* ptr_fc1{}; + ElementInput* ptr_bias{}; + ElementOutput* ptr_output{}; + int64_t* total_rows_before_expert{}; + int gemm_n{}; + int gemm_k{}; + int num_expert{}; +}; + +enum class Activation_Type +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGateActivation(Activation_Type const& activation_type) +{ + return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu; +} + +template +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::InvalidType; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::Identity; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::Relu; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool is_gate) +{ + return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool is_gate) +{ + return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu; +} + +/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/ +template +struct Fused_Moe_Kernel_traits_sm80 +{ + using ElementInput = ElementInput_; + using ElementWeight = ElementWeight_; + using ElementAccum = float; + using ElementOutput = ElementOutput_; + + using index_t = uint32_t; + static_assert(TileM_ % 16 == 0); + static_assert(TileN_ % 32 == 0); + static_assert(TileK_ % 32 == 0); + static constexpr int Stages = Stages_; + static constexpr int kTileM = TileM_; + static constexpr int kTileN = TileN_; + static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64); + + // tile shape + using TileShape = cute::Shape, cute::Int, cute::Int>; + static constexpr int kWarpsCount = 4; + static constexpr int kThreadCount = kWarpsCount * 32; + + // MMA atom arch and layout + using MMA_Atom_Arch = std::conditional_t, + cute::MMA_Atom, cute::MMA_Atom>; + // using ValLayoutMNK = cute::Layout>; + using ThreadLayoutMNK + = std::conditional_t, cute::_1>>, + cute::Layout, cute::_1>>>; + using ValLayoutMNK = std::conditional_t, + cute::Tile>; + using TiledMma = cute::TiledMMA; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4 + static constexpr int kAlignment = 8; + static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32; + // A memory copy operand + using DefaultOperandA + = DefaultGemm_TensorOpSm80_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B memory copy operand + using DefaultOperandB + = DefaultGemm_TensorOpSm80_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Output memory copy operand + using SmemLayoutAtomO = SmemLayoutAtomA; + using SmemCopyAtomO = cute::Copy_Atom; + static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput); + static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad; + using GmemLayoutAtomO + = cute::Layout, cute::Int>, + cute::Stride, cute::_1>>; + using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom{}, + GmemLayoutAtomO{}, cute::Layout>{})); + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2); + static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M + static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K + static_assert(cute::rank(SmemLayoutAtomB{}) == 2); + static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N + static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K + + using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{}, + cute::make_shape( + cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_M, BLK_K, Stages + using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{}, + cute::make_shape( + cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_N, BLK_K, Stages + using SmemLayoutO = decltype(cute::tile_to_shape( + SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N + + // we need at least 2 stages.. + static_assert(Stages >= 2); + + struct SharedStorageNormal : cute::aligned_struct<128> + { + cute::array_aligned> smem_input; + cute::array_aligned> smem_fc1_weight; + cute::array_aligned> smem_o; + }; + + struct SharedStorageGate : cute::aligned_struct<128> + { + cute::array_aligned> smem_input; + cute::array_aligned> smem_fc1_gate_weight; + cute::array_aligned> smem_fc1_weight; + cute::array_aligned> smem_o; + }; + + using SharedStorage = std::conditional_t; + + using ActivationFn = std::conditional_t, + std::conditional_t, + std::conditional_t, cutlass::epilogue::thread::Identity>>>; + + static constexpr int kSmemSize = static_cast(sizeof(SharedStorage)); + + static constexpr bool can_implement(int const avaliable_smem_size) + { + return avaliable_smem_size > kSmemSize; + } + + // #endif +}; +} // namespace fused_moe diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp new file mode 100644 index 000000000..3a084ee04 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * + * Supports both the 2.x and 3.x APIs based on whether the first type is + * a cute::tuple<> or not. + * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h + * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp + * + * In the following declaration, the name preceding the 'Or' refers to + * 3.x API type argument order, and the name succeeding the 'Or' refers to + * 2.x API type argument order. Template arguments without two names + * belong to the 3.x API only. + **/ +template +class GemmUniversalGated; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +//////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp" +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh new file mode 100644 index 000000000..aac2cb357 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh @@ -0,0 +1,185 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template +CUTE_DEVICE auto util_convert_type(cute::Tensor const& tensor) +{ + using From_type = typename Engine::value_type; + constexpr int numel = decltype(cute::size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast const*>(tensor.data())); + return cute::make_tensor(cute::make_rmem_ptr(&frag), tensor.layout()); +} + +template +CUTE_DEVICE void util_copy( + TiledCopy const& tiled_copy, cute::Tensor const& S, cute::Tensor& D) +{ + CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{}); + CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{}); + CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D)); + CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D)); + CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D)); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(S); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < cute::size<2>(S); ++k) + { + cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k)); + } + } +} diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp new file mode 100644 index 000000000..9b93382c1 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp @@ -0,0 +1,646 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated + && CollectiveMainloop_::isGated>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock + = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Kernel level shared memory storage + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> + { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) + { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) + { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, + ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + + return {args.mode, problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, + scheduler, workspace}; + } + + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) + { + bool implementable = (args.mode == GemmUniversalMode::kGemm) + or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) + { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) + { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + status = TileScheduler::template initialize_workspace(args.scheduler, + workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, + NumEpilogueSubTiles); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) + { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) + { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() + { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) + { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ + enum class WarpGroupRole + { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole + { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) + { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = size(TiledMma{}); + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = []() + { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) + { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } + else + { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.get_current_work(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) + { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) + { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) + { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) + { + work_tile_info = fetch_next_work(work_tile_info, scheduler); + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the + // work. + auto work_k_tile_count + = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter + = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster, + shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) + { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) + { + while (work_tile_info.is_valid()) + { + if (!TileScheduler::requires_separate_reduction(params.scheduler)) + { + load_order_barrier.wait(); + } + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline, + epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, + shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx()); + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + cutlass::arch::warpgroup_reg_alloc(); + + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count + = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) + { + collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, + accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop, + params.mainloop); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) + { + accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); + } + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) + { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] + = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx()); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + if (do_store_tail) + { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state); + } + } // Consumer Warp Groups End +#endif + } + +private: + // Kernel helper function to get next work unit + CUTLASS_DEVICE + typename TileScheduler::WorkTileInfo fetch_next_work( + typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const + { + // Check whether we should continue on with the current work unit. If this is the case, + // the work unit will have been updated in continue_current_work to reflect the new + // tile to be computed. + if (scheduler.continue_current_work(work_tile_info)) + { + return work_tile_info; + } + + // Get next work tile + scheduler.advance_to_next_work(); + return scheduler.get_current_work(); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp new file mode 100644 index 000000000..5eb1e1cd5 --- /dev/null +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp @@ -0,0 +1,621 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +#include "cute/util/debug.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated + && CollectiveMainloop_::isGated>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(!cute::is_same_v, + "Ping-pong kernel does not currently support stream-K scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 2; + static constexpr uint32_t MaxThreadsPerBlock + = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; + + // Kernel level shared memory storage + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> + { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) + { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + (void) workspace; + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) + { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + return {args.mode, problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)}; + } + + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) + { + bool implementable = (args.mode == GemmUniversalMode::kGemm) + or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) + { + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) + { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = TileScheduler::template initialize_workspace(args.scheduler, + workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) + { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) + { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() + { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) + { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + enum class WarpGroupRole + { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole + { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) + { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); + params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier( + shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&]() + { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) + { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } + else + { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + TileScheduler scheduler{params.scheduler}; + + if (warp_group_role == WarpGroupRole::Consumer1) + { + // Advance 2nd Math WG to the next work tile for the startup + scheduler.advance_to_next_work(); + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + } + auto work_tile_info = scheduler.get_current_work(); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) + { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) + { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + + collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster, + shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) + { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) + { + load_order_barrier.wait(); + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state + = collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, + blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue); + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + cutlass::arch::warpgroup_reg_alloc(); + + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Allocate the accumulators for the (M,N) blk_shape + Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + // Order two Math WG's MMA one after the other, helps hide Epilogue + math_wg_order_barrier.wait(); + + collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1, + k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop); + + // Cue for next Math WG's MMA to start + math_wg_order_barrier.arrive(); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) + { + accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); + } + + // Order two Math WG's Epilogue one after the other + math_wg_order_barrier.wait(); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] + = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue); + + // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels + // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives + // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. + auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] + = collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next); + + // Update starting load/store pipeline states for the next tile + // state has already been incremented by 1 tile in collective calls, advance once again for ping pong + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + + // Get next work tile + scheduler.advance_to_next_work(NumMmaWarpGroups); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + } // Consumer Warp Groups End +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index 0a3c3b0b1..d76f95f56 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -66,6 +67,7 @@ enum class SplitKStyle { NO_SPLIT_K, SPLIT_K_SERIAL, + STREAM_K, // Sm80+ // SPLIT_K_PARALLEL // Not supported yet }; @@ -110,7 +112,9 @@ enum class ClusterShape ClusterShape_1x1x1, ClusterShape_2x1x1, ClusterShape_1x2x1, - ClusterShape_2x2x1 + ClusterShape_2x2x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 }; struct CutlassGemmConfig @@ -185,5 +189,26 @@ struct CutlassGemmConfig } }; +inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) +{ + // clang-format off + if (config.is_sm90) + { + out << "tile_config_sm90_enum: " << int(config.tile_config_sm90) + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape); + } + else + { + out << "tile_config_enum: " << int(config.tile_config) + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages; + } + // clang-format on + return out; +} + } // namespace cutlass_extensions } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a index a3f5b9575..95f4fcd0f 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dc6db658e67f3ea6a11bc37be2b90090f23b831719559a3202bf8b2b397f6f3b -size 1334290 +oid sha256:4321e6f05a2010b566249058795d9af2e057628c2dd5a0cc584374721179d811 +size 1364476 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index a3f5b9575..95f4fcd0f 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dc6db658e67f3ea6a11bc37be2b90090f23b831719559a3202bf8b2b397f6f3b -size 1334290 +oid sha256:4321e6f05a2010b566249058795d9af2e057628c2dd5a0cc584374721179d811 +size 1364476 diff --git a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt index b2b662b4a..ccd22a55c 100644 --- a/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/executor/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -8e5f1d6bb88c80004b4260aa2d022420 libtensorrt_llm_executor_static.a -8e5f1d6bb88c80004b4260aa2d022420 libtensorrt_llm_executor_static.pre_cxx11.a -fc46fa01e555f9f97387340e46e9571fabf73988 commit \ No newline at end of file +f1aa8db2043d7aa305950d708ba7a2c0 libtensorrt_llm_executor_static.a +f1aa8db2043d7aa305950d708ba7a2c0 libtensorrt_llm_executor_static.pre_cxx11.a +736b3fc4259916d31211104b91e6b2b4db995b17 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a index 28271595d..1ac1dccbf 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f0421ca1e637adfdebc9718c47537ed81b55cad4f7fbd062b1d83ca0ab7ebbe5 -size 1371948 +oid sha256:cf64e85fa387f57f1d76a3245785fc1d13c281c9c973c74e53b8cbd2ae2082d9 +size 1408774 diff --git a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a index d0ec82211..c7915b383 100644 --- a/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/executor/x86_64-linux-gnu/libtensorrt_llm_executor_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:46b6253eef9136f91d2877e9baa827a8ff229b54b6fb1f2717fb6c85a7ffa047 -size 1306830 +oid sha256:96a753478e563fd6769f6f7ed3c19d963fb3929ff8124af2bf1ee56408ae74dc +size 1344186 diff --git a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib index 5aed4916b..433304114 100644 --- a/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib +++ b/cpp/tensorrt_llm/executor/x86_64-windows-msvc/tensorrt_llm_executor_static.lib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fa1bde9d020eac84321954b0cb393ec7af63393e29947d6c147700063b0267da -size 12726212 +oid sha256:eb7c4993d6cd8d44a5fc9200b4976d293f3f484a7e44264b0c8a7a72566eb9a9 +size 12987162 diff --git a/cpp/tensorrt_llm/executor_worker/executorWorker.cpp b/cpp/tensorrt_llm/executor_worker/executorWorker.cpp index b057b1839..c777ce21b 100644 --- a/cpp/tensorrt_llm/executor_worker/executorWorker.cpp +++ b/cpp/tensorrt_llm/executor_worker/executorWorker.cpp @@ -20,6 +20,7 @@ #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/serialization.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" +#include namespace tle = tensorrt_llm::executor; @@ -29,7 +30,7 @@ int main(int argc, char* argv[]) // Register the TRT-LLM plugins initTrtLlmPlugins(); - tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE); + tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE, true); MPI_Comm parentComm; MPI_Comm_get_parent(&parentComm); diff --git a/cpp/tensorrt_llm/kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/CMakeLists.txt index 9364d59cd..81dde2963 100644 --- a/cpp/tensorrt_llm/kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/CMakeLists.txt @@ -18,12 +18,44 @@ file(GLOB_RECURSE SRC_CPP *.cpp) file(GLOB_RECURSE SRC_CU *.cu) -# Exclude files in the cutlass_kernels and unfusedAttentionKernels folder +# Exclude files in the cutlass_kernels and decoderMaskedMultiheadAttention +# folder list(FILTER SRC_CPP EXCLUDE REGEX "cutlass_kernels/.*") list(FILTER SRC_CU EXCLUDE REGEX "cutlass_kernels/.*") list(FILTER SRC_CPP EXCLUDE REGEX "decoderMaskedMultiheadAttention/.*") list(FILTER SRC_CU EXCLUDE REGEX "decoderMaskedMultiheadAttention/.*") +function(filter_cuda_archs ARCH SOURCES_VAR) + if(NOT "${ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES + AND NOT "${ARCH}-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + set(FILTER_REGEX ".*_sm(_)?${ARCH}[.]cubin[.]cpp") + list(APPEND SOURCES ${${SOURCES_VAR}}) + list(APPEND SOURCES_FILTERED ${SOURCES}) + list(FILTER SOURCES_FILTERED INCLUDE REGEX "${FILTER_REGEX}") + list(LENGTH SOURCES_FILTERED SOURCES_FILTERED_LEN) + message( + STATUS + "Excluding ${SOURCES_FILTERED_LEN} cubins for SM ${ARCH} from ${CMAKE_CURRENT_SOURCE_DIR}" + ) + foreach(filtered_item ${SOURCES_FILTERED}) + message(VERBOSE "- ${filtered_item}") + endforeach() + list(FILTER SOURCES EXCLUDE REGEX "${FILTER_REGEX}") + set(${SOURCES_VAR} + "${SOURCES}" + PARENT_SCOPE) + if(SOURCES_FILTERED_LEN GREATER 0) + add_compile_definitions("EXCLUDE_SM_${ARCH}") + endif() + endif() +endfunction() + +filter_cuda_archs("70" SRC_CPP) +filter_cuda_archs("80" SRC_CPP) +filter_cuda_archs("86" SRC_CPP) +filter_cuda_archs("89" SRC_CPP) +filter_cuda_archs("90" SRC_CPP) + if(ENABLE_MULTI_DEVICE EQUAL 0) list(FILTER SRC_CU EXCLUDE REGEX "customAllReduceKernels*.*cu$") endif() diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h index 3fd399b93..17e946745 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h @@ -24,7 +24,7 @@ namespace kernels - +#ifndef EXCLUDE_SM_90 extern unsigned char cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin[]; @@ -215,270 +215,6 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_12 extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm90_cu_cubin[]; @@ -523,138 +259,6 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm90 extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin[]; @@ -677,27 +281,6 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128 extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_80_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_96_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_104_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_128_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_80_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_96_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_104_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm70_cu_cubin[]; -// QK Tanh Scale. extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin[]; @@ -714,50 +297,12 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tan extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin[]; extern uint32_t cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len; @@ -949,270 +494,6 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_ali extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm90_cu_cubin_len; @@ -1257,138 +538,6 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm90_cu_c extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin_len; @@ -1411,27 +560,6 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90 extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_80_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_96_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_104_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_128_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_80_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_96_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_104_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm70_cu_cubin_len; -// QK Tanh Scale. extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len; @@ -1448,1527 +576,1626 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin_len; +#endif -static const struct FusedMultiHeadAttentionKernelMetaInfoV2 -{ - Data_type mDataType; - unsigned int mS; - unsigned int mStepQ; - unsigned int mStepKV; - unsigned int mD; - unsigned int mSM; - const unsigned char* mCubin; - unsigned int mCubinSize; - const char* mFuncName; - unsigned int mSharedMemBytes; - unsigned int mThreadsPerCTA; - unsigned int mUnrollStep; - bool mInterleaved; - bool mFlashAttention; - bool mWarpSpecialization; - bool mFP32Accumulation; - int mAttentionMaskType; - bool mAlibiSupported; - bool mTiled; - bool mPagedKV; - bool mEnableQKTanhScale; -} sMhaKernelMetaInfosV2[] = { -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_sliding_window_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_sliding_window_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_sliding_window_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_sliding_window_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_sliding_window_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_sliding_window_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_sliding_window_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_sliding_window_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_sliding_window_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_sliding_window_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, false, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_sliding_window_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_sliding_window_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_sliding_window_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_sliding_window_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_sliding_window_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_sliding_window_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_sliding_window_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_sliding_window_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_sliding_window_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_sliding_window_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_sliding_window_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_sliding_window_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_sliding_window_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_sliding_window_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_sliding_window_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_sliding_window_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_sliding_window_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_sliding_window_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_sliding_window_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_sliding_window_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_causal_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 78208, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_causal_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_causal_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_causal_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_sliding_window_causal_tma_ws_sm90_kernel", 78208, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_causal_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_sliding_window_causal_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_causal_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_sliding_window_causal_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 78208, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_causal_alibi_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 78208, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sliding_window_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sliding_window_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sliding_window_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sliding_window_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sliding_window_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sliding_window_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sliding_window_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sliding_window_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sliding_window_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sliding_window_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sliding_window_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sliding_window_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sliding_window_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sliding_window_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +#ifndef EXCLUDE_SM_89 +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin[]; + +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len; +#endif + +#ifndef EXCLUDE_SM_80 +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin[]; + +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len; +#endif + +#ifndef EXCLUDE_SM_86 +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin[]; + +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len; +#endif + +#ifndef EXCLUDE_SM_70 +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_80_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_96_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_104_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_128_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_80_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_96_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_104_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm70_cu_cubin[]; +// QK Tanh Scale. +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin[]; +extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin[]; + + +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_80_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_96_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_104_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_128_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_80_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_96_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_104_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm70_cu_cubin_len; +// QK Tanh Scale. +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin_len; +extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin_len; +#endif + +static const struct FusedMultiHeadAttentionKernelMetaInfoV2 +{ + Data_type mDataType; + unsigned int mS; + unsigned int mStepQ; + unsigned int mStepKV; + unsigned int mD; + unsigned int mSM; + const unsigned char* mCubin; + unsigned int mCubinSize; + const char* mFuncName; + unsigned int mSharedMemBytes; + unsigned int mThreadsPerCTA; + unsigned int mUnrollStep; + bool mInterleaved; + bool mFlashAttention; + bool mWarpSpecialization; + bool mFP32Accumulation; + int mAttentionMaskType; + bool mAlibiSupported; + bool mTiled; + bool mPagedKV; + bool mEnableQKTanhScale; +} sMhaKernelMetaInfosV2[] = { +#ifndef EXCLUDE_SM_90 +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_sliding_window_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_sliding_window_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_sliding_window_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_sliding_window_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_64_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_sliding_window_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_128_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_sliding_window_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_256_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_sliding_window_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_sliding_window_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_sliding_window_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_384_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_sliding_window_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_512_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, false, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_sliding_window_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_sliding_window_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_sliding_window_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_sliding_window_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_64_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_sliding_window_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_128_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_sliding_window_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_256_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_sliding_window_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_sliding_window_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_sliding_window_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_384_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_sliding_window_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_bf16_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_bf16_512_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_sliding_window_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_causal_ldgsts_sm90_kernel", 17408, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 32, kSM_90, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 17408, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_sliding_window_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_causal_ldgsts_sm90_kernel", 25600, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 32, kSM_90, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 25600, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_sliding_window_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_causal_ldgsts_sm90_kernel", 41984, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 32, kSM_90, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 41984, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_sliding_window_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_causal_ldgsts_sm90_kernel", 33792, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 64, 64, 64, 64, kSM_90, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_64_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_64_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 33792, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_sliding_window_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_causal_ldgsts_sm90_kernel", 50176, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 128, 64, 128, 64, kSM_90, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_128_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_128_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 50176, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_sliding_window_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_causal_ldgsts_sm90_kernel", 82944, 128, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 256, 64, 256, 64, kSM_90, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_256_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_256_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 82944, 128, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_sliding_window_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_causal_ldgsts_sm90_kernel", 67072, 256, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 32, kSM_90, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 67072, 256, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_sliding_window_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_causal_ldgsts_sm90_kernel", 83456, 256, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 32, kSM_90, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_32_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_32_sliding_window_causal_ldgsts_sm90_kernel_nl", 83456, 256, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_sliding_window_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_causal_ldgsts_sm90_kernel", 132608, 256, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 384, 64, 384, 64, kSM_90, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_384_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_384_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 132608, 256, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_sliding_window_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_causal_ldgsts_sm90_kernel", 165376, 256, 0, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 512, 64, 512, 64, kSM_90, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin, cubin_fmha_v2_fp16_fp32_512_64_ldgsts_sm90_cu_cubin_len, "fmha_v2_fp16_fp32_512_64_sliding_window_causal_ldgsts_sm90_kernel_nl", 165376, 256, 64, false, false, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_256_S_pagedKV_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_256_S_pagedKV_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_causal_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 78208, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_causal_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_causal_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_causal_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_sliding_window_causal_tma_ws_sm90_kernel", 78208, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_causal_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_sliding_window_causal_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_causal_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_sliding_window_causal_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_causal_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 78208, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_causal_alibi_tma_ws_sm90_kernel", 82304, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 78208, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 156032, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_E4M3, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_pagedKV_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_sliding_window_causal_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_sliding_window_causal_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 256, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_256_S_pagedKV_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false, true, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, true, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_causal_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, true, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sliding_window_causal_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, true, false}, @@ -3101,6 +2328,839 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, { DATA_TYPE_BF16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, { DATA_TYPE_BF16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_causal_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sliding_window_causal_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_causal_sm90_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sliding_window_causal_sm90_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_causal_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sliding_window_causal_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_causal_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sliding_window_causal_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_causal_sm90_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sliding_window_causal_sm90_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_causal_sm90_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sliding_window_causal_sm90_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_causal_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sliding_window_causal_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_causal_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sliding_window_causal_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sliding_window_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sliding_window_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sliding_window_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_causal_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_causal_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, true}, +{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, +#endif + +#ifndef EXCLUDE_SM_89 +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, +#endif +#ifndef EXCLUDE_SM_80 +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sliding_window_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sliding_window_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sliding_window_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sliding_window_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sliding_window_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sliding_window_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sliding_window_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sliding_window_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sliding_window_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sliding_window_causal_sm80_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sliding_window_causal_sm80_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sliding_window_causal_sm80_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sliding_window_causal_sm80_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sliding_window_causal_sm80_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sliding_window_causal_sm80_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, @@ -3233,6 +3293,308 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, { DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, { DATA_TYPE_FP16, 0, 64, 16, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, +#endif +#ifndef EXCLUDE_SM_86 +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_32_sliding_window_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_40_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_64_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_80_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, false, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_128_128_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_32_sliding_window_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_40_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_64_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_80_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 128, 128, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_128_128_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_BF16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, { DATA_TYPE_FP16, 0, 128, 128, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sliding_window_causal_sm86_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, @@ -3319,250 +3681,90 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, 0, 64, 128, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, { DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, { DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, false, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sliding_window_causal_sm89_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sliding_window_causal_sm89_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sliding_window_causal_sm89_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_causal_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_16_sliding_window_causal_sm90_kernel_nl_tiled", 16384, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_causal_sm90_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_32_sliding_window_causal_sm90_kernel_nl_tiled", 32768, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_causal_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_40_sliding_window_causal_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_causal_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 128, 128, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_128_128_S_pagedKV_64_sliding_window_causal_sm90_kernel_nl_tiled", 65536, 128, 128, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_80_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_96_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_causal_sm90_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sliding_window_causal_sm90_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_causal_sm90_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sliding_window_causal_sm90_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_causal_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sliding_window_causal_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_causal_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sliding_window_causal_sm90_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sliding_window_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sliding_window_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sliding_window_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, -{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 16, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_16_sliding_window_causal_sm86_kernel_nl", 6144, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 64, 32, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_pagedKV_32_sliding_window_causal_sm86_kernel_nl", 12288, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 40, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_40_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 64, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_64_sliding_window_causal_sm86_kernel_nl", 16384, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 80, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_80_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 96, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_96_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 104, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_104_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 16, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_pagedKV_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false, true, false}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, +{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, +{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, +#endif +#ifndef EXCLUDE_SM_70 { DATA_TYPE_FP16, 0, 64, 64, 32, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sm70_kernel_nl", 12288, 128, 64, false, true, false, false, 0, true, false, false, false}, { DATA_TYPE_FP16, 0, 64, 64, 32, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_causal_sm70_kernel_nl", 12288, 128, 64, false, true, false, false, 1, true, false, false, false}, { DATA_TYPE_FP16, 0, 64, 64, 32, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_32_sliding_window_causal_sm70_kernel_nl", 12288, 128, 64, false, true, false, false, 2, true, false, false, false}, @@ -3624,186 +3826,13 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, 0, 64, 16, 256, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_causal_sm70_kernel_nl", 65536, 128, 64, false, true, false, false, 1, true, false, true, false}, { DATA_TYPE_FP16, 0, 64, 16, 256, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_256_sliding_window_causal_sm70_kernel_nl", 65536, 128, 64, false, true, false, false, 2, true, false, true, false}, // QK Tanh. -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, false, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, false, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, false, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, false, 2, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, false, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_causal_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, false, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, false, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false, true, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_causal_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false, true, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, false, false, true, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, false, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, false, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, false, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false, true, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false, true, true}, -{ DATA_TYPE_E4M3, 0, 64, 256, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_pagedKV_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, false, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, false, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, false, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_alibi_qk_tanh_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_alibi_qk_tanh_tma_ws_sm90_kernel", 164096, 384, 64, false, true, true, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, -{ DATA_TYPE_BF16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm80_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm86_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, false, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm89_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 128, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 0, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 1, true, false, true, true}, -{ DATA_TYPE_FP16, 0, 64, 32, 128, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_qk_tanh_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_pagedKV_128_sliding_window_causal_qk_tanh_sm90_kernel_nl", 32768, 128, 64, false, true, false, true, 2, true, false, true, true}, { DATA_TYPE_FP16, 0, 64, 16, 128, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, false, true}, { DATA_TYPE_FP16, 0, 64, 16, 128, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_128_causal_qk_tanh_sm70_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, false, true}, { DATA_TYPE_FP16, 0, 64, 16, 128, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_128_qk_tanh_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_128_sliding_window_causal_qk_tanh_sm70_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, false, true}, { DATA_TYPE_FP16, 0, 64, 16, 128, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_kernel_nl", 32768, 128, 64, false, true, false, false, 0, true, false, true, true}, { DATA_TYPE_FP16, 0, 64, 16, 128, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_causal_qk_tanh_sm70_kernel_nl", 32768, 128, 64, false, true, false, false, 1, true, false, true, true}, { DATA_TYPE_FP16, 0, 64, 16, 128, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_qk_tanh_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_pagedKV_128_sliding_window_causal_qk_tanh_sm70_kernel_nl", 32768, 128, 64, false, true, false, false, 2, true, false, true, true} +#endif }; // clang-format on diff --git a/cpp/tensorrt_llm/kernels/cumsumLastDim.cu b/cpp/tensorrt_llm/kernels/cumsumLastDim.cu index daed22a64..f22d3a7da 100644 --- a/cpp/tensorrt_llm/kernels/cumsumLastDim.cu +++ b/cpp/tensorrt_llm/kernels/cumsumLastDim.cu @@ -16,25 +16,28 @@ #include -#include - #include "cumsumLastDim.h" +#include + namespace tensorrt_llm { namespace kernels { -template -size_t invokeComputeCumsumLastDimWorkspaceSize(int input_length) + +/////////////// + +template +size_t invokeComputeCumsumLastDimWorkspaceSize(SizeType32 inputLength) { - input_t* iodata = nullptr; - size_t temp_storage_bytes; - cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, iodata, iodata, input_length); - return temp_storage_bytes; + T* iodata = nullptr; + size_t tempStorageBytes; + cub::DeviceScan::InclusiveSum(nullptr, tempStorageBytes, iodata, iodata, inputLength); + return tempStorageBytes; } -#define INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(input_t) \ - template size_t invokeComputeCumsumLastDimWorkspaceSize(int input_length) +#define INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(T) \ + template size_t invokeComputeCumsumLastDimWorkspaceSize(int inputLength) INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(int); INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(float); @@ -46,21 +49,111 @@ INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(__nv_bfloat16); /////////////// -template -void invokeCumsumLastDim(int batch_size, int input_length, void const* __restrict__ input, void* __restrict__ output, - void* d_temp_storage, size_t temp_storage_bytes, cudaStream_t stream) +template +__global__ void cumsum_last_dim(T const* d_in, T* d_out, int length) { - for (int i = 0; i < batch_size; i++) + typedef cub::BlockLoad BlockLoadT; + typedef cub::BlockStore BlockStoreT; + typedef cub::BlockScan BlockScanT; + + int const row_idx = blockIdx.x; + T const* local_d_in = d_in + row_idx * length; + T* local_d_out = d_out + row_idx * length; + + // Shared memory + __shared__ union TempStorage + { + typename BlockLoadT::TempStorage load; + typename BlockStoreT::TempStorage store; + typename BlockScanT::TempStorage scan; + } temp_storage; + + int tile_size = THREADS_PER_BLOCK * ITEMS_PER_THREAD; + T aggregate = static_cast(0); + T const* cur_d_in = local_d_in; + T* cur_d_out = local_d_out; + for (int tile_start = 0; tile_start < length; + tile_start += tile_size, cur_d_in += tile_size, cur_d_out += tile_size) + { + int cur_tile_size = (tile_start + tile_size) <= length ? tile_size : (length - tile_start); + T data[ITEMS_PER_THREAD]; // Per-thread tile data + + // Load items into a blocked arrangement + BlockLoadT(temp_storage.load).Load(cur_d_in, data, cur_tile_size, static_cast(0)); + if (threadIdx.x == 0) + { + data[0] += aggregate; + } + __syncthreads(); + + BlockScanT(temp_storage.scan).InclusiveSum(data, data, aggregate); + __syncthreads(); + + // Store items from a blocked arrangement + BlockStoreT(temp_storage.store).Store(cur_d_out, data, cur_tile_size); + } +} + +/////////////// + +template +void invokeDeviceScan(SizeType32 batchSize, SizeType32 inputLength, void const* __restrict__ input, + void* __restrict__ output, void* d_temp_storage, size_t tempStorageBytes, cudaStream_t stream) +{ + for (SizeType32 i = 0; i < batchSize; i++) + { + T const* inputPtr = reinterpret_cast(input) + i * inputLength; + T* outputPtr = reinterpret_cast(output) + i * inputLength; + cub::DeviceScan::InclusiveSum(d_temp_storage, tempStorageBytes, inputPtr, outputPtr, inputLength, stream); + } +} + +/////////////// + +template +void invokeCumsumLastDim(SizeType32 batchSize, SizeType32 inputLength, void const* __restrict__ input, + void* __restrict__ output, void* deviceTempStorage, size_t tempStorageBytes, cudaStream_t stream) +{ + + if (deviceTempStorage != nullptr) // we need to use DeviceScan + { + invokeDeviceScan(batchSize, inputLength, input, output, deviceTempStorage, tempStorageBytes, stream); + return; + } + + T const* inputPtr = reinterpret_cast(input); + T* outputPtr = reinterpret_cast(output); + + // Launch the kernel + if (inputLength <= 64) + { + int const ITP = 1; + int const TPB = 32; + const size_t SHMEM = sizeof(T) * TPB * ITP; + const cub::BlockScanAlgorithm ALG = cub::BLOCK_SCAN_WARP_SCANS; + cumsum_last_dim<<>>(inputPtr, outputPtr, inputLength); + } + else if (inputLength < 512) + { + int const ITP = 2; + int const TPB = 64; + const size_t SHMEM = sizeof(T) * TPB * ITP; + const cub::BlockScanAlgorithm ALG = cub::BLOCK_SCAN_WARP_SCANS; + cumsum_last_dim<<>>(inputPtr, outputPtr, inputLength); + } + else // if () { - input_t const* input_ptr = reinterpret_cast(input) + i * input_length; - input_t* output_ptr = reinterpret_cast(output) + i * input_length; - cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, input_ptr, output_ptr, input_length, stream); + int const ITP = 8; + int const TPB = 256; + const size_t SHMEM = sizeof(T) * TPB * ITP; + const cub::BlockScanAlgorithm ALG = cub::BLOCK_SCAN_WARP_SCANS; + cumsum_last_dim<<>>(inputPtr, outputPtr, inputLength); } } -#define INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(input_t) \ - template void invokeCumsumLastDim(int batch_size, int input_length, const void* __restrict__ input, \ - void* __restrict__ output, void* workspace, size_t temp_storage_bytes, cudaStream_t stream) +#define INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(T) \ + template void invokeCumsumLastDim(SizeType32 batchSize, SizeType32 inputLength, const void* __restrict__ input, \ + void* __restrict__ output, void* workspace, size_t tempStorageBytes, cudaStream_t stream) INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(int); INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(float); diff --git a/cpp/tensorrt_llm/kernels/cumsumLastDim.h b/cpp/tensorrt_llm/kernels/cumsumLastDim.h index 6955acc2e..2266f685e 100644 --- a/cpp/tensorrt_llm/kernels/cumsumLastDim.h +++ b/cpp/tensorrt_llm/kernels/cumsumLastDim.h @@ -18,18 +18,20 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/runtime/common.h" namespace tensorrt_llm { namespace kernels { +using SizeType32 = tensorrt_llm::runtime::SizeType32; -template -size_t invokeComputeCumsumLastDimWorkspaceSize(int input_length); +template +size_t invokeComputeCumsumLastDimWorkspaceSize(SizeType32 inputLength); -template -void invokeCumsumLastDim(int batch_size, int input_length, void const* __restrict__ input, void* __restrict__ output, - void* workspace, size_t temp_storage_bytes, cudaStream_t stream); +template +void invokeCumsumLastDim(SizeType32 batchSize, SizeType32 inputLength, void const* __restrict__ input, + void* __restrict__ output, void* workspace, size_t tempStorageBytes, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt index 2672dc53e..4d91f6f0a 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt @@ -73,6 +73,8 @@ file(GLOB_RECURSE SRC_CU *.cu) set(ALL_SRCS ${SRC_CPP};${SRC_CU}) list(FILTER ALL_SRCS EXCLUDE REGEX "fpA_intB_gemm/.*") list(FILTER ALL_SRCS EXCLUDE REGEX "moe_gemm/.*") +list(REMOVE_ITEM ALL_SRCS + "${CMAKE_CURRENT_SOURCE_DIR}/fused_gated_gemm/gemm_swiglu_e4m3.cu") message( STATUS @@ -92,7 +94,11 @@ add_library(fpA_intB_gemm_src STATIC ${MIXED_SRC_CPP} ${MIXED_SRC_CU} # WARNING: Building with `-G` flag may generate invalid results for this target add_library(moe_gemm_src STATIC ${GROUPED_SRC_CU} ${GROUPED_SRC_CPP} ${GROUPED_CU_INSTANTIATIONS}) -foreach(target_name fpA_intB_gemm_src;moe_gemm_src) + +set(GEMM_SWIGLU_SM90_SRC_CU + ${CMAKE_CURRENT_SOURCE_DIR}/fused_gated_gemm/gemm_swiglu_e4m3.cu) +add_library(gemm_swiglu_sm90_src STATIC ${GEMM_SWIGLU_SM90_SRC_CU}) +foreach(target_name fpA_intB_gemm_src;moe_gemm_src;gemm_swiglu_sm90_src) set_property(TARGET ${target_name} PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET ${target_name} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) @@ -100,8 +106,8 @@ foreach(target_name fpA_intB_gemm_src;moe_gemm_src) # specified). This is because sm_90a has arch conditional instructions that # are not forward compatible. As a result, it does not make sense to embed PTX # into the binary anyway. - if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST - OR "9.0+PTX" IN_LIST TORCH_CUDA_ARCH_LIST + if("90" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG + OR "90-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG OR "90-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_NATIVE) message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.") diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h index a64b908dd..ffd765cef 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -19,6 +19,7 @@ #include "cutlass_extensions/gemm_configs.h" #include "cutlass_extensions/weight_only_quant_op.h" #include +#include namespace tkc = tensorrt_llm::cutlass_extensions; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm.h new file mode 100644 index 000000000..6e670d2d3 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/gemm_configs.h" +#include "tensorrt_llm/common/quantization.h" + +#include +#include + +namespace tk = tensorrt_llm::common; +namespace tkc = tensorrt_llm::cutlass_extensions; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +/* + This runner supports: + + Activations and outputs are all assumed to be row-major. + Weights are assumed to be column-major. +*/ + +class CutlassFusedGatedGemmRunnerInterface +{ +public: + CutlassFusedGatedGemmRunnerInterface() {} + + virtual ~CutlassFusedGatedGemmRunnerInterface() {} + + virtual void gemm(void* D, void const* A, void const* B, void const* C_bias, tk::QuantMode quantOption, int m, + int n, int k, float scale_d0, float scale_d1, float scale_output, tkc::CutlassGemmConfig gemmConfig, + char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) + = 0; + + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; + + virtual std::vector getConfigs() const = 0; +}; + +template +class CutlassFusedGatedGemmRunner : public virtual CutlassFusedGatedGemmRunnerInterface +{ +public: + CutlassFusedGatedGemmRunner(); + ~CutlassFusedGatedGemmRunner(); + + void gemm(void* D, void const* A, void const* B, void const* C_bias, tk::QuantMode quantOption, int m, int n, int k, + float scale_d0, float scale_d1, float scale_output, tkc::CutlassGemmConfig gemmConfig, char* workspace, + size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) override; + + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(int const m, int const n, int const k) override; + + std::vector getConfigs() const override; + +private: + size_t dispatchToArch(void* D, void const* A, void const* B, void const* C_bias, tk::QuantMode quantOption, int m, + int n, int k, float scale_d0, float scale_d1, float scale_output, tkc::CutlassGemmConfig gemmConfig, + char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr); + + size_t getWorkspaceSizeImpl(int const m, int const n, int const k); + + int mSm; +}; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm_kernel_template_sm90.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm_kernel_template_sm90.h new file mode 100644 index 000000000..8637f49cf --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm_kernel_template_sm90.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +#include "cute/tensor.hpp" +#include "cutlass/conv/convolution.h" +// Order matters here, packed_stride.hpp is missing cute and convolution includes +#include "cutlass/util/packed_stride.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass_extensions/gemm/collective/collective_builder_gated.hpp" +#include "cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +using namespace cute; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +template class Activation = cutlass::epilogue::thread::SiLu, bool SwapAB = false> +struct DeviceGemmGatedSm90 +{ + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + // A matrix configuration + using ElementA = ElementType; // Element type for A matrix operand + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = ElementType; // Element type for B matrix operand + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = ElementType; // Element type for C matrix operands + // using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands + using LayoutC = cute::conditional_t; + static constexpr int AlignmentC + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in units of + // elements (up to 16 bytes) + + // Output matrix configuration + using ElementOutput = ElementType; // Element type for output matrix operands + // using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands + using LayoutOutput = cute::conditional_t; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = AccumElementType; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + using KernelSchedule = MainloopScheduleType; + using EpilogueSchedule = EpilogueScheduleType; + using TileScheduler = TileSchedulerType; + + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using FusionOperation = cutlass::epilogue::fusion::ScaledAcc; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderGated( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule, Activation, SwapAB>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversalGated, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm_template.h new file mode 100644 index 000000000..8ef0f9bb6 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm_template.h @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +#include "cute/tensor.hpp" +#include "cutlass/conv/convolution.h" +// Order matters here, packed_stride.hpp is missing cute and convolution includes +#include "cutlass/util/packed_stride.hpp" +#include "cutlass_extensions/gemm_configs.h" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +#include "fused_gated_gemm.h" +#include "fused_gated_gemm_kernel_template_sm90.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/quantization.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" + +#include +#include + +namespace tk = tensorrt_llm::common; +namespace tkc = tensorrt_llm::cutlass_extensions; + +using namespace cute; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +template +size_t typedGemmGatedKernelLauncher(Gemm gemm, typename Gemm::Arguments args, void* D, void const* A, void const* B, + void const* C_bias, char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + using ElementT = typename Gemm::ElementA; + + // Check shared memory size; throw when SMEM exceeds + int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); + static int mMaxSmemSize = tk::getMaxSharedMemoryPerBlockOptin(); + if (smem_size > mMaxSmemSize) + { + std::string errMsg = "SMEM size exceeds maximum allowed. Required " + std::to_string(smem_size) + ", got " + + std::to_string(mMaxSmemSize); + throw std::runtime_error("[TensorRT-LLM Error][fusedGatedGemm Runner] " + errMsg); + } + + // Return workspace size + if (!A && !B && !C_bias && !D) + { + return gemm.get_workspace_size(args); + } + + if (gemm.get_workspace_size(args) > workspaceBytes) + { + std::string errMsg("Requested workspace size insufficient. Required " + + std::to_string(gemm.get_workspace_size(args)) + ", got " + std::to_string(workspaceBytes)); + throw std::runtime_error("[TensorRT-LLM Error][fusedGatedGemm Runner] " + errMsg); + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) + { + std::string errMsg = "fusedGatedGemm cutlass kernel not implemented given the params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[TensorRT-LLM Error][fusedGatedGemm Runner] " + errMsg); + } + + auto initStatus = gemm.initialize(args, workspace, stream); + if (initStatus != cutlass::Status::kSuccess) + { + std::string errMsg = "Failed to initialize. Error: " + std::string(cutlassGetStatusString(initStatus)); + throw std::runtime_error("[TensorRT-LLM Error][fusedGatedGemm Runner] " + errMsg); + } + + auto runStatus = gemm.run(stream); + if (runStatus != cutlass::Status::kSuccess) + { + std::string errMsg = "Failed to run gemm. Error: " + std::string(cutlassGetStatusString(runStatus)); + throw std::runtime_error("[TensorRT-LLM Error][fusedGatedGemm Runner] " + errMsg); + } + return gemm.get_workspace_size(args); +} + +template +typename Gemm::Arguments prepareGemmArgsSm90(void* D, void const* A, void const* B, void const* C_bias, + tk::QuantMode quantOption, int m, int n, int k, float scale_d0, float scale_d1, float scale_output, + tkc::CutlassGemmConfig gemmConfig) +{ + using ElementT = typename Gemm::ElementA; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + int arg_m = m; + int arg_n = n / 2; + ElementT const* ptr_A = reinterpret_cast(A); + ElementT const* ptr_B = reinterpret_cast(B); + if constexpr (SwapAB) + { + arg_m = n / 2; + arg_n = m; + ptr_A = reinterpret_cast(B); + ptr_B = reinterpret_cast(A); + } + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(arg_m, k, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(arg_n, k, 1)); + StrideC stride_C; + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(arg_m, arg_n, 1)); + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, {arg_m, arg_n, k, 1}, + {ptr_A, stride_A, ptr_B, stride_B, scale_d0, scale_d1}, + {{}, // epilogue.thread + nullptr, stride_C, reinterpret_cast(D), stride_D}}; + args.epilogue.thread.alpha = scale_output; + return args; +} + +template typename Activation = cutlass::epilogue::thread::SiLu, bool SwapAB = true> +size_t genericGemmGatedKernelLauncherSm90(void* D, void const* A, void const* B, void const* C_bias, + tk::QuantMode quantOption, int m, int n, int k, float scale_d0, float scale_d1, float scale_output, + tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t workspaceBytes, cudaStream_t stream, + int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + +#ifdef COMPILE_HOPPER_TMA_GEMMS + using ElementT = typename TllmToCutlassTypeAdapter::type; + using AccumElementType = float; + using MainloopScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using EpilogueScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, + cutlass::epilogue::TmaWarpSpecialized, cutlass::epilogue::TmaWarpSpecializedCooperative>; + using TileSchedulerType = void; + using Gemm = typename DeviceGemmGatedSm90::Gemm; + auto args = prepareGemmArgsSm90( + D, A, B, C_bias, quantOption, m, n, k, scale_d0, scale_d1, scale_output, gemmConfig); + return typedGemmGatedKernelLauncher(Gemm{}, args, D, A, B, C_bias, workspace, workspaceBytes, stream, occupancy); +#else // COMPILE_HOPPER_TMA_GEMMS + throw std::runtime_error( + "[TensorRT-LLm Error][GemmGatedKernelLauncherSm90] Please recompile with support for hopper by passing 90-real " + "as an arch to build_wheel.py."); +#endif // COMPILE_HOPPER_TMA_GEMMS +} + +template +size_t dispatchGemmConfigSm90(void* D, void const* A, void const* B, void const* C_bias, tk::QuantMode quantOption, + int m, int n, int k, float scale_d0, float scale_d1, float scale_output, tkc::CutlassGemmConfig gemmConfig, + char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemmConfig.cluster_shape) + { + case tkc::ClusterShape::ClusterShape_1x1x1: + return genericGemmGatedKernelLauncherSm90>(D, A, B, C_bias, quantOption, m, n, k, + scale_d0, scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_2x1x1: + return genericGemmGatedKernelLauncherSm90>(D, A, B, C_bias, quantOption, m, n, k, + scale_d0, scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_1x2x1: + return genericGemmGatedKernelLauncherSm90>(D, A, B, C_bias, quantOption, m, n, k, + scale_d0, scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_2x2x1: + return genericGemmGatedKernelLauncherSm90>(D, A, B, C_bias, quantOption, m, n, k, + scale_d0, scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_1x8x1: + return genericGemmGatedKernelLauncherSm90>(D, A, B, C_bias, quantOption, m, n, k, + scale_d0, scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_8x1x1: + return genericGemmGatedKernelLauncherSm90>(D, A, B, C_bias, quantOption, m, n, k, + scale_d0, scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + default: + throw std::runtime_error( + "[TensorRT-LLM Error][CutlassFusedGatedGemmRunner][dispatchGemmConfigSm90] Config is invalid for fused " + "gated GEMM."); + break; + } +} + +template +size_t dispatchGemmToCutlassSm90(void* D, void const* A, void const* B, void const* C_bias, tk::QuantMode quantOption, + int m, int n, int k, float scale_d0, float scale_d1, float scale_output, tkc::CutlassGemmConfig gemmConfig, + char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + static_assert(std::is_same_v, "fusedGatedGemmSm90 only support FP8(e4m3)"); + constexpr int Ktile = 128 / sizeof(T); + using _Ktile = Int; + switch (gemmConfig.tile_config_sm90) + { + case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: + return dispatchGemmConfigSm90>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, + scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x32x128B: + return dispatchGemmConfigSm90>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, + scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x64x128B: + return dispatchGemmConfigSm90>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, + scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x128x128B: + return dispatchGemmConfigSm90>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, + scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x16x128B: + return dispatchGemmConfigSm90>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, + scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x32x128B: + return dispatchGemmConfigSm90>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, + scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x64x128B: + return dispatchGemmConfigSm90>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, + scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x128x128B: + return dispatchGemmConfigSm90>(D, A, B, C_bias, quantOption, m, n, k, scale_d0, + scale_d1, scale_output, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::Undefined: + throw std::runtime_error( + "[TensorRT-LLm Error][CutlassFusedGatedGemmRunner][dispatchGemmToCutlassSm90] gemm config undefined."); + break; + case tkc::CutlassTileConfigSM90::ChooseWithHeuristic: + throw std::runtime_error( + "[TensorRT-LLm Error][CutlassFusedGatedGemmRunner][dispatchGemmToCutlassSm90] gemm config should have " + "already been set by " + "heuristic."); + break; + default: + throw std::runtime_error( + "[TensorRT-LLm Error][CutlassFusedGatedGemmRunner][dispatchGemmToCutlassSm90] Config is invalid for fused " + "gated GEMM."); + break; + } +} + +template +CutlassFusedGatedGemmRunner::CutlassFusedGatedGemmRunner() +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + mSm = tk::getSMVersion(); +} + +template +CutlassFusedGatedGemmRunner::~CutlassFusedGatedGemmRunner() +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); +} + +template +size_t CutlassFusedGatedGemmRunner::dispatchToArch(void* D, void const* A, void const* B, void const* C_bias, + tk::QuantMode quantOption, int m, int n, int k, float scale_d0, float scale_d1, float scale_output, + tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr (std::is_same_v) + { + if (mSm == 90) + { + return dispatchGemmToCutlassSm90(D, A, B, C_bias, quantOption, m, n, k, scale_d0, scale_d1, scale_output, + gemmConfig, workspace, workspaceBytes, stream, occupancy); + } + else + { + throw std::runtime_error( + "[TensorRT-LLM Error][CutlassFusedGatedGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS fused " + "gated GEMM"); + } + } + else + { + throw std::runtime_error( + "[TensorRT-LLM Error][CutlassFusedGatedGemmRunner][GEMM Dispatch] dtype unsupported for CUTLASS fused " + "gated " + "GEMM"); + } + return 0; +} + +template +void CutlassFusedGatedGemmRunner::gemm(void* D, void const* A, void const* B, void const* C_bias, + tk::QuantMode quantOption, int m, int n, int k, float scale_d0, float scale_d1, float scale_output, + tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + dispatchToArch(D, A, B, C_bias, quantOption, m, n, k, scale_d0, scale_d1, scale_output, gemmConfig, workspace, + workspaceBytes, stream, occupancy); +} + +template +std::vector CutlassFusedGatedGemmRunner::getConfigs() const +{ + using tkc::CutlassTileConfig; + using tkc::CutlassGemmConfig; + using tkc::SplitKStyle; + + std::vector candidateConfigs; + + if constexpr (std::is_same_v) + { + if (mSm != 90) + { + throw std::runtime_error( + "[TensorRT-LLM Error][CutlassFusedGatedGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS fused " + "gated GEMM"); + } + tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param + = tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; + std::vector commonConfigs = get_candidate_configs(mSm, 2, config_type_param); + candidateConfigs.insert(candidateConfigs.end(), commonConfigs.begin(), commonConfigs.end()); + // registers are not enough when N_tile is 256, remove some configs + candidateConfigs.erase(std::remove_if(candidateConfigs.begin(), candidateConfigs.end(), + [](auto const& config) + { + return config.tile_config_sm90 == tkc::CutlassTileConfigSM90::CtaShape64x256x128B + || config.tile_config_sm90 + == tkc::CutlassTileConfigSM90::CtaShape128x256x128B; + }), + candidateConfigs.end()); + std::vector tilesSm90 + = {tkc::CutlassTileConfigSM90::CtaShape64x16x128B, tkc::CutlassTileConfigSM90::CtaShape64x32x128B, + tkc::CutlassTileConfigSM90::CtaShape64x64x128B, tkc::CutlassTileConfigSM90::CtaShape64x128x128B, + tkc::CutlassTileConfigSM90::CtaShape128x16x128B, tkc::CutlassTileConfigSM90::CtaShape128x32x128B, + tkc::CutlassTileConfigSM90::CtaShape128x64x128B, tkc::CutlassTileConfigSM90::CtaShape128x128x128B}; + for (auto const& tile_config : tilesSm90) + { + { + CutlassGemmConfig config(tile_config, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, + tkc::ClusterShape::ClusterShape_1x8x1); + candidateConfigs.push_back(config); + } + { + CutlassGemmConfig config(tile_config, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, + tkc::ClusterShape::ClusterShape_8x1x1); + candidateConfigs.push_back(config); + } + } + } + else + { + throw std::runtime_error( + "[TensorRT-LLM Error][CutlassFusedGatedGemmRunner][GEMM Dispatch] dtype unsupported for CUTLASS fused " + "gated " + "GEMM"); + } + return candidateConfigs; +} + +// Note: can be quite heavyweight; when possible, call once +template +size_t CutlassFusedGatedGemmRunner::getWorkspaceSizeImpl(int const m, int const n, int const k) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + size_t workspace_size = 0; + auto gemmConfigs = CutlassFusedGatedGemmRunner{}.getConfigs(); + for (auto const& gemmConfig : gemmConfigs) + { + try + { + size_t curr_workspace_size = CutlassFusedGatedGemmRunner::dispatchToArch( + nullptr, nullptr, nullptr, nullptr, tk::QuantMode{}, m, n, k, 1.0, 1.0, 1.0, gemmConfig, nullptr, 0, 0); + workspace_size = std::max(workspace_size, curr_workspace_size); + } + catch (std::runtime_error& e) + { + // Swallow errors when SMEM exceeds maximum allowed + continue; + } + } + + return workspace_size; +} + +template +size_t CutlassFusedGatedGemmRunner::getWorkspaceSize(int const m, int const n, int const k) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + // Custom hash function for the MNK type + using MNK = std::tuple; + + struct MNKHash + { + size_t operator()(const MNK& mnk) const + { + auto h1 = std::hash{}(std::get<0>(mnk)); + auto h2 = std::hash{}(std::get<1>(mnk)); + auto h3 = std::hash{}(std::get<2>(mnk)); + return h1 ^ h2 ^ h3; + } + }; + + static std::unordered_map workspace_hashmap; + + size_t workspace_size = 0; + if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) + { + workspace_size = CutlassFusedGatedGemmRunner::getWorkspaceSizeImpl(m, n, k); + workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; + } + else + { + workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; + } + return workspace_size; +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/gemm_swiglu_e4m3.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/gemm_swiglu_e4m3.cu new file mode 100644 index 000000000..2e603cfb1 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/gemm_swiglu_e4m3.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fused_gated_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFusedGatedGemmRunner<__nv_fp8_e4m3>; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h index f3561dc50..722f817db 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h @@ -19,6 +19,7 @@ #include "cutlass_extensions/gemm_configs.h" #include "tensorrt_llm/common/quantization.h" #include +#include namespace tk = tensorrt_llm::common; namespace tkc = tensorrt_llm::cutlass_extensions; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h new file mode 100644 index 000000000..5a8e3b514 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template +void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, + ElementType_ const* biases, ElementType_* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy); +} diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl new file mode 100644 index 000000000..1b2e4f4d4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" + +#include +#include +#include + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template +void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, + ElementType_ const* biases, ElementType_* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy) +{ + constexpr auto activation_type = fused_moe::EpilogueRouting(true); + using GemmType = fused_moe::Fused_Moe_Kernel_sm80; + + // make sure GPU has enough resources.. + if (kernel_occupancy != nullptr) + { + constexpr int smem_size = GemmType::kSmemSize; + + if (smem_size > (48 << 10)) + { + cudaFuncAttributes attr{}; + int device = 0; + int max_smem_per_block = 0; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global)); + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) + { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + // smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the + // heuristic to ignore this configuration. + *kernel_occupancy = 0; + return; + } + } + + int max_active_blocks = -1; + tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, fused_moe::run_global, GemmType::kThreadCount, smem_size)); + *kernel_occupancy = max_active_blocks; + return; + } + int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks()); + int const threadblock_count = multi_processor_count * occupancy; + TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel"); + GemmType gemm; + using Arguments = typename GemmType::Arguments; + Arguments args{{const_cast(A), const_cast(B), const_cast(biases), + reinterpret_cast(C), total_rows_before_expert, static_cast(gemm_n), + static_cast(gemm_k), num_experts}, + num_experts, threadblock_count}; + auto params = GemmType::to_underlying_arguments(args); + if (GemmType::kSmemSize >= (48 << 10)) + { + cudaError_t result = cudaFuncSetAttribute( + fused_moe::run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, + "Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel"); + } + dim3 grid(params.threadblock_count, 1, 1); + dim3 block(GemmType::kThreadCount); + fused_moe::run_global<<>>(params); + auto result = cudaGetLastError(); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result)); +} +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h index d485fa596..643a2a916 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h @@ -16,6 +16,7 @@ */ #pragma once +#include "tensorrt_llm/common/cudaFp8Utils.h" #include "tensorrt_llm/common/workspace.h" #include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" #include @@ -161,11 +162,11 @@ class MoeGemmRunner void moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream); + int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, cudaStream_t stream); void moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cudaStream_t stream); + bool use_fused_moe, cudaStream_t stream); std::vector getConfigs() const; static std::vector getConfigs(int sm); @@ -174,20 +175,23 @@ class MoeGemmRunner bool isHopperSpecialised() const; bool supportsHopperSpecialisation() const; + [[nodiscard]] bool isFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; size_t calcMaxWorkspaceSize(int num_experts) const; + [[nodiscard]] int getSM() const; + private: template void dispatchToArch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, - int* occupancy = nullptr); + int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, + cudaStream_t stream, int* occupancy = nullptr); template void runGemm(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, cudaStream_t stream); + int64_t gemm_k, int num_experts, bool use_fused_moe, cudaStream_t stream); private: int sm_; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h index 349ae4cc8..3b1302358 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h @@ -52,10 +52,10 @@ #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" #include "moe_gemm_kernels_template_sm90.h" - #include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" #include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" #include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" +#include #include #include @@ -72,8 +72,8 @@ template void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, cudaStream_t stream, - int* kernel_occupancy = nullptr) + cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe, + cudaStream_t stream, int* kernel_occupancy = nullptr) { #ifdef ENABLE_BF16 static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value @@ -95,64 +95,78 @@ void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weig // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. using ElementType = typename TllmToCutlassTypeAdapter::type; using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; + if (!use_fused_moe) + { + // We need separate config for each architecture since we will target different tensorcore instructions. For + // float, we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + int const threadblock_count = multi_processor_count * occupancy; - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; + typename EpilogueOp::Params epilogue_op( + ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); - using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; + int const group_size = gemm_k; + typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, + reinterpret_cast(A), reinterpret_cast(B), + reinterpret_cast(weight_scales), reinterpret_cast(biases), + reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); - // Finally, set up the kernel. - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; + GemmGrouped gemm; - using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); - using GemmGrouped = cutlass::gemm::device::GemmGrouped; + auto init_status = gemm.initialize(args); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status))); - if (kernel_occupancy != nullptr) + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status))); + } + else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2 + && (std::is_same_v + || std::is_same_v) ) // use fused moe gemm + // kernel.. (only support + // fp16 or bf16) { - *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); - return; + sm80_generic_fused_moe_gemm_kernelLauncher(reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(biases), + reinterpret_cast(C), total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, + multi_processor_count, stream, kernel_occupancy); } - int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); - int const threadblock_count = multi_processor_count * occupancy; - - typename EpilogueOp::Params epilogue_op( - ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); - - int const group_size = gemm_k; - typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, - reinterpret_cast(A), reinterpret_cast(B), - reinterpret_cast(weight_scales), reinterpret_cast(biases), - reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); - - GemmGrouped gemm; - - auto can_implement = gemm.can_implement(args); - TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, - "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); - - auto init_status = gemm.initialize(args); - TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, - "Failed to initialize cutlass variable batched gemm. Error: " - + std::string(cutlassGetStatusString(init_status))); - - auto run_status = gemm.run(stream); - TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, - "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status))); } } // namespace kernels::cutlass_kernels @@ -161,8 +175,8 @@ template static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + cudaStream_t stream, int* occupancy = nullptr) { static_assert(!std::is_same_v, "Use TMA specialised functions for arch SM90"); constexpr bool isFp8 = std::is_same_v || std::is_same_v; @@ -170,7 +184,7 @@ static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T { kernels::cutlass_kernels::genericMoeGemmKernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count, stream, occupancy); + num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); } else { @@ -183,25 +197,25 @@ template void dispatchGemmConfig(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.stages) { case 2: dispatch(A, B, weight_scales, biases, C, - total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, - occupancy); + total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, stream, occupancy); break; case 3: dispatch(A, B, weight_scales, biases, C, - total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, - occupancy); + total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, stream, occupancy); break; case 4: dispatch(A, B, weight_scales, biases, C, - total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, - occupancy); + total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, stream, occupancy); break; default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; } @@ -213,8 +227,8 @@ template ::value && std::is_same::value>::type* = nullptr> void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { @@ -224,7 +238,8 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s { dispatchGemmConfig, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, - total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, + occupancy); } break; case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: @@ -233,23 +248,24 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s { dispatchGemmConfig, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, - total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, + occupancy); } break; case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatchGemmConfig, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: dispatchGemmConfig, cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: dispatchGemmConfig, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: @@ -266,8 +282,8 @@ template ::value && !std::is_same::value>::type* = nullptr> void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { @@ -277,7 +293,8 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s { dispatchGemmConfig, cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, - total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, + occupancy); } break; case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: @@ -286,23 +303,24 @@ void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_s { dispatchGemmConfig, cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, - total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, + occupancy); } break; case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatchGemmConfig, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: dispatchGemmConfig, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: dispatchGemmConfig, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: @@ -317,15 +335,15 @@ template ::value>::type* = nullptr> void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: dispatchGemmConfig, cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, use_fused_moe, stream, occupancy); break; case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: @@ -414,6 +432,21 @@ bool MoeGemmRunner::supportsHopperSpecialisation() const return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation(); } +template +int MoeGemmRunner::getSM() const +{ + return this->sm_; +} + +// currently support sm80 bf16/fp16 gate ativation, only set predication tensor for m direction +template +bool MoeGemmRunner::isFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const +{ + return is_gated_activation + && std::is_same_v && (!std::is_same_v) &&(!this->isHopperSpecialised()) + && this->getSM() >= 80 && (gemm_k % 32 == 0) && (gemm_n % 32 == 0); +} + template MoeGemmRunner::MoeGemmRunner() { @@ -429,7 +462,7 @@ template void MoeGemmRunner::dispatchToArch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, - cudaStream_t stream, int* occupancy) + bool use_fused_moe, cudaStream_t stream, int* occupancy) { TLLM_CHECK_WITH_INFO( @@ -441,19 +474,19 @@ void MoeGemmRunner::dispatchToArch(T const* A, Weigh { dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, - stream, occupancy); + use_fused_moe, stream, occupancy); } else if (sm_ >= 75 && sm_ < 80) { dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, - stream, occupancy); + use_fused_moe, stream, occupancy); } else if (sm_ >= 80 && sm_ < 90) { dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, - stream, occupancy); + use_fused_moe, stream, occupancy); } else if (sm_ >= 90) { @@ -486,7 +519,7 @@ void MoeGemmRunner::dispatchToArch(T const* A, Weigh dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, - stream, occupancy); + use_fused_moe, stream, occupancy); } else { @@ -539,36 +572,45 @@ template template void MoeGemmRunner::runGemm(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) + int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, cudaStream_t stream) { TLLM_CHECK_WITH_INFO(this->best_config_, "No MOE GEMM config set at runtime"); auto chosen_conf = *this->best_config_; dispatchToArch(A, B, weight_scales, biases, C, total_rows_before_expert, hopper_input, total_rows, - gemm_n, gemm_k, num_experts, chosen_conf, stream); + gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, stream); } template void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream) + int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, + cudaStream_t stream) { switch (activation_type) { case ActivationType::Relu: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream); break; case ActivationType::Gelu: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream); break; case ActivationType::Silu: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream); break; case ActivationType::Identity: runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream); + break; + case ActivationType::Swiglu: + runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream); + break; + case ActivationType::Geglu: + runGemm(A, B, weight_scales, biases, C, total_rows_before_expert, + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream); break; case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; default: TLLM_THROW("Invalid activation type."); break; @@ -578,10 +620,10 @@ void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* template void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C, int64_t* total_rows_before_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, cudaStream_t stream) + int64_t gemm_k, int num_experts, bool use_fused_moe, cudaStream_t stream) { runGemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, stream); + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, stream); } } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py index 21322bde9..95ab8e899 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py @@ -11,19 +11,27 @@ class TrtLlm_EpilogueTag(enum.Enum): epilogue_op_default = enum_auto() epilogue_op_bias = enum_auto() + epilogue_op_silu = enum_auto() + epilogue_op_gelu = enum_auto() EpiTagNames = { TrtLlm_EpilogueTag.epilogue_op_default: "lc", # linear combination TrtLlm_EpilogueTag.epilogue_op_bias: - "lc_bias" # linear combination with bias addition + "lc_bias", # linear combination with bias addition + TrtLlm_EpilogueTag.epilogue_op_silu: "silu", # silu or swiglu + TrtLlm_EpilogueTag.epilogue_op_gelu: "gelu" # gelu or geglu } EpiTag = { TrtLlm_EpilogueTag.epilogue_op_default: "tensorrt_llm::cutlass_extensions::EpilogueOpDefault", TrtLlm_EpilogueTag.epilogue_op_bias: - "tensorrt_llm::cutlass_extensions::EpilogueOpBias" + "tensorrt_llm::cutlass_extensions::EpilogueOpBias", + TrtLlm_EpilogueTag.epilogue_op_silu: + "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu", + TrtLlm_EpilogueTag.epilogue_op_gelu: + "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu" } @@ -350,6 +358,70 @@ def generate_sm90_operations(): return operations +def generate_sm80_fused_grouped_gemm_operations(): + arch = 80 + supported_dtypes = [DataType.f16, DataType.bf16] + epi_tags = [ + TrtLlm_EpilogueTag.epilogue_op_silu, TrtLlm_EpilogueTag.epilogue_op_gelu + ] + cta_shapes_mnk = [(16, 128, 64), (16, 256, 64), (32, 128, 64), + (64, 128, 64), (128, 128, 64)] + + stages = [2, 3, 4] + + partial_args = product(supported_dtypes, epi_tags, cta_shapes_mnk, stages) + + operations = list() + for dtype, epi_tag, cta_shape_mnk, stage in partial_args: + item = { + "arch": arch, + "dtype": dtype, + "epi_tag": epi_tag, + "cta_shape": cta_shape_mnk, + "stage": stage + } + operations.append(item) + return operations + + +def generate_sm80_operations(): + operations = generate_sm80_fused_grouped_gemm_operations() + return operations + + +def get_sm80_file_content(op_item): + includes = f"#include " + act_tag = DataTypeTag[op_item['dtype']] + weight_tag = DataTypeTag[op_item['dtype']] + epi_tag = EpiTag[op_item['epi_tag']] + + instantiations = f""" + template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {op_item['cta_shape'][0]}, {op_item['cta_shape'][1]}, {op_item['cta_shape'][2]}, {op_item['stage']}, {epi_tag}> + ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, {act_tag}* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy); +""" + file_content = f"""{includes} +namespace tensorrt_llm +{{ +namespace kernels +{{ +namespace cutlass_kernels +{{ + +{instantiations} + +}} // namespace cutlass_kernels +}} // namespace kernels +}} // namespace tensorrt_llm +""" + return file_content + + +def write_sm80_file(op_item, file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, mode="w") as f: + f.write(get_sm80_file_content(op_item)) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description='Print the output directory') @@ -389,3 +461,14 @@ def generate_sm90_operations(): f"cutlass_kernel_file_{file_counter}.generated.cu") write_file(inl_map[gemm_kind], value, out_file) file_counter += 1 + + # Since GemmKind.Grouped is used for gen sm90 moe code. + sm80_operations = generate_sm80_operations() + for op_item in sm80_operations: + # print(op_item) + out_file_path = os.path.join( + output_dir, "gemm_grouped", + f"fused_moe_sm{op_item['arch']}_{op_item['cta_shape'][0]}_{op_item['cta_shape'][1]}_{op_item['cta_shape'][2]}_{op_item['stage']}_{DataTypeNames[op_item['dtype']]}_{EpiTagNames[op_item['epi_tag']]}.generated.cu" + ) + write_sm80_file(op_item, out_file_path) + # print(out_file_path) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.cu b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.cu index 7b0b9d546..5191c7e03 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.cu +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.cu @@ -28,8 +28,8 @@ namespace mmha //////////////////////////////////////////////////////////////////////////////////////////////////// // Forward declaration of the kernel launcher to avoid including decoderMaskedMultiheadAttentionLaunch.h -template +template void mmha_launch_kernel(const T_PARAMS& params, KVCacheBuffer const& kv_cache_buffer, KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream); @@ -39,19 +39,19 @@ namespace { #define MMHA_LAUNCH_KERNEL(Dh) \ - mmha::mmha_launch_kernel( \ + mmha::mmha_launch_kernel( \ params, kv_cache_buffer, shift_k_cache, stream); \ break; #define MMHA_LAUNCH_KERNE_EX1(Dh) \ if (has_implicit_rel_attn_bias) \ { \ - mmha::mmha_launch_kernel( \ + mmha::mmha_launch_kernel( \ params, kv_cache_buffer, shift_k_cache, stream); \ } \ else \ { \ - mmha::mmha_launch_kernel( \ + mmha::mmha_launch_kernel( \ params, kv_cache_buffer, shift_k_cache, stream); \ } \ break; @@ -59,17 +59,22 @@ namespace #define MMHA_LAUNCH_KERNE_EX2(Dh) \ if (has_implicit_rel_attn_bias) \ { \ - mmha::mmha_launch_kernel( \ + mmha::mmha_launch_kernel( \ params, kv_cache_buffer, shift_k_cache, stream); \ } \ else if (has_qk_tanh_scale) \ { \ - mmha::mmha_launch_kernel( \ + mmha::mmha_launch_kernel( \ + params, kv_cache_buffer, shift_k_cache, stream); \ + } \ + else if (has_block_sparse_attn) \ + { \ + mmha::mmha_launch_kernel( \ params, kv_cache_buffer, shift_k_cache, stream); \ } \ else \ { \ - mmha::mmha_launch_kernel( \ + mmha::mmha_launch_kernel( \ params, kv_cache_buffer, shift_k_cache, stream); \ } \ break; @@ -88,6 +93,13 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, KVCacheBuffer const& TLLM_CHECK_WITH_INFO(!(has_qk_tanh_scale && has_implicit_rel_attn_bias), "MMHA kernels haven't instantiate implicit_relative_attention_bias + qk_tanh_scale paths for head size %d.", head_size); + + bool const has_block_sparse_attn = params.block_sparse_attention; + TLLM_CHECK_WITH_INFO(!has_block_sparse_attn || head_size == 128, + "MMHA kernels were not instantiated for block_sparse_attention for head size %d.", head_size); + TLLM_CHECK_WITH_INFO(!(has_implicit_rel_attn_bias && has_block_sparse_attn), + "MMHA kernels do not support combining implicit_relative_attention_bias and block_sparse_attention"); + switch (params.hidden_size_per_head) { case 32: MMHA_LAUNCH_KERNE_EX1(32); diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h index d71f9d149..6ce276827 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h @@ -123,9 +123,11 @@ struct Multihead_attention_params_base float rotary_embedding_base = 0.0f; RotaryScalingType rotary_embedding_scale_type = RotaryScalingType::kNONE; float rotary_embedding_scale = 0.0f; - float rotary_embedding_m_scale = 0.0f; + float rotary_embedding_short_m_scale = 0.0f; + float rotary_embedding_long_m_scale = 0.0f; float const* rotary_embedding_scaling_factors = nullptr; int rotary_embedding_max_positions = 0; + int rotary_embedding_original_max_positions = 0; int rotary_cogvlm_vision_start = -1; int rotary_cogvlm_vision_length = -1; // Position shift for streamingllm @@ -146,6 +148,10 @@ struct Multihead_attention_params_base int relative_attention_bias_stride = 0; int max_distance = 0; + // block sparse config + bool block_sparse_attention = false; + BlockSparseParams block_sparse_params{64, false, 16, 8}; + // The slope per head of linear position bias to attention score (H). T const* linear_bias_slopes = nullptr; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/CMakeLists.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/CMakeLists.txt index 9825ca1a9..a65e4bc02 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/CMakeLists.txt @@ -32,6 +32,12 @@ if(FAST_BUILD) ) endif() +filter_cuda_archs("70" SRC_CPP) +filter_cuda_archs("80" SRC_CPP) +filter_cuda_archs("86" SRC_CPP) +filter_cuda_archs("89" SRC_CPP) +filter_cuda_archs("90" SRC_CPP) + add_library(decoder_attention_src OBJECT ${SRC_CPP} ${SRC_CU}) set_property(TARGET decoder_attention_src PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoder_attention_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/cubin/xqa_kernel_cubin.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/cubin/xqa_kernel_cubin.h index ae6eb6b72..c95d40c3e 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/cubin/xqa_kernel_cubin.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/cubin/xqa_kernel_cubin.h @@ -20,6 +20,7 @@ namespace kernels { // clang-format off // SingleQueryToken kernels. +#ifndef EXCLUDE_SM_80 extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[]; @@ -44,6 +45,98 @@ extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_n extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin[]; extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[]; extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin[]; + +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len; +#endif + +#ifndef EXCLUDE_SM_86 extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[]; @@ -68,6 +161,98 @@ extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_n extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin[]; extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[]; extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin[]; + +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len; +#endif + +#ifndef EXCLUDE_SM_89 extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[]; @@ -104,92 +289,6 @@ extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_n extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; - -// MultiQueryToken kernels. -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin[]; @@ -226,6 +325,144 @@ extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nq extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; + +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; +#endif + +#ifndef EXCLUDE_SM_90 +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[]; +extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[]; + +// MultiQueryToken kernels. extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[]; @@ -264,34 +501,6 @@ extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_n extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin[]; // MHA with beamWidth=4 -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin[]; -extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_90_cubin[]; extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_90_cubin[]; @@ -306,90 +515,6 @@ extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nq extern unsigned long long xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_90_cubin[]; // SingleQueryToken kernels. -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len; extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len; @@ -424,94 +549,10 @@ extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len; - -// MultiQueryToken kernels. -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len; +extern uint32_t xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len; + +// MultiQueryToken kernels. extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len; @@ -550,34 +591,6 @@ extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_1 extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len; // MHA with beamWidth=4 -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len; -extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len; extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len; @@ -590,6 +603,7 @@ extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_ extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len; extern uint32_t xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_90_cubin_len; +#endif static const struct XQAKernelMetaInfo { @@ -608,6 +622,7 @@ static const struct XQAKernelMetaInfo const char* mFuncName; } sXqaKernelMetaInfo[] = { // SingleQueryToken kernels. +#ifndef EXCLUDE_SM_80 { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"}, @@ -632,6 +647,40 @@ static const struct XQAKernelMetaInfo { DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, +#endif +#ifndef EXCLUDE_SM_86 { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"}, @@ -656,6 +705,40 @@ static const struct XQAKernelMetaInfo { DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, +#endif +#ifndef EXCLUDE_SM_89 { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"}, @@ -692,6 +775,56 @@ static const struct XQAKernelMetaInfo { DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, +#endif +#ifndef EXCLUDE_SM_90 { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"}, @@ -729,90 +862,6 @@ static const struct XQAKernelMetaInfo { DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_256_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"}, // MultiQueryToken kernels. -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"}, @@ -850,34 +899,6 @@ static const struct XQAKernelMetaInfo { DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"}, // MHA with beamWidth=4 -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_80_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_86_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_bf16_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 4, 1, 1, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, -{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 4, 1, 1, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_89_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_FP16, DATA_TYPE_INT8, 256, 4, 1, 1, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_int8_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"}, @@ -890,6 +911,7 @@ static const struct XQAKernelMetaInfo { DATA_TYPE_BF16, DATA_TYPE_INT8, 256, 4, 1, 1, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_int8_pagedKV_128_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 4, 1, 1, 64, true, false, kSM_90, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"}, { DATA_TYPE_BF16, DATA_TYPE_E4M3, 256, 4, 1, 1, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_bf16_d_256_beam_4_kvt_e4m3_pagedKV_128_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"} +#endif }; // clang-format on diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h index df3fd88ac..de96f3507 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h @@ -152,18 +152,18 @@ inline void multi_block_grid_setup(dim3& grid, Multihead_attention_params= 46 * 1024) \ { \ - cudaError_t res \ - = cudaFuncSetAttribute(mmha::masked_multihead_attention_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \ + cudaError_t res = cudaFuncSetAttribute( \ + mmha::masked_multihead_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \ TLLM_CHECK_WITH_INFO( \ res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \ } \ TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&available_blocks, \ mmha::masked_multihead_attention_kernel, \ + BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>, \ DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz)); #define MMHA_KERNEL(DYNAMIC_THDS_PER_BLOCK, ENABLE_MULTI_BLOCK) \ @@ -175,14 +175,14 @@ inline void multi_block_grid_setup(dim3& grid, Multihead_attention_params, \ + POS_SHIFT, BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \ TLLM_CHECK_WITH_INFO( \ res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \ } \ mmha::masked_multihead_attention_kernel \ + BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE> \ <<>>(params, kv_cache_buffer, k_cache_buffer); // if resources are not enough to launch 512 threads per block, we will fallback to 256. @@ -214,7 +214,7 @@ inline void multi_block_grid_setup(dim3& grid, Multihead_attention_params + bool BLOCK_SPARSE_ATTN, bool IMPLICIT_REL_ATTN_BIAS, bool QK_TANH_SCALE> void mmha_launch_kernel_ex(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, KCacheBuffer const& k_cache_buffer, cudaStream_t const& stream, int tlength) { @@ -235,8 +235,8 @@ void mmha_launch_kernel_ex(KernelParamsType const& params, KVCacheBuffer const& // Dynamic shared memory is fixed for different block size. TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, mmha::masked_multihead_attention_kernel, + KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK, POS_SHIFT, BLOCK_SPARSE_ATTN, + IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>, THDS_PER_BLOCK, 0)); int block_size_factor @@ -305,52 +305,52 @@ void mmha_launch_kernel_ex(KernelParamsType const& params, KVCacheBuffer const& } template + bool HAS_BEAMS, bool DO_MULTI_BLOCK, bool BLOCK_SPARSE_ATTN, bool IMPLICIT_REL_ATTN_BIAS, bool QK_TANH_SCALE> void mmha_launch_kernel_dispatch_pos_shift(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream, int tlength) { if (params.position_shift_enabled && !KernelParamsType::DO_CROSS_ATTENTION) { mmha_launch_kernel_ex( + HAS_BEAMS, DO_MULTI_BLOCK, true, BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>( params, kv_cache_buffer, shift_k_cache, stream, tlength); } else { mmha_launch_kernel_ex( + HAS_BEAMS, DO_MULTI_BLOCK, false, BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>( params, kv_cache_buffer, kv_cache_buffer, stream, tlength); } } template + bool DO_MULTI_BLOCK, bool BLOCK_SPARSE_ATTN, bool IMPLICIT_REL_ATTN_BIAS, bool QK_TANH_SCALE> void mmha_launch_kernel_dispatch_8bits_kv_cache(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream, int tlength) { if (params.int8_kv_cache) { mmha_launch_kernel_dispatch_pos_shift( + DO_MULTI_BLOCK, BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>( params, kv_cache_buffer, shift_k_cache, stream, tlength); } #ifdef ENABLE_FP8 else if (params.fp8_kv_cache) { mmha_launch_kernel_dispatch_pos_shift( + HAS_BEAMS, DO_MULTI_BLOCK, BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>( params, kv_cache_buffer, shift_k_cache, stream, tlength); } #endif // ENABLE_FP8 else { mmha_launch_kernel_dispatch_pos_shift( + DO_MULTI_BLOCK, BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>( params, kv_cache_buffer, shift_k_cache, stream, tlength); } } -template void mmha_launch_kernel_dispatch(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream) @@ -359,17 +359,19 @@ void mmha_launch_kernel_dispatch(KernelParamsType const& params, KVCacheBuffer c if (params.multi_block_mode) { mmha_launch_kernel_dispatch_8bits_kv_cache(params, kv_cache_buffer, shift_k_cache, stream, tlength); + BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>( + params, kv_cache_buffer, shift_k_cache, stream, tlength); } else { mmha_launch_kernel_dispatch_8bits_kv_cache(params, kv_cache_buffer, shift_k_cache, stream, tlength); + BLOCK_SPARSE_ATTN, IMPLICIT_REL_ATTN_BIAS, QK_TANH_SCALE>( + params, kv_cache_buffer, shift_k_cache, stream, tlength); } } -template +template void mmha_launch_kernel(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream) { @@ -379,57 +381,71 @@ void mmha_launch_kernel(KernelParamsType const& params, KVCacheBuffer const& kv_ || params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE)); if (params.beam_width == 1) { - mmha_launch_kernel_dispatch(params, kv_cache_buffer, shift_k_cache, stream); + mmha_launch_kernel_dispatch(params, kv_cache_buffer, shift_k_cache, stream); } else { - mmha_launch_kernel_dispatch(params, kv_cache_buffer, shift_k_cache, stream); + mmha_launch_kernel_dispatch(params, kv_cache_buffer, shift_k_cache, stream); } } } // namespace mmha #define INSTANTIATE_MMHA_LAUNCHERS(T, Dh) \ - template void mmha_launch_kernel, Dh, false, false>( \ - const Masked_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ + template void mmha_launch_kernel, Dh, false, false, \ + false>(const Masked_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, false, false>( \ + template void mmha_launch_kernel, Dh, false, false, false>( \ const Masked_multihead_attention_params& params, const KVBlockArray& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, false, false>( \ + template void mmha_launch_kernel, Dh, false, false, false>( \ const Cross_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, false, false>( \ + template void mmha_launch_kernel, Dh, false, false, false>( \ const Cross_multihead_attention_params& params, const KVBlockArray& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); #define INSTANTIATE_MMHA_LAUNCHERS_WITH_IMPLICIT_REL_ATTN_BIAS(T, Dh) \ - template void mmha_launch_kernel, Dh, true, false>( \ + template void mmha_launch_kernel, Dh, false, true, false>( \ const Masked_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, true, false>( \ + template void mmha_launch_kernel, Dh, false, true, false>( \ const Masked_multihead_attention_params& params, const KVBlockArray& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, true, false>( \ + template void mmha_launch_kernel, Dh, false, true, false>( \ const Cross_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, true, false>( \ + template void mmha_launch_kernel, Dh, false, true, false>( \ const Cross_multihead_attention_params& params, const KVBlockArray& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); #define INSTANTIATE_MMHA_LAUNCHERS_WITH_QK_TANH_SCALE(T, Dh) \ - template void mmha_launch_kernel, Dh, false, true>( \ + template void mmha_launch_kernel, Dh, false, false, true>( \ + const Masked_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ + const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ + template void mmha_launch_kernel, Dh, false, false, true>( \ + const Masked_multihead_attention_params& params, const KVBlockArray& kv_cache_buffer, \ + const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ + template void mmha_launch_kernel, Dh, false, false, true>( \ + const Cross_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ + const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ + template void mmha_launch_kernel, Dh, false, false, true>( \ + const Cross_multihead_attention_params& params, const KVBlockArray& kv_cache_buffer, \ + const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); + +#define INSTANTIATE_MMHA_LAUNCHERS_WITH_BLOCK_SPARSE_ATTN(T, Dh) \ + template void mmha_launch_kernel, Dh, true, false, false>( \ const Masked_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, false, true>( \ + template void mmha_launch_kernel, Dh, true, false, false>( \ const Masked_multihead_attention_params& params, const KVBlockArray& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, false, true>( \ + template void mmha_launch_kernel, Dh, true, false, false>( \ const Cross_multihead_attention_params& params, const KVLinearBuffer& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \ - template void mmha_launch_kernel, Dh, false, true>( \ + template void mmha_launch_kernel, Dh, true, false, false>( \ const Cross_multihead_attention_params& params, const KVBlockArray& kv_cache_buffer, \ const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index 78c89ac92..a4e45a004 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -1282,6 +1282,8 @@ template < bool DO_MULTI_BLOCK = false, // Whether enable position shift for streamingllm bool POS_SHIFT = false, + // Whether to compute and apply block sparse attention mask + bool BLOCK_SPARSE_ATTN = false, // Whether compute implicit relative attention bias on the fly. bool IMPLICIT_REL_ATTN_BIAS = false, // Whether apply tanh scale to the qk product. @@ -1680,13 +1682,16 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; if (do_rotary) { + float rotary_embedding_m_scale = tlength <= params.rotary_embedding_original_max_positions + ? params.rotary_embedding_short_m_scale + : params.rotary_embedding_long_m_scale; mmha::vec_from_smem_transpose(q, q_smem_, transpose_idx, smem_pitch); if (HANDLE_KV) { mmha::vec_from_smem_transpose(k, k_smem_, transpose_idx, smem_pitch); mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, - rotary_embedding_base, rotary_embedding_scale, params.rotary_embedding_m_scale, + rotary_embedding_base, rotary_embedding_scale, rotary_embedding_m_scale, params.rotary_embedding_scaling_factors, current_pos_idx, params.rotary_cogvlm_vision_start, params.rotary_cogvlm_vision_length); @@ -1695,7 +1700,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske else { mmha::apply_rotary_embedding(q, transpose_idx / tidx_factor, params.rotary_embedding_dim, - rotary_embedding_base, rotary_embedding_scale, params.rotary_embedding_m_scale, + rotary_embedding_base, rotary_embedding_scale, rotary_embedding_m_scale, params.rotary_embedding_scaling_factors, current_pos_idx, params.rotary_cogvlm_vision_start, params.rotary_cogvlm_vision_length); } @@ -2045,6 +2050,14 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // All the threads do the work even if it's not relevant to avoid divergence. qk_ += linear_bias_slope * (local_time_now - tlength) + relative_attention_bias; + if constexpr (BLOCK_SPARSE_ATTN) + { + float mask_val + = params.block_sparse_params.computeMask(tlength, local_time_now, tlength + 1, num_heads, hi) ? 1.f + : 0.f; + qk_ += (1.0f - mask_val) * -10000.0f; + } + // There's one qk value per timestep. // Make sure only leader threads stores qk value within the bound. if (is_active && is_leader) @@ -2177,6 +2190,13 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // All the threads perform that step to avoid divergence. qk_ += linear_bias_slope * (time_now - tlength) + relative_attention_bias; + if constexpr (BLOCK_SPARSE_ATTN) + { + float mask_val + = params.block_sparse_params.computeMask(tlength, time_now, tlength + 1, num_heads, hi) ? 1.f : 0.f; + qk_ += (1.0f - mask_val) * -10000.0f; + } + // There's one qk value per timestep. // Make sure only leader threads stores qk value within the bound. if (is_active && is_leader) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h index 2d305f7dd..6aaf43760 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h @@ -17,6 +17,7 @@ */ #pragma once #include "decoderXQAConstants.h" +#include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/workspace.h" #include "tensorrt_llm/kernels/kvCacheUtils.h" @@ -210,10 +211,10 @@ struct XQALaunchParam void* scratch = nullptr; }; -// Setup launch params. +// Setup launch params and ioScratch. ioScratch is for RoPE and output type conversion. template -void buildXQALaunchParams( - XQALaunchParam& launchParams, XQAParams const& params, KVCacheBuffer kv_cache_buffer) +void buildXQALaunchParams(XQALaunchParam& launchParams, void*& ioScratch, XQAParams const& params, + KVCacheBuffer kv_cache_buffer) { TLLM_CHECK_WITH_INFO( params.data_type == DATA_TYPE_FP16 || params.data_type == DATA_TYPE_BF16, "Only fp16 or bf16 supported now."); @@ -228,6 +229,9 @@ void buildXQALaunchParams( // Workspace. size_t offset = 0; int8_t* workspace = reinterpret_cast(params.workspaces); + ioScratch = workspace; + workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment( + workspace, 2 * params.head_size * params.num_q_heads * params.total_num_input_tokens); unsigned int batch_beam_size = params.batch_size * params.beam_width; const size_t cu_seqlens_size = sizeof(int) * (batch_beam_size + 1); const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index f748afb0f..7a375a929 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -192,7 +192,13 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& : (xqaParams.kv_cache_quant_mode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE); XQALaunchParam launchParams; - buildXQALaunchParams(launchParams, xqaParams, kv_cache_buffer); + void* ioScratch = nullptr; + buildXQALaunchParams(launchParams, ioScratch, xqaParams, kv_cache_buffer); + bool const needOutputCvt = (xqaParams.fp8_out_scale != nullptr); + if (needOutputCvt) + { + launchParams.output = ioScratch; + } // Build cu_seqlens, padding_offset, and rotary inv freq tensors BuildDecoderInfoParams decoder_params; @@ -214,18 +220,18 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& // IDEA: Store rotary_processed Q buffer to output buffer. // NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache. - void const* xqa_q_input_ptr = xqaParams.output; + void* xqa_q_input_ptr = ioScratch; QKVPreprocessingParams preprocessingParms{static_cast(const_cast(xqaParams.qkv)), - nullptr, static_cast(const_cast(xqaParams.output)), kv_cache_buffer, - static_cast(xqaParams.qkv_bias), nullptr, xqaParams.sequence_lengths, nullptr, - launchParams.rotary_inv_freq_buf, (float2 const*) nullptr, xqaParams.kv_scale_orig_quant, - xqaParams.spec_decoding_position_offsets, int(batch_beam_size), xqaParams.generation_input_length, - xqaParams.timestep, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length, - int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), xqaParams.num_q_heads, - xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size, - xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, - xqaParams.rotary_embedding_scale, xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, - xqaParams.position_shift_enabled, cache_type, true, false, multiprocessor_count}; + nullptr, static_cast(xqa_q_input_ptr), kv_cache_buffer, static_cast(xqaParams.qkv_bias), nullptr, + xqaParams.sequence_lengths, nullptr, launchParams.rotary_inv_freq_buf, (float2 const*) nullptr, + xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets, int(batch_beam_size), + xqaParams.generation_input_length, xqaParams.timestep, xqaParams.cyclic_attention_window_size, + xqaParams.sink_token_length, int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), + xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, + xqaParams.head_size, xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, + xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale, + xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, xqaParams.position_shift_enabled, + cache_type, true, false, multiprocessor_count}; invokeQKVPreprocessing(preprocessingParms, stream); sync_check_cuda_error(); @@ -311,6 +317,15 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& } sync_check_cuda_error(); + + if (needOutputCvt) + { + tensorrt_llm::kernels::invokeConversion<__nv_fp8_e4m3, T>(static_cast<__nv_fp8_e4m3*>(xqaParams.output), + static_cast(launchParams.output), + xqaParams.head_size * xqaParams.num_q_heads * xqaParams.total_num_input_tokens, xqaParams.fp8_out_scale, + stream); + sync_check_cuda_error(); + } } } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt index 67e1825c3..1a295e432 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ b3823dd8e1d7f154019fb7dc24172ff4 libtensorrt_llm_nvrtc_wrapper.so -fc46fa01e555f9f97387340e46e9571fabf73988 commit \ No newline at end of file +736b3fc4259916d31211104b91e6b2b4db995b17 commit \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll index 88182893a..69bc28c51 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/x86_64-windows-msvc/tensorrt_llm_nvrtc_wrapper.dll @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:afd0ee85e633f116ef53fe9f71c2821b937274d835548239f5f8dae306143a27 +oid sha256:763493a4f1996e97a1c449fad132a6119a29b6c4d98262a315ee765cde780350 size 1011200 diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp index 66e4be0d4..42c6d5e6b 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp @@ -160,7 +160,13 @@ class XQAKernelList : (xqaParams.kv_cache_quant_mode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE); XQALaunchParam launchParams; - buildXQALaunchParams(launchParams, xqaParams, kv_cache_buffer); + void* ioScratch = nullptr; + buildXQALaunchParams(launchParams, ioScratch, xqaParams, kv_cache_buffer); + bool const needOutputCvt = (xqaParams.fp8_out_scale != nullptr); + if (needOutputCvt) + { + launchParams.output = ioScratch; + } // Build cu_seqlens, padding_offset, and rotary inv freq tensors BuildDecoderInfoParams decoder_params; @@ -186,14 +192,14 @@ class XQAKernelList // IDEA: Store rotary_processed Q buffer to output buffer. // NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache. - void const* xqa_q_input_ptr = xqaParams.output; + void* xqa_q_input_ptr = ioScratch; QKVPreprocessingParams preprocessingParms{static_cast(const_cast(xqaParams.qkv)), - nullptr, static_cast(const_cast(xqaParams.output)), kv_cache_buffer, - static_cast(xqaParams.qkv_bias), xqaParams.spec_decoding_generation_lengths, - xqaParams.sequence_lengths, xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr, - launchParams.rotary_inv_freq_buf, (float2 const*) nullptr, xqaParams.kv_scale_orig_quant, - xqaParams.spec_decoding_position_offsets, int(batch_beam_size), xqaParams.generation_input_length, - xqaParams.timestep, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length, + nullptr, static_cast(xqa_q_input_ptr), kv_cache_buffer, static_cast(xqaParams.qkv_bias), + xqaParams.spec_decoding_generation_lengths, xqaParams.sequence_lengths, + xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr, launchParams.rotary_inv_freq_buf, + (float2 const*) nullptr, xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets, + int(batch_beam_size), xqaParams.generation_input_length, xqaParams.timestep, + xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length, int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, @@ -293,6 +299,15 @@ class XQAKernelList } sync_check_cuda_error(); + + if (needOutputCvt) + { + tensorrt_llm::kernels::invokeConversion<__nv_fp8_e4m3, T>(static_cast<__nv_fp8_e4m3*>(xqaParams.output), + static_cast(launchParams.output), + xqaParams.head_size * xqaParams.num_q_heads * xqaParams.total_num_input_tokens, xqaParams.fp8_out_scale, + stream); + sync_check_cuda_error(); + } } private: diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp index a83a7b13b..76ccd2e03 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp @@ -73,9 +73,11 @@ constexpr inline T roundUp(T a, T b) } // namespace -size_t DecoderXQARunner::getWorkspaceSize(int max_batch_beam_size) +size_t DecoderXQARunner::getWorkspaceSize(int max_batch_beam_size, int max_num_tokens) { - size_t workspace_size = 0; + // buffer for RoPE / output quantization. + constexpr size_t kXQA_OUT_ELEM_SIZE = 2; // fp16 or bf16. + size_t workspace_size = kXQA_OUT_ELEM_SIZE * mHeadSize * mNumHeads * max_num_tokens; if (mMultiBlockMode) { int workspaces[4]; @@ -90,7 +92,8 @@ size_t DecoderXQARunner::getWorkspaceSize(int max_batch_beam_size) = roundUp(sizeof(__half) * kMaxBeamWidth * group_size * mHeadSize, 128); workspaces[3] = multi_block_workspace_alignment * xqaMaxNbCtaPerKVHeadFactor() * mNumKVHeads * divUp(max_batch_beam_size, kMaxBeamWidth); - workspace_size = roundUp(workspaces[0], multi_block_workspace_alignment) + workspace_size = roundUp(workspace_size, multi_block_workspace_alignment) + + roundUp(workspaces[0], multi_block_workspace_alignment) + roundUp(workspaces[1], multi_block_workspace_alignment) + roundUp(workspaces[2], multi_block_workspace_alignment) + roundUp(workspaces[3], multi_block_workspace_alignment) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h index cb2ce35f0..da79fda3b 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h @@ -155,7 +155,7 @@ class DecoderXQARunner return shouldUseImpl(xqaParams, forConfigurePlugin); } - size_t getWorkspaceSize(int max_batch_beam_size); + size_t getWorkspaceSize(int max_batch_beam_size, int max_num_tokens); void prepare(XQAParams const& xqa_params) { diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h index 0160cc4ef..3f7bb5288 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h @@ -84,6 +84,10 @@ struct XQAParams int max_distance = 0; bool multi_block_mode; bool multi_query_tokens = false; + + int32_t total_num_input_tokens; // total number of input tokens. may differ from batch_size due to medusa. + float const* fp8_out_scale = nullptr; // fp8 output scale in case we need post-processing to convert output to fp8. + // nullptr means no conversion. }; } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/gptKernels.cu b/cpp/tensorrt_llm/kernels/gptKernels.cu index 4b226f38a..de64e6377 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.cu +++ b/cpp/tensorrt_llm/kernels/gptKernels.cu @@ -189,7 +189,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets template __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, int const* seqLengths, int maxQSeqLength, - int attentionWindowSize, AttentionMaskType attentionMaskType) + int attentionWindowSize, AttentionMaskType attentionMaskType, BlockSparseParams blockSparseParams) { // The index of the sequence in the batch. int batchIdx = blockIdx.y; @@ -264,6 +264,9 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, int c // 1 1 1 1 0 // 1 1 1 1 1 break; + case AttentionMaskType::BLOCKSPARSE: + isValid = blockSparseParams.computeMask(rowIdx, colIdx, seqLength, 1 /*num_heads*/, 0 /*head_id*/); + break; } // Store the mask. @@ -313,7 +316,7 @@ void invokeBuildDecoderInfo(BuildDecoderInfoParams const& params, cudaStream_ } dim3 grid(blocksPerSeq, params.batchSize); computeAttentionMask<<>>(params.attentionMask, params.seqQLengths, - params.maxQSeqLength, params.attentionWindowSize, params.attentionMaskType); + params.maxQSeqLength, params.attentionWindowSize, params.attentionMaskType, params.blockSparseParams); } } diff --git a/cpp/tensorrt_llm/kernels/gptKernels.h b/cpp/tensorrt_llm/kernels/gptKernels.h index 1310a1bd1..7f960f2d6 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.h +++ b/cpp/tensorrt_llm/kernels/gptKernels.h @@ -36,7 +36,9 @@ enum class AttentionMaskType BIDIRECTIONAL = 2, // See GLM-10B mask. // TODO: merge this mask into BIDIRECTIONAL - BIDIRECTIONALGLM = 3 + BIDIRECTIONALGLM = 3, + // For Phi-3-small model + BLOCKSPARSE = 4, }; enum class PositionEmbeddingType : int8_t @@ -59,6 +61,31 @@ enum class RotaryScalingType : int8_t kDYNAMIC = 2, }; +struct BlockSparseParams +{ + int block_size; + int homo_head_pattern; + int num_local_blocks; // Sliding window blocks + int vertical_stride; + + __device__ bool computeMask(int row_idx, int col_idx, int seq_length, int num_heads, int head_idx) const + { + bool causal_mask = row_idx < seq_length && col_idx < seq_length && col_idx <= row_idx; + + // Mask 1/0 decision is made at block_size granularity + int block_row_idx = row_idx / block_size; + int block_col_idx = col_idx / block_size; + + bool block_local_mask = (block_row_idx - block_col_idx) < num_local_blocks; + + int head_sliding_step = homo_head_pattern ? 0 : std::max(1, int(vertical_stride / num_heads)); + bool block_vertical_stride_mask = ((block_col_idx + head_idx * head_sliding_step + 1) % vertical_stride) == 0; + + bool is_valid = causal_mask && (block_local_mask || block_vertical_stride_mask); + return is_valid; + } +}; + template struct BuildDecoderInfoParams { @@ -97,6 +124,8 @@ struct BuildDecoderInfoParams int numTokens; // The type of attention. AttentionMaskType attentionMaskType; + // Params for block sparse pattern + BlockSparseParams blockSparseParams; // Rotary Embedding inv_freq. // [batch_size, halfRotaryDim] variable across different requests due to dynamic scaling. diff --git a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu index 07cf8492f..3472eca9b 100644 --- a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu @@ -975,7 +975,8 @@ std::vector CutlassMoeFCRunner::getWo size_t const permuted_elems = num_moe_inputs * hidden_size; size_t const interbuf_elems = num_moe_inputs * inter_size; size_t glu_inter_elems = 0; - if (isGatedActivation(activation_type)) + bool is_gated_activation = isGatedActivation(activation_type); + if (is_gated_activation) { glu_inter_elems = interbuf_elems * 2; } @@ -1180,7 +1181,8 @@ void CutlassMoeFCRunner::runMoe(void const* i sync_check_cuda_error(); bool const is_gated_activation = isGatedActivation(fc1_activation_type); - size_t const fc1_out_size = is_gated_activation ? inter_size * 2 : inter_size; + bool const use_fused_moe = moe_gemm_runner_.isFusedGatedActivation(is_gated_activation, inter_size, hidden_size); + size_t const fc1_out_size = ((!use_fused_moe) && is_gated_activation) ? inter_size * 2 : inter_size; // Upper bound on number of expanded rows int64_t const expanded_active_expert_rows = k * active_rows; @@ -1209,7 +1211,7 @@ void CutlassMoeFCRunner::runMoe(void const* i sync_check_cuda_error(); moe_gemm_runner_.moeGemm(permuted_data_, nullptr, nullptr, nullptr, total_rows_before_expert_, hopper_input, - expanded_active_expert_rows, fc1_out_size, hidden_size, num_experts_per_node, stream); + expanded_active_expert_rows, fc1_out_size, hidden_size, num_experts_per_node, false, stream); sync_check_cuda_error(); @@ -1223,24 +1225,27 @@ void CutlassMoeFCRunner::runMoe(void const* i { moe_gemm_runner_.moeGemmBiasAct(permuted_data_, fc1_expert_weights, fc1_int_scales, fc1_expert_biases, fc1_result_, total_rows_before_expert_, HopperGroupedGemmInput{}, expanded_active_expert_rows, fc1_out_size, - hidden_size, num_experts_per_node, fc1_activation_type, stream); + hidden_size, num_experts_per_node, fc1_activation_type, use_fused_moe, stream); sync_check_cuda_error(); } else { // Run the GEMM with activation function overridden with `Identity`, we do the activation separately + ActivationType activation_type = (use_fused_moe) ? fc1_activation_type : ActivationType::Identity; + T* gemm_result = (use_fused_moe) ? fc1_result_ : static_cast(glu_inter_result_); moe_gemm_runner_.moeGemmBiasAct(permuted_data_, fc1_expert_weights, fc1_int_scales, fc1_expert_biases, - static_cast(glu_inter_result_), total_rows_before_expert_, HopperGroupedGemmInput{}, - expanded_active_expert_rows, fc1_out_size, hidden_size, num_experts_per_node, ActivationType::Identity, - stream); + gemm_result, total_rows_before_expert_, HopperGroupedGemmInput{}, expanded_active_expert_rows, fc1_out_size, + hidden_size, num_experts_per_node, activation_type, use_fused_moe, stream); sync_check_cuda_error(); + if (!use_fused_moe) + { + doGatedActivation(fc1_result_, static_cast(glu_inter_result_), num_valid_tokens_ptr, + inter_size, num_rows * k, fc1_activation_type, stream); - doGatedActivation(fc1_result_, static_cast(glu_inter_result_), num_valid_tokens_ptr, inter_size, - num_rows * k, fc1_activation_type, stream); - - sync_check_cuda_error(); + sync_check_cuda_error(); + } } sync_check_cuda_error(); @@ -1255,7 +1260,7 @@ void CutlassMoeFCRunner::runMoe(void const* i moe_gemm_runner_.moeGemm(fc1_result_, fc2_expert_weights, fc2_int_scales, static_cast(fc2_result_), total_rows_before_expert_, hopper_input, expanded_active_expert_rows, hidden_size, inter_size, - num_experts_per_node, stream); + num_experts_per_node, false, stream); sync_check_cuda_error(); diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu index 7edae7ac3..02bcd2a5a 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu @@ -180,7 +180,8 @@ INSTANTIATE_ADDQKVBIASIA3_TRANSPOSE(__nv_bfloat16); template __global__ void softmax_kernel(T* attn_score, const T_IN* qk, T const* attn_mask, T const* linear_bias_slopes, const int64_t batch_size, const int64_t head_num, const int64_t q_length, const int64_t k_length, - float const qk_scale, float const qk_tanh_scale, float const qk_tanh_inverse_scale) + float const qk_scale, float const qk_tanh_scale, float const qk_tanh_inverse_scale, bool const block_sparse_attn, + BlockSparseParams const block_sparse_params, int const* q_seq_lengths) { // attn_score, [batch_size, num_heads, q_length, k_length] // qk, [batch_size, num_heads, q_length, k_length] @@ -217,8 +218,17 @@ __global__ void softmax_kernel(T* attn_score, const T_IN* qk, T const* attn_mask qk_bias += static_cast(linear_bias_slope * (ki - qi)); } - int64_t mask_offset = ((int64_t) bi * q_length + qi) * k_length + ki; - float mask_val = static_cast(ldg(&attn_mask[mask_offset])); + float mask_val; + if (block_sparse_attn && block_sparse_params.homo_head_pattern == false) + { + // We cannot share attention mask across heads. Instead, we compute mask on the fly here. + mask_val = block_sparse_params.computeMask(qi, ki, q_seq_lengths[bi], head_num, hi) ? 1.f : 0.f; + } + else + { + int64_t mask_offset = ((int64_t) bi * q_length + qi) * k_length + ki; + mask_val = static_cast(ldg(&attn_mask[mask_offset])); + } qk_bias += (1.0f - mask_val) * -10000.0f; data[i] = qk_scale * qk_val + qk_bias; @@ -264,7 +274,8 @@ __global__ void softmax_kernel(T* attn_score, const T_IN* qk, T const* attn_mask template __global__ void softmax_kernel_h2(T* attn_score, T const* qk_buf, T const* attn_mask, T const* linear_bias_slopes, const int64_t batch_size, const int64_t head_num, const int64_t q_length, const int64_t k_length, const T qk_scale, - float const qk_tanh_scale, float const qk_tanh_inverse_scale) + float const qk_tanh_scale, float const qk_tanh_inverse_scale, bool const block_sparse_attn, + BlockSparseParams const block_sparse_params, int const* q_seq_lengths) { // attn_score, [batch_size, num_heads, q_length, k_length] // qk, [batch_size, num_heads, q_length, k_length] @@ -324,7 +335,15 @@ __global__ void softmax_kernel_h2(T* attn_score, T const* qk_buf, T const* attn_ qk_bias = hadd2(qk_bias, hmul2(linear_bias_slope, dist)); } - T2 mask_val = ldg(&attn_mask_h2[mask_offset]); + T2 mask_val; + if (block_sparse_attn && block_sparse_params.homo_head_pattern == false) + { + mask_val = block_sparse_params.computeMask(qi, ki, q_seq_lengths[bi], head_num, hi) ? ONE : ZERO; + } + else + { + mask_val = ldg(&attn_mask_h2[mask_offset]); + } qk_bias = hadd2(qk_bias, hmul2(hsub2(ONE, mask_val), NEG_INFTY)); data[i] = hadd2(hmul2(qk, qk_scale_h2), qk_bias); @@ -374,7 +393,8 @@ __global__ void softmax_kernel_h2(T* attn_score, T const* qk_buf, T const* attn_ template __global__ void softmax_kernel_h2_v2(T* attn_score, T const* qk_buf, T const* attn_mask, T const* linear_bias_slopes, const int64_t batch_size, const int64_t head_num, const int64_t q_length, const int64_t k_length, const T scalar, - float const qk_tanh_scale, float const qk_tanh_inverse_scale) + float const qk_tanh_scale, float const qk_tanh_inverse_scale, bool const block_sparse_attn, + BlockSparseParams const block_sparse_params, int const* q_seq_lengths) { // attn_score, [batch_size, num_heads, q_length, k_length] // qk, [batch_size, num_heads, q_length, k_length] @@ -436,7 +456,14 @@ __global__ void softmax_kernel_h2_v2(T* attn_score, T const* qk_buf, T const* at T2 mask_val[Q_ITEMS_PER_THREAD]; for (int j = 0; j < q_items; j++) { - mask_val[j] = ldg(&attn_mask_h2[mask_offset[j]]); + if (block_sparse_attn && block_sparse_params.homo_head_pattern == false) + { + mask_val[j] = block_sparse_params.computeMask(qi, ki, q_seq_lengths[bi], head_num, hi) ? ONE : ZERO; + } + else + { + mask_val[j] = ldg(&attn_mask_h2[mask_offset[j]]); + } } T2 qk[Q_ITEMS_PER_THREAD]; @@ -573,21 +600,24 @@ __global__ void softmax_kernel_h2_v2(T* attn_score, T const* qk_buf, T const* at softmax_kernel_h2_v2<<>>((T_*) param.attention_score, \ (const T_*) param.qk, (const T_*) param.attention_mask, (const T_*) param.linear_bias_slopes, \ param.batch_size, param.num_heads, param.q_length, param.k_length, (const T_) param.qk_scale, \ - param.qk_tanh_scale, param.qk_tanh_inverse_scale); \ + param.qk_tanh_scale, param.qk_tanh_inverse_scale, param.block_sparse_attn, param.block_sparse_params, \ + param.q_seq_lengths); \ } \ else \ { \ softmax_kernel_h2<<>>((T_*) param.attention_score, \ (const T_*) param.qk, (const T_*) param.attention_mask, (const T_*) param.linear_bias_slopes, \ param.batch_size, param.num_heads, param.q_length, param.k_length, (const T_) param.qk_scale, \ - param.qk_tanh_scale, param.qk_tanh_inverse_scale); \ + param.qk_tanh_scale, param.qk_tanh_inverse_scale, param.block_sparse_attn, param.block_sparse_params, \ + param.q_seq_lengths); \ } \ } \ else \ { \ softmax_kernel<<>>(param.attention_score, param.qk, \ param.attention_mask, param.linear_bias_slopes, param.batch_size, param.num_heads, param.q_length, \ - param.k_length, param.qk_scale, param.qk_tanh_scale, param.qk_tanh_inverse_scale); \ + param.k_length, param.qk_scale, param.qk_tanh_scale, param.qk_tanh_inverse_scale, param.block_sparse_attn, \ + param.block_sparse_params, param.q_seq_lengths); \ } #define LAUNCH_MASKED_SOFTMAX(ITEMS_PER_THREAD) LAUNCH_MASKED_SOFTMAX_(half, ITEMS_PER_THREAD) @@ -1996,5 +2026,108 @@ INSTANTIATE_SHIFT_K_CACHE(__nv_bfloat16); #undef INSTANTIATE_SHIFT_K_CACHE_CACHE_TYPE #undef INSTANTIATE_SHIFT_K_CACHE +namespace +{ +template +struct alignas(std::max(alignof(T), std::min(sizeof(T) * size_, 16))) Vec +{ + using Elem = T; + static constexpr uint32_t size = size_; + Elem data[size]; + + __device__ inline void fill(T val) + { +#pragma unroll + for (uint32_t i = 0; i < size; i++) + { + data[i] = val; + } + } + + static __device__ inline Vec filled(T val) + { + Vec ret; + ret.fill(val); + return ret; + } + + __device__ inline Elem const& operator[](uint32_t i) const + { + assert(i < size); + return data[i]; + } + + __device__ inline Elem& operator[](uint32_t i) + { + assert(i < size); + return data[i]; + } +}; + +template +__global__ void convertData(Dst* dst, Src const* src, int64_t size, float const* __restrict__ pScale) +{ + constexpr uint32_t srcElemSize = sizeof(Src); + constexpr uint32_t dstElemSize = sizeof(Dst); + static_assert((srcElemSize & (srcElemSize - 1)) == 0 && (dstElemSize & (dstElemSize - 1)) == 0); + assert(reinterpret_cast(dst) % 16 == 0 && reinterpret_cast(src) % 16 == 0); + constexpr uint32_t packSize = 16 / std::max(srcElemSize, dstElemSize); + auto const tid = blockDim.x * blockIdx.x + threadIdx.x; + auto const nbThrds = blockDim.x * gridDim.x; + if (nbThrds * packSize + packSize - 1 >= size) + { + return; + } + float const scale = (pScale == nullptr ? 1.F : pScale[0]); + using SrcPack = Vec; + using DstPack = Vec; + int64_t const stride = packSize * nbThrds; + for (int64_t i = tid * packSize; i < size; i += stride) + { + if (i + packSize < size) + { + auto const srcPack = reinterpret_cast(src[i]); + DstPack dstPack; +#pragma unroll + for (int32_t j = 0; j < packSize; j++) + { + dstPack[j] = Dst{float{srcPack[j]} * scale}; + } + reinterpret_cast(dst[i]) = dstPack; + } + else + { +#pragma unroll + for (int64_t j = 0; j < packSize; j++) + { + if (i + j >= size) + { + break; + } + dst[i + j] = Dst{float{src[i + j]} * scale}; + } + } + } +} +} // unnamed namespace + +template +void invokeConversion(Dst* dst, Src const* src, int64_t size, float const* __restrict__ scale, cudaStream_t stream) +{ + auto const packSize = 16 / std::max(sizeof(Dst), sizeof(Src)); + auto const nbPack = divUp(size, packSize); + uint32_t const ctaSize = 256; + auto const nbCta = std::min(divUp(nbPack, ctaSize), 4096); + + convertData<<>>(dst, src, size, scale); +} + +#define INSTANTIATE_invokeConversion(Dst, Src) \ + template void invokeConversion( \ + Dst * dst, Src const* src, int64_t size, float const* __restrict__ scale, cudaStream_t stream) +INSTANTIATE_invokeConversion(__nv_fp8_e4m3, half); +INSTANTIATE_invokeConversion(__nv_fp8_e4m3, __nv_bfloat16); +#undef INSTANTIATE_invokeConversion + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h index 26edc56bd..e9b18f5fb 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h @@ -45,6 +45,9 @@ struct MaskedSoftmaxParam // always float compute data type. float qk_tanh_scale = 0.f; float qk_tanh_inverse_scale = 0.f; + bool block_sparse_attn = false; + BlockSparseParams block_sparse_params; + int const* q_seq_lengths = nullptr; // (batch_size) // Optional parameters that depend on the type of attention. // The slopes of the linear position bias of ALiBi. @@ -272,5 +275,9 @@ void invokeShiftKCache(KVCacheBuffer const& kvCacheBuffer, KVLinearBuffer const& float const* kScaleQuantOrig, int const* sequence_lengths, int const* input_lengths, int const rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, cudaStream_t stream); + +// compute src[x] * scale[0] and write into dst[x] +template +void invokeConversion(Dst* dst, Src const* src, int64_t size, float const* __restrict__ scale, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/layers/banWordsLayer.cpp b/cpp/tensorrt_llm/layers/banWordsLayer.cpp index 1dee8a3a9..b414092d8 100644 --- a/cpp/tensorrt_llm/layers/banWordsLayer.cpp +++ b/cpp/tensorrt_llm/layers/banWordsLayer.cpp @@ -20,6 +20,7 @@ #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/banBadWords.h" #include "tensorrt_llm/kernels/banRepeatNgram.h" +#include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/layers/layerUtils.h" #include @@ -40,37 +41,101 @@ BanWordsLayer::BanWordsLayer(executor::DecodingMode const& mode, DecoderDomai , mDecodingMode(mode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + initialize(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +BanWordsLayer::~BanWordsLayer() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + freeBuffer(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void BanWordsLayer::initialize() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + allocateBuffer(); + + mNoRepeatNgramSize.resize(mDecoderDomain.getBatchSize()); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void BanWordsLayer::allocateBuffer() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + if (mDecodingMode.isUseNoRepeatNgramSize()) + { + mNoRepeatNgramSizeDevice + = mAllocator->reMalloc(mNoRepeatNgramSizeDevice, sizeof(SizeType32) * mDecoderDomain.getBatchSize(), false); + } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void BanWordsLayer::freeBuffer() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + if (mDecodingMode.isUseNoRepeatNgramSize()) + { + mAllocator->free((void**) (&mNoRepeatNgramSizeDevice)); + } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BanWordsLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType32 const* batchSlots, - std::shared_ptr setupParams) + std::shared_ptr baseSetupParams) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + auto setupParams = std::dynamic_pointer_cast(baseSetupParams); + std::vector batchSlotsVec(batchSize); + std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0); + auto batchSlotsHost = batchSlots ? batchSlots : batchSlotsVec.data(); + auto const& penaltyParams = setupParams->penaltyParams; + bool const useNoRepeatNgramSize + = mDecodingMode.isUseNoRepeatNgramSize() && penaltyParams.noRepeatNgramSize.has_value(); + FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mStream}; + mUseNoRepeatNgramSize |= useNoRepeatNgramSize; + if (mUseNoRepeatNgramSize) + { + fillBuffers(penaltyParams.noRepeatNgramSize, DefaultDecodingParams::getNoRepeatNgramSize(), mNoRepeatNgramSize, + mNoRepeatNgramSizeDevice, batchSlotsHost, std::make_pair(0.f, std::numeric_limits::max()), + "no_repeat_ngram_size"); + } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } template void BanWordsLayer::banRepeatNGrams(Tensor& logits, std::shared_ptr const& outputs, std::shared_ptr const& inputs, SizeType32 const* batchSlots, - DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, cudaStream_t stream) + SizeType32 const* noRepeatNgramSizeDevice, DecoderDomain const& decoderDomain, SizeType32 maxSeqLen, + bool useNoRepeatNgramSize, cudaStream_t stream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto const maxStep = inputs->step; - if (inputs->no_repeat_ngram_size) + // auto const maxStep = inputs->step; // TODO (bhsueh) Should we use step? but current inputs->step is always 0. + auto const maxStep = maxSeqLen; + if (useNoRepeatNgramSize) { - SizeType32 const* noRepeatNgramSizeBuf - = inputs->no_repeat_ngram_size.value().template getPtr(); - invokeBanRepeatNgram(logits.template getPtr(), outputs->output_ids_ptr.template getPtr(), reinterpret_cast( inputs->finished.value_or(Tensor{}).template getPtr()), outputs->parent_ids_ptr.template getPtr(), batchSlots, outputs->sequence_length->template getPtr(), decoderDomain.getBatchSize(), - decoderDomain.getBeamWidth(), maxSeqLen, - inputs->no_repeat_ngram_size.value().template getPtr(), - decoderDomain.getVocabSizePadded(), maxStep, stream); + decoderDomain.getBeamWidth(), maxSeqLen, noRepeatNgramSizeDevice, decoderDomain.getVocabSizePadded(), + maxStep, stream); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -110,7 +175,8 @@ void BanWordsLayer::forwardAsync( auto const maxSeqLen = outputs->output_ids.shape[outputs->output_ids.shape.size() - 1]; auto batchSlots = inputs->batch_slots ? inputs->batch_slots->template getPtr() : nullptr; - banRepeatNGrams(inputs->logits.value(), outputs, inputs, batchSlots, localDecoderDomain, maxSeqLen, mStream); + banRepeatNGrams(inputs->logits.value(), outputs, inputs, batchSlots, mNoRepeatNgramSizeDevice, localDecoderDomain, + maxSeqLen, mUseNoRepeatNgramSize, mStream); banBadWords(inputs->logits.value(), outputs, inputs, batchSlots, localDecoderDomain, maxSeqLen, mStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/layers/banWordsLayer.h b/cpp/tensorrt_llm/layers/banWordsLayer.h index 74e578a96..4630e03f1 100644 --- a/cpp/tensorrt_llm/layers/banWordsLayer.h +++ b/cpp/tensorrt_llm/layers/banWordsLayer.h @@ -42,21 +42,25 @@ class BanWordsLayer : public BaseLayer BanWordsLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain, cudaStream_t stream, std::shared_ptr allocator); - ~BanWordsLayer() override = default; + ~BanWordsLayer() override; void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 const* batchSlots, - std::shared_ptr setupParams) override; + std::shared_ptr baseSetupParams) override; //! \brief Modifies 'outputs->logits' in-place with -INF for banned words void forwardAsync(std::shared_ptr outputs, std::shared_ptr inputs) override; private: - static void banRepeatNGrams(tc::Tensor& logits, std::shared_ptr const& outputs, - std::shared_ptr const& params, runtime::SizeType32 const* batchSlots, - DecoderDomain const& decoderDomain, runtime::SizeType32 maxSeqLen, cudaStream_t stream); + void initialize(); + void allocateBuffer(); + void freeBuffer(); static void banBadWords(tc::Tensor& logits, std::shared_ptr const& outputs, std::shared_ptr const& params, runtime::SizeType32 const* batchSlots, DecoderDomain const& decoderDomain, runtime::SizeType32 maxSeqLen, cudaStream_t stream); + static void banRepeatNGrams(tc::Tensor& logits, std::shared_ptr const& outputs, + std::shared_ptr const& inputs, runtime::SizeType32 const* batchSlots, + runtime::SizeType32 const* noRepeatNgramSizeDevice, DecoderDomain const& decoderDomain, + runtime::SizeType32 maxSeqLen, bool useNoRepeatNgramSize, cudaStream_t stream); private: using BaseLayer::mWorkspaceSize; @@ -66,6 +70,10 @@ class BanWordsLayer : public BaseLayer using BaseLayer::mAllocator; executor::DecodingMode mDecodingMode; + + runtime::SizeType32* mNoRepeatNgramSizeDevice{nullptr}; + std::vector mNoRepeatNgramSize; + bool mUseNoRepeatNgramSize{false}; }; } // namespace layers diff --git a/cpp/tensorrt_llm/layers/beamSearchLayer.cu b/cpp/tensorrt_llm/layers/beamSearchLayer.cu index a877808ad..50084144a 100644 --- a/cpp/tensorrt_llm/layers/beamSearchLayer.cu +++ b/cpp/tensorrt_llm/layers/beamSearchLayer.cu @@ -146,8 +146,9 @@ void BeamSearchLayer::forwardAsyncSingleRequest( T const* logits = ip->logits.template getPtr(); T const* bias = static_cast(nullptr); - TLLM_CHECK_WITH_INFO(mWorkspaceSize >= 2 * bh.nMaxBatchSize * bh.nBeamWidth * bh.nBeamWidth * 2, - std::string("Workspace size is not enough for topk softmax.")); + TLLM_CHECK_WITH_INFO(mWorkspaceSize >= 2 * bh.nBatchSize * bh.nBeamWidth * bh.nBeamWidth * 2, + fmtstr("Workspace size (%lu) is not enough for topk softmax required (%lu).", (uint64_t) mWorkspaceSize, + (uint64_t) (2 * bh.nMaxBatchSize * bh.nBeamWidth * bh.nBeamWidth * 2))); invokeTopkSoftMax(logits, bias, mWorkspace, bh, mStream); sync_check_cuda_error(); diff --git a/cpp/tensorrt_llm/layers/decodingLayer.cpp b/cpp/tensorrt_llm/layers/decodingLayer.cpp index 169d0262e..e318ac88b 100644 --- a/cpp/tensorrt_llm/layers/decodingLayer.cpp +++ b/cpp/tensorrt_llm/layers/decodingLayer.cpp @@ -61,7 +61,7 @@ bool hasDiffRuntimeArgs(std::shared_ptrpenaltyParams.frequencyPenalty) || !allSame(params->penaltyParams.presencePenalty) || !allSame(params->penaltyParams.repetitionPenalty) || !allSame(params->penaltyParams.temperature) - || !allSame(params->penaltyParams.minLength); + || !allSame(params->penaltyParams.minLength) || !allSame(params->penaltyParams.noRepeatNgramSize); } } // namespace diff --git a/cpp/tensorrt_llm/layers/decodingParams.h b/cpp/tensorrt_llm/layers/decodingParams.h index 2514afcc9..0a3afc8e0 100644 --- a/cpp/tensorrt_llm/layers/decodingParams.h +++ b/cpp/tensorrt_llm/layers/decodingParams.h @@ -112,11 +112,12 @@ class DynamicDecodeSetupParams : public BaseSetupParams // Penalty layer struct PenaltyParams { - std::optional> temperature; // [1] or [setupBatchSize] on cpu - std::optional> minLength; // [1] or [setupBatchSize] on cpu - std::optional> repetitionPenalty; // [1] or [setupBatchSize] on cpu - std::optional> presencePenalty; // [1] or [setupBatchSize] on cpu - std::optional> frequencyPenalty; // [1] or [setupBatchSize] on cpu + std::optional> temperature; // [1] or [setupBatchSize] on cpu + std::optional> minLength; // [1] or [setupBatchSize] on cpu + std::optional> repetitionPenalty; // [1] or [setupBatchSize] on cpu + std::optional> presencePenalty; // [1] or [setupBatchSize] on cpu + std::optional> frequencyPenalty; // [1] or [setupBatchSize] on cpu + std::optional> noRepeatNgramSize; // [1] or [setupBatchSize] on cpu }; struct SamplingParams @@ -220,7 +221,6 @@ class DynamicDecodeInputParams : public BaseInputParams std::optional bad_words_lengths; // [maxBatchSize], on gpu std::optional stop_words_ptr; // [maxBatchSize][2, stop_words_length], on gpu std::optional stop_words_lengths; // [maxBatchSize], on gpu - std::optional no_repeat_ngram_size; // [maxBatchSize], on gpu // Medusa inputs class MedusaInputs diff --git a/cpp/tensorrt_llm/layers/penaltyLayer.cpp b/cpp/tensorrt_llm/layers/penaltyLayer.cpp index 806280e60..614893573 100644 --- a/cpp/tensorrt_llm/layers/penaltyLayer.cpp +++ b/cpp/tensorrt_llm/layers/penaltyLayer.cpp @@ -215,44 +215,44 @@ void PenaltyLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, SizeType bool const useFrequencyPenalty = mDecodingMode.isUseFrequencyPenalty() && penaltyParams.frequencyPenalty.has_value(); bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams.minLength.has_value(); - if (useTemperature) + // FIXME(nkorobov): once one of the requests has some penalty, we will always have to compute it. + // To avoid that we need to scan through all active requests at each iteration. + mUseTemperature |= useTemperature; + mUseRepetitionPenalty |= useRepetitionPenalty; + mUsePresencePenalty |= usePresencePenalty; + mUseFrequencyPenalty |= useFrequencyPenalty; + mUseMinLength |= useMinLength; + + if (mUseTemperature) { fillBuffers(penaltyParams.temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty"); } - if (useRepetitionPenalty) + if (mUseRepetitionPenalty) { fillBuffers(penaltyParams.repetitionPenalty, DefaultDecodingParams::getRepetitionPenalty(), mRepetitionPenalty, mRepetitionPenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Repetition), "repetition penalty"); } - if (usePresencePenalty) + if (mUsePresencePenalty) { fillBuffers(penaltyParams.presencePenalty, DefaultDecodingParams::getPresencePenalty(), mPresencePenalty, mPresencePenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Presence), "presence penalty"); } - if (useFrequencyPenalty) + if (mUseFrequencyPenalty) { fillBuffers(penaltyParams.frequencyPenalty, DefaultDecodingParams::getFrequencyPenalty(), mFrequencyPenalty, mFrequencyPenaltyDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::Frequency), "frequency penalty"); } - if (useMinLength) + if (mUseMinLength) { fillBuffers(penaltyParams.minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice, batchSlotsHost, getLimitsPenalty(DecodingPenaltyType::MinLength), "min length"); } - // FIXME(nkorobov): once of the requests has some penalty, we will always have to compute it. - // To avoid that need scan through all active requests for each iteration. - mUseTemperature |= useTemperature; - mUseRepetitionPenalty |= useRepetitionPenalty; - mUsePresencePenalty |= usePresencePenalty; - mUseFrequencyPenalty |= useFrequencyPenalty; - mUseMinLength |= useMinLength; - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } diff --git a/cpp/tensorrt_llm/plugins/CMakeLists.txt b/cpp/tensorrt_llm/plugins/CMakeLists.txt index f8a281491..18c40d543 100755 --- a/cpp/tensorrt_llm/plugins/CMakeLists.txt +++ b/cpp/tensorrt_llm/plugins/CMakeLists.txt @@ -35,6 +35,7 @@ set(PLUGIN_LISTS gptAttentionPlugin identityPlugin gemmPlugin + gemmSwigluPlugin smoothQuantGemmPlugin quantizePerTokenPlugin quantizeTensorPlugin diff --git a/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp b/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp index d72fb58be..d8c63c565 100644 --- a/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp @@ -21,6 +21,7 @@ #include "tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h" #include "tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h" +#include "tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.h" #include "tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h" #include "tensorrt_llm/plugins/identityPlugin/identityPlugin.h" #include "tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h" @@ -173,6 +174,7 @@ extern "C" static tensorrt_llm::plugins::BertAttentionPluginCreator bertAttentionPluginCreator; static tensorrt_llm::plugins::GPTAttentionPluginCreator gptAttentionPluginCreator; static tensorrt_llm::plugins::GemmPluginCreator gemmPluginCreator; + static tensorrt_llm::plugins::GemmSwigluPluginCreator gemmSwigluPluginCreator; static tensorrt_llm::plugins::MixtureOfExpertsPluginCreator moePluginCreator; #if ENABLE_MULTI_DEVICE static tensorrt_llm::plugins::SendPluginCreator sendPluginCreator; @@ -201,6 +203,7 @@ extern "C" creatorPtr(bertAttentionPluginCreator), creatorPtr(gptAttentionPluginCreator), creatorPtr(gemmPluginCreator), + creatorPtr(gemmSwigluPluginCreator), creatorPtr(moePluginCreator), #if ENABLE_MULTI_DEVICE creatorPtr(sendPluginCreator), diff --git a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp index 301fb6fa4..a3aeb0c68 100644 --- a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp +++ b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp @@ -18,6 +18,7 @@ #include "tensorrt_llm/plugins/common/gemmPluginProfiler.h" #include "tensorrt_llm/common/cublasMMWrapper.h" #include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm.h" #include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h" #include "tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h" @@ -93,6 +94,19 @@ size_t GemmPluginProfiler::getSer * sizeof(std::pair>); // size of the tactics map } +template +int GemmPluginProfiler::getMaxProfileM() const +{ + return 8192; +} + +template +void GemmPluginProfiler::initTmpData( + int m, int n, int k, char* workspace, size_t size, cudaStream_t stream) +{ + /* Do nothing */ +} + template void GemmPluginProfiler::profileTactics( RunnerPtr const& runner, nvinfer1::DataType const& type, GemmDims const& dims, GemmIdType const& gemmId) @@ -107,7 +121,7 @@ void GemmPluginProfiler::profileT mRunner = runner; mType = type; - int const maxM = std::min(nextPowerOfTwo(dims.maxM), MAX_PROFILE_M); + int const maxM = std::min(nextPowerOfTwo(dims.maxM), getMaxProfileM()); computeTmpSize(maxM, dims.n, dims.k); if (!mMNKProfileMap->existsMProfileMap(gemmId)) @@ -170,7 +184,7 @@ std::optional GemmPluginProfilergetMProfileMap(gemmId)->at(mRounded); } @@ -301,4 +315,8 @@ template class GemmPluginProfiler; +template class GemmPluginProfiler, GemmIdCore, + GemmIdCoreHash>; + } // namespace tensorrt_llm::plugins diff --git a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h index aaf9e5c9e..a04bf6a6f 100644 --- a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h +++ b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h @@ -172,8 +172,6 @@ template >; using MProfileMapPtr = std::shared_ptr; @@ -244,6 +242,8 @@ class GemmPluginProfiler std::optional getBestConfig(int m, GemmIdType const& gemmId) const; + virtual int getMaxProfileM() const; + protected: virtual void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) = 0; @@ -256,7 +256,7 @@ class GemmPluginProfiler virtual std::vector getTactics(int m, int n, int k) const = 0; - virtual void initTmpData(int m, int n, int k, char* workspace, size_t size, cudaStream_t stream){}; + virtual void initTmpData(int m, int n, int k, char* workspace, size_t size, cudaStream_t stream); private: void allocateTmpData(); diff --git a/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp b/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp index 0e4fc20ec..2c4b84dd9 100644 --- a/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.cpp @@ -29,8 +29,11 @@ static char const* CUMSUM_LAST_DIM_PLUGIN_NAME{"CumsumLastDim"}; PluginFieldCollection CumsumLastDimPluginCreator::mFC{}; std::vector CumsumLastDimPluginCreator::mPluginAttributes; -CumsumLastDimPlugin::CumsumLastDimPlugin(int input_length, nvinfer1::DataType type) - : mInputLength(input_length) +static constexpr SizeType32 LENGTH_LIMIT_FOR_BLOCKSCAN = 4096; + +CumsumLastDimPlugin::CumsumLastDimPlugin(SizeType32 inputLength, nvinfer1::DataType type, size_t temp_storage_bytes) + : mInputLength(inputLength) + , mTempStorageBytes(temp_storage_bytes) , mType(type) { TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), @@ -38,6 +41,10 @@ CumsumLastDimPlugin::CumsumLastDimPlugin(int input_length, nvinfer1::DataType ty TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF) || (mType == DataType::kINT32), "Only support int, float, half, and bfloat16."); + if (mTempStorageBytes == 0) + { + mTempStorageBytes = getWorkspaceSizeNeeded(inputLength, type); + } } // Parameterized constructor @@ -45,6 +52,7 @@ CumsumLastDimPlugin::CumsumLastDimPlugin(void const* data, size_t length) { char const *d = reinterpret_cast(data), *a = d; read(d, mInputLength); + read(d, mTempStorageBytes); read(d, mType); TLLM_CHECK(d == a + length); TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), "Unsupported data type"); @@ -56,13 +64,13 @@ CumsumLastDimPlugin::CumsumLastDimPlugin(void const* data, size_t length) // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* CumsumLastDimPlugin::clone() const noexcept { - auto* plugin = new CumsumLastDimPlugin(mInputLength, mType); + auto* plugin = new CumsumLastDimPlugin(mInputLength, mType, mTempStorageBytes); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } // Outputs -// output_tensor: [batch_size, input_length] +// output_tensor: [batch_size, inputLength] nvinfer1::DimsExprs CumsumLastDimPlugin::getOutputDimensions( int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { @@ -81,28 +89,38 @@ void CumsumLastDimPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc cons { } -size_t CumsumLastDimPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, - nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept +size_t CumsumLastDimPlugin::getWorkspaceSizeNeeded(SizeType32 inputLength, nvinfer1::DataType type) { - if (mType == DataType::kINT32) + size_t tempStorageBytes; + if (inputLength < LENGTH_LIMIT_FOR_BLOCKSCAN) // last dim unknown or small, use BlockScan { - return invokeComputeCumsumLastDimWorkspaceSize(mInputLength); + tempStorageBytes = 0; } - else if (mType == DataType::kHALF) + else if (type == DataType::kINT32) { - return invokeComputeCumsumLastDimWorkspaceSize(mInputLength); + tempStorageBytes = invokeComputeCumsumLastDimWorkspaceSize(inputLength); } - else if (mType == DataType::kFLOAT) + else if (type == DataType::kHALF) { - return invokeComputeCumsumLastDimWorkspaceSize(mInputLength); + tempStorageBytes = invokeComputeCumsumLastDimWorkspaceSize(inputLength); + } + else if (type == DataType::kFLOAT) + { + tempStorageBytes = invokeComputeCumsumLastDimWorkspaceSize(inputLength); } #ifdef ENABLE_BF16 - else if (mType == DataType::kBF16) + else if (type == DataType::kBF16) { - return invokeComputeCumsumLastDimWorkspaceSize<__nv_bfloat16>(mInputLength); + tempStorageBytes = invokeComputeCumsumLastDimWorkspaceSize<__nv_bfloat16>(inputLength); } #endif - return 0; + return tempStorageBytes; +} + +size_t CumsumLastDimPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept +{ + return mTempStorageBytes; } template @@ -111,13 +129,19 @@ int CumsumLastDimPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc cudaStream_t stream) { // inputs - // 0. input_tensor [batch_size, input_length] + // 0. input_tensor [batch_size, inputLength] // outputs - // 0. output_tensor [batch_size, input_length] - auto const batch_size = inputDesc[getInputTensorIdx()].dims.d[0]; - size_t temp_storage_bytes = invokeComputeCumsumLastDimWorkspaceSize(mInputLength); + // 0. output_tensor [batch_size, inputLength] + auto const batchSize = inputDesc[getInputTensorIdx()].dims.d[0]; + auto const inputLength = inputDesc[getInputTensorIdx()].dims.d[1]; + /* + Two cases where we should use BlockScan: + 1. inputLength is small + 2. batchSize is large (since DeviceScan causes kernel launch per row) + */ + void* wp = inputLength < LENGTH_LIMIT_FOR_BLOCKSCAN || batchSize > 2 ? nullptr : workspace; invokeCumsumLastDim( - batch_size, mInputLength, inputs[getInputTensorIdx()], outputs[0], workspace, temp_storage_bytes, stream); + batchSize, inputLength, inputs[getInputTensorIdx()], outputs[0], wp, mTempStorageBytes, stream); return 0; } @@ -181,13 +205,14 @@ void CumsumLastDimPlugin::terminate() noexcept {} size_t CumsumLastDimPlugin::getSerializationSize() const noexcept { - return sizeof(mInputLength) + sizeof(mType); + return sizeof(mInputLength) + sizeof(mTempStorageBytes) + sizeof(mType); } void CumsumLastDimPlugin::serialize(void* buffer) const noexcept { char *d = static_cast(buffer), *a = d; write(d, mInputLength); + write(d, mTempStorageBytes); write(d, mType); assert(d == a + getSerializationSize()); } @@ -227,7 +252,7 @@ PluginFieldCollection const* CumsumLastDimPluginCreator::getFieldNames() noexcep IPluginV2* CumsumLastDimPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { PluginField const* fields = fc->fields; - int input_length; + int inputLength; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) @@ -236,7 +261,7 @@ IPluginV2* CumsumLastDimPluginCreator::createPlugin(char const* name, PluginFiel if (!strcmp(attrName, "input_length")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - input_length = static_cast(*(static_cast(fields[i].data))); + inputLength = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { @@ -246,7 +271,7 @@ IPluginV2* CumsumLastDimPluginCreator::createPlugin(char const* name, PluginFiel } try { - auto* obj = new CumsumLastDimPlugin(input_length, type); + auto* obj = new CumsumLastDimPlugin(inputLength, type); obj->setPluginNamespace(mNamespace.c_str()); return obj; } diff --git a/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h b/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h index 813168e29..3cbf4e235 100644 --- a/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h +++ b/cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h @@ -27,7 +27,9 @@ namespace tensorrt_llm::plugins class CumsumLastDimPlugin : public BasePlugin { public: - CumsumLastDimPlugin(int mInputLength, nvinfer1::DataType type); + using SizeType32 = tensorrt_llm::kernels::SizeType32; + + CumsumLastDimPlugin(SizeType32 inputLength, nvinfer1::DataType type, size_t tempStorageBytes = 0); CumsumLastDimPlugin(void const* data, size_t length); ~CumsumLastDimPlugin() override = default; // IPluginV2DynamicExt Methods @@ -45,6 +47,7 @@ class CumsumLastDimPlugin : public BasePlugin template int enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); + size_t getWorkspaceSizeNeeded(SizeType32 inputLength, nvinfer1::DataType type); // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( @@ -69,7 +72,8 @@ class CumsumLastDimPlugin : public BasePlugin }; private: - int mInputLength; + SizeType32 mInputLength; + size_t mTempStorageBytes; nvinfer1::DataType mType; }; diff --git a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp index 689bc6b23..c64d09e32 100644 --- a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp @@ -20,6 +20,7 @@ #include "gemmPluginProfiler.h" #include "plugin.h" #include "pluginUtils.h" +#include "tensorrt_llm/runtime/utils/debugUtils.h" #include @@ -348,8 +349,26 @@ int GemmPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P mTransA ? inputDesc[0].dims.d[0] - padK : inputDesc[0].dims.d[nbDimsA - 1] - padK); auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId); - runGemm(M, N, K, mTransA, mTransB, mPadLda, mPadLdb, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0], - bestTactic, workspace, stream); + + std::string mnkStr = "MNK={" + std::to_string(M) + ", " + std::to_string(N) + ", " + std::to_string(K) + "}"; + { + std::string const activationStr = "GEMM layer's activation before GEMM with " + mnkStr; + TLLM_CHECK_DEBUG_WITH_INFO( + tensorrt_llm::runtime::utils::tensorHasNan(M, K, mType, inputs[0], stream, activationStr) == false, + "Found NaN in " + activationStr); + } + + { + runGemm(M, N, K, mTransA, mTransB, mPadLda, mPadLdb, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0], + bestTactic, workspace, stream); + } + + { + std::string const outputStr = "GEMM layer's output after GEMM with " + mnkStr; + TLLM_CHECK_DEBUG_WITH_INFO( + tensorrt_llm::runtime::utils::tensorHasNan(M, N, mType, outputs[0], stream, outputStr) == false, + "Found NaN in " + outputStr); + } return 0; } diff --git a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/CMakeLists.txt b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/CMakeLists.txt new file mode 100644 index 000000000..3b714a392 --- /dev/null +++ b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/CMakeLists.txt @@ -0,0 +1,21 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# +file(GLOB SRCS *.cpp *.cu) +set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS}) +set(PLUGIN_SOURCES + ${PLUGIN_SOURCES} + PARENT_SCOPE) diff --git a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cpp b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cpp new file mode 100644 index 000000000..92ff693ca --- /dev/null +++ b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cpp @@ -0,0 +1,446 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gemmSwigluPlugin.h" +#include "cutlass_extensions/gemm_configs.h" + +#include +#include + +using namespace nvinfer1; +using namespace tensorrt_llm::common; +using namespace tensorrt_llm::kernels::cutlass_kernels; +using tensorrt_llm::plugins::GemmSwigluPluginCreator; +using tensorrt_llm::plugins::GemmSwigluPlugin; +using tensorrt_llm::plugins::GemmSwigluPluginProfiler; +using tensorrt_llm::plugins::read; +using tensorrt_llm::plugins::write; + +static char const* GEMM_SWIGLU_PLUGIN_VERSION{"1"}; +static char const* GEMM_SWIGLU_PLUGIN_NAME{"GemmSwiglu"}; +PluginFieldCollection GemmSwigluPluginCreator::mFC{}; +std::vector GemmSwigluPluginCreator::mPluginAttributes; + +size_t GemmSwigluPluginProfiler::getBytePerElement(nvinfer1::DataType type) +{ + size_t bpe; + if (type == nvinfer1::DataType::kHALF || type == nvinfer1::DataType::kBF16) + { + bpe = 2; + } + else if (type == nvinfer1::DataType::kINT8 || type == nvinfer1::DataType::kFP8) + { + bpe = 1; + } + else + { + TLLM_THROW("Not recognized/implemented"); + } + return bpe; +} + +void GemmSwigluPluginProfiler::setQuantMode(tensorrt_llm::common::QuantMode const& quantMode) +{ + mQuantMode = quantMode; +} + +void GemmSwigluPluginProfiler::runTactic( + int m, int n, int k, GemmSwigluPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream) +{ + size_t bpe = getBytePerElement(mType); + + // Workspace size required by gemm runner + // NB: this function will throw exception when selected tactic exceeds SMEM, which is then + // caught by gemmPluginProfiler and it will register this tactic as invalid + size_t wsSizeRunner = mRunner->getWorkspaceSize(m, n, k); + + // Workspace size required by profiling + size_t wsByteOffset = 0; + int8_t* wsBytePointer = reinterpret_cast(workspace); + void* aTmp = reinterpret_cast(nextWorkspacePtr(wsBytePointer, wsByteOffset, m * k * bpe)); + void* bTmp = reinterpret_cast(nextWorkspacePtr(wsBytePointer, wsByteOffset, n * k * bpe)); + void* cTmp = reinterpret_cast(nextWorkspacePtr(wsBytePointer, wsByteOffset, 1 * n * bpe)); + void* dTmp = reinterpret_cast(nextWorkspacePtr(wsBytePointer, wsByteOffset, m * (n / 2) * bpe)); + char* workspaceTmp = reinterpret_cast(nextWorkspacePtr(wsBytePointer, wsByteOffset, wsSizeRunner)); + + // Run profiling + mRunner->gemm( + dTmp, aTmp, bTmp, cTmp, mQuantMode, m, n, k, 1.0, 1.0, 1.0, tactic, workspaceTmp, wsSizeRunner, stream); +} + +int GemmSwigluPluginProfiler::getMaxProfileM() const +{ + return 32768; +} + +void GemmSwigluPluginProfiler::computeTmpSize(int maxM, int n, int k) +{ + std::vector workspaces = { + maxM * k * getBytePerElement(mType), // A + n * k * getBytePerElement(mType), // B + 1 * n * getBytePerElement(mType), // C_bias + maxM * (n / 2) * getBytePerElement(mType), // D + mRunner->getWorkspaceSize(maxM, n, k) // workspace + }; + size_t bytes = calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); + setTmpWorkspaceSizeInBytes(bytes); +} + +std::vector GemmSwigluPluginProfiler::getTactics(int m, int n, int k) const +{ + return mRunner->getConfigs(); +} + +GemmSwigluPlugin::GemmSwigluPlugin(QuantMode quantMode, nvinfer1::DataType type, bool hasBias, float scale_d0, + float scale_d1, float scale_output, GemmSwigluPlugin::PluginProfilerPtr const& pluginProfiler) + : mQuantMode(quantMode) + , mHasBias(hasBias) + , mScaleD0(scale_d0) + , mScaleD1(scale_d1) + , mScaleOutput(scale_output) + , mPluginProfiler(pluginProfiler) +{ + init(type); +} + +// Parameterized constructor +GemmSwigluPlugin::GemmSwigluPlugin( + void const* data, size_t length, GemmSwigluPlugin::PluginProfilerPtr const& pluginProfiler) + : mPluginProfiler(pluginProfiler) +{ + char const *d = reinterpret_cast(data), *a = d; + nvinfer1::DataType type; + unsigned int quantMode; + read(d, quantMode); + read(d, type); + read(d, mHasBias); + read(d, mScaleD0); + read(d, mScaleD1); + read(d, mScaleOutput); + read(d, mDims); + + mQuantMode = QuantMode(quantMode); + + init(type); + + mPluginProfiler->deserialize(d, mDims, mGemmId); + + TLLM_CHECK(d == a + length); +} + +void GemmSwigluPlugin::init(nvinfer1::DataType type) +{ + mType = type; + if (mType == nvinfer1::DataType::kFP8) + { + mGemmRunner = std::make_shared>(); + } + else + { + TLLM_THROW("Gemm Swiglu plugin only supports fp8 now"); + } + + mPluginProfiler->setQuantMode(mQuantMode); + + mGemmId = GemmIdCore(mDims.n, mDims.k, mType); +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* GemmSwigluPlugin::clone() const noexcept +{ + auto* plugin = new GemmSwigluPlugin(*this); + return plugin; +} + +nvinfer1::DimsExprs GemmSwigluPlugin::getOutputDimensions( + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept +{ + try + { + TLLM_CHECK(nbInputs == 3); + TLLM_CHECK(outputIndex == 0); + int const nbDimsA = inputs[0].nbDims; + TLLM_CHECK(nbDimsA >= 2); + DimsExprs ret; + ret.nbDims = nbDimsA; + for (int ii = 0; ii < nbDimsA - 1; ++ii) + { + ret.d[ii] = inputs[0].d[ii]; + } + ret.d[nbDimsA - 1] = exprBuilder.constant(inputs[1].d[1]->getConstantValue() / 2); + return ret; + } + catch (std::exception const& e) + { + caughtError(e); + } + return DimsExprs{}; +} + +bool GemmSwigluPlugin::supportsFormatCombination( + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept +{ + switch (pos) + { + case 0: + // activation + return inOut[pos].type == mType && inOut[pos].format == TensorFormat::kLINEAR; + case 1: + // weights + return inOut[pos].type == mType && inOut[pos].format == TensorFormat::kLINEAR; + case 2: + // bias + return inOut[pos].type == mType && inOut[pos].format == TensorFormat::kLINEAR; + case 3: + // out + return inOut[pos].type == mType && inOut[pos].format == TensorFormat::kLINEAR; + default: + // Never should be here + TLLM_CHECK(false); + return false; + } +} + +void GemmSwigluPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept +{ + auto const minM = std::accumulate(in[0].min.d, in[0].min.d + in[0].min.nbDims - 1, 1, std::multiplies()); + auto const maxM = std::accumulate(in[0].max.d, in[0].max.d + in[0].max.nbDims - 1, 1, std::multiplies()); + + int const maxK = in[0].max.d[in[0].max.nbDims - 1]; + int const maxN = in[1].max.d[1]; + int const minK = in[0].min.d[in[0].min.nbDims - 1]; + int const minN = in[1].min.d[1]; + + TLLM_CHECK_WITH_INFO(minN == maxN, "Variable out channels is not allowed"); + TLLM_CHECK_WITH_INFO(minK == maxK, "Variable in channels is not allowed"); + + if (!mDims.isInitialized()) + { + mDims = {minM, maxM, maxN, maxK}; + } + mGemmId = {maxN, maxK, mType}; + + mWorkspaceMaxSize = mGemmRunner->getWorkspaceSize(maxM, maxN, maxK); +} + +size_t GemmSwigluPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept +{ + return mWorkspaceMaxSize; +} + +int GemmSwigluPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +{ + // inputs + // mat1 [M(*), K] + // mat2 [K, N] + // bias [1, N] + // outputs + // mat [M(*), N / 2] + int m = 1; + for (int ii = 0; ii < inputDesc[0].dims.nbDims - 1; ++ii) + { + m *= inputDesc[0].dims.d[ii]; + } + int const n = inputDesc[1].dims.d[1]; + int const k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; + size_t const wsSize = mGemmRunner->getWorkspaceSize(m, n, k); + + auto const bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); + TLLM_CHECK_WITH_INFO(bestTactic, "No valid GEMM tactic"); + mGemmRunner->gemm(outputs[0], inputs[0], inputs[1], inputs[2], mQuantMode, m, n, k, mScaleD0, mScaleD1, + mScaleOutput, *bestTactic, reinterpret_cast(workspace), wsSize, stream); + + return 0; +} + +// IPluginV2Ext Methods +nvinfer1::DataType GemmSwigluPlugin::getOutputDataType( + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept +{ + TLLM_CHECK(index == 0); + return mType; +} + +// IPluginV2 Methods + +char const* GemmSwigluPlugin::getPluginType() const noexcept +{ + return GEMM_SWIGLU_PLUGIN_NAME; +} + +char const* GemmSwigluPlugin::getPluginVersion() const noexcept +{ + return GEMM_SWIGLU_PLUGIN_VERSION; +} + +int GemmSwigluPlugin::getNbOutputs() const noexcept +{ + return 1; +} + +int GemmSwigluPlugin::initialize() noexcept +{ + configGemm(); // gemm profiler in action + return 0; +} + +void GemmSwigluPlugin::terminate() noexcept {} + +size_t GemmSwigluPlugin::getSerializationSize() const noexcept +{ + return sizeof(unsigned int) + // QuantMode + sizeof(nvinfer1::DataType) + // dtype + sizeof(bool) + // hasBias + sizeof(float) * 3 + // scales + sizeof(mDims) + // Dimensions + mPluginProfiler->getSerializationSize(mGemmId); // selected tactics container size +} + +void GemmSwigluPlugin::serialize(void* buffer) const noexcept +{ + char *d = static_cast(buffer), *a = d; + write(d, mQuantMode.value()); + write(d, mType); + write(d, mHasBias); + write(d, mScaleD0); + write(d, mScaleD1); + write(d, mScaleOutput); + write(d, mDims); + + mPluginProfiler->serialize(d, mGemmId); + TLLM_CHECK(d == a + getSerializationSize()); +} + +void GemmSwigluPlugin::destroy() noexcept +{ + // This gets called when the network containing plugin is destroyed + delete this; +} + +void GemmSwigluPlugin::configGemm() +{ + mPluginProfiler->profileTactics(mGemmRunner, mType, mDims, mGemmId); +} + +/////////////// + +GemmSwigluPluginCreator::GemmSwigluPluginCreator() +{ + // Fill PluginFieldCollection with PluginField arguments metadata + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("has_bias", nullptr, PluginFieldType::kINT8, 1)); + mPluginAttributes.emplace_back(PluginField("scale_d0", nullptr, PluginFieldType::kFLOAT32, 1.0)); + mPluginAttributes.emplace_back(PluginField("scale_d1", nullptr, PluginFieldType::kFLOAT32, 1.0)); + mPluginAttributes.emplace_back(PluginField("scale_output", nullptr, PluginFieldType::kFLOAT32, 1.0)); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* GemmSwigluPluginCreator::getPluginName() const noexcept +{ + return GEMM_SWIGLU_PLUGIN_NAME; +} + +char const* GemmSwigluPluginCreator::getPluginVersion() const noexcept +{ + return GEMM_SWIGLU_PLUGIN_VERSION; +} + +PluginFieldCollection const* GemmSwigluPluginCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2* GemmSwigluPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept +{ + PluginField const* fields = fc->fields; + TLLM_CHECK(fc->nbFields == 5); + nvinfer1::DataType type; + bool hasBias; + float scale_d0; + float scale_d1; + float scale_output; + // Read configurations from each fields + for (int i = 0; i < fc->nbFields; ++i) + { + char const* attrName = fields[i].name; + if (!strcmp(attrName, "type_id")) + { + TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); + type = static_cast(*(static_cast(fields[i].data))); + } + else if (!strcmp(attrName, "has_bias")) + { + TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); + hasBias = static_cast(*(static_cast(fields[i].data))); + } + else if (!strcmp(attrName, "scale_d0")) + { + TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32); + scale_d0 = static_cast(*(static_cast(fields[i].data))); + } + else if (!strcmp(attrName, "scale_d1")) + { + TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32); + scale_d1 = static_cast(*(static_cast(fields[i].data))); + } + else if (!strcmp(attrName, "scale_output")) + { + TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32); + scale_output = static_cast(*(static_cast(fields[i].data))); + } + } + try + { + // GemmSwigluPluginCreator is unique and shared for an engine generation + // Create plugin profiler with shared tactics map + auto pluginProfiler = mGemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false); + QuantMode quantMode = QuantMode::fromDescription(); + auto* obj = new GemmSwigluPlugin(quantMode, type, hasBias, scale_d0, scale_d1, scale_output, pluginProfiler); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +IPluginV2* GemmSwigluPluginCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept +{ + // This object will be deleted when the network is destroyed, which will + // call GemmSwigluPlugin::destroy() + try + { + // Create plugin profiler with private tactics map which is read from the serialized engine + auto pluginProfiler = mGemmPluginProfileManager.createGemmPluginProfiler(/* inference */ true); + auto* obj = new GemmSwigluPlugin(serialData, serialLength, pluginProfiler); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} diff --git a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu new file mode 100644 index 000000000..1fe21bd91 --- /dev/null +++ b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu @@ -0,0 +1,41 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gemmSwigluPlugin.h" + +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass_extensions/gemm_configs.h" + +using namespace nvinfer1; +using namespace tensorrt_llm::common; +using namespace tensorrt_llm::kernels::cutlass_kernels; +using tensorrt_llm::plugins::GemmSwigluPluginCreator; +using tensorrt_llm::plugins::GemmSwigluPlugin; +using tensorrt_llm::plugins::GemmSwigluPluginProfiler; +using tensorrt_llm::plugins::read; +using tensorrt_llm::plugins::write; + +void GemmSwigluPluginProfiler::initTmpData(int m, int n, int k, char* workspace, size_t size, cudaStream_t stream) +{ + size_t bpe = getBytePerElement(mType); + + if (mType == nvinfer1::DataType::kFP8) + { + cutlass::reference::device::BlockFillRandomUniform(reinterpret_cast(workspace), + m * k + n * k + 1 * n, 42, cutlass::float_e4m3_t{128}, -cutlass::float_e4m3_t{128}, -1, stream); + } +} diff --git a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.h b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.h new file mode 100644 index 000000000..744f1fc04 --- /dev/null +++ b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.h @@ -0,0 +1,150 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm.h" +#include "tensorrt_llm/plugins/common/gemmPluginProfiler.h" +#include "tensorrt_llm/plugins/common/plugin.h" +#include +#include +#include +#include + +namespace tensorrt_llm::plugins +{ + +using GemmSwigluRunnerPtr + = std::shared_ptr; + +class GemmSwigluPluginProfiler : public GemmPluginProfiler + +{ +public: + using Config = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; + + void setQuantMode(tensorrt_llm::common::QuantMode const& quantMode); + + virtual int getMaxProfileM() const override; + +protected: + void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) override; + + void computeTmpSize(int maxM, int n, int k) override; + + // TODO(anchengc) implement checkTactic + // bool checkTactic(int m, int n, int k, const Config& tactic) const override; + + std::vector getTactics(int m, int n, int k) const override; + + void initTmpData(int m, int n, int k, char* workspace, size_t size, cudaStream_t stream) override; + +private: + size_t getBytePerElement(nvinfer1::DataType type); + + tensorrt_llm::common::QuantMode mQuantMode; +}; + +class GemmSwigluPlugin : public BasePlugin +{ +public: + using PluginProfilerPtr = std::shared_ptr; + + GemmSwigluPlugin() = delete; + + GemmSwigluPlugin(tensorrt_llm::common::QuantMode quantMode, nvinfer1::DataType type, bool hasBias, float scale_d0, + float scale_d1, float scale_output, PluginProfilerPtr const& pluginProfiler); + + GemmSwigluPlugin(void const* data, size_t length, PluginProfilerPtr const& profiler); + + ~GemmSwigluPlugin() override = default; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination( + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int getNbOutputs() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + +private: + void init(nvinfer1::DataType type); + + void configGemm(); + // void setGemmConfig(); + +private: + const std::string mLayerName; + + GemmSwigluRunnerPtr mGemmRunner; + tensorrt_llm::common::QuantMode mQuantMode; // not configurable yet + size_t mWorkspaceMaxSize; + + GemmDims mDims{}; + GemmIdCore mGemmId{}; + + PluginProfilerPtr mPluginProfiler; + + nvinfer1::DataType mType; + bool mHasBias; + float mScaleD0; + float mScaleD1; + float mScaleOutput; +}; + +class GemmSwigluPluginCreator : public BaseCreator +{ +public: + GemmSwigluPluginCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept override; + +private: + GemmPluginProfilerManager mGemmPluginProfileManager; + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; +}; + +} // namespace tensorrt_llm::plugins diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp index a5d84abd1..8863d8135 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp @@ -25,6 +25,7 @@ #include "tensorrt_llm/kernels/unfusedAttentionKernels.h" #include "tensorrt_llm/plugins/common/checkMacrosPlugin.h" #include "tensorrt_llm/runtime/iBuffer.h" +#include "tensorrt_llm/runtime/utils/debugUtils.h" #include #include #include @@ -68,9 +69,11 @@ struct FusedQKVMaskedAttentionDispatchParams float rotary_embedding_base; RotaryScalingType rotary_embedding_scale_type; float rotary_embedding_scale; - float rotary_embedding_m_scale; + float rotary_embedding_short_m_scale; + float rotary_embedding_long_m_scale; float const* rotary_embedding_scaling_factors; int rotary_embedding_max_positions; + int rotary_embedding_original_max_positions; int rotary_cogvlm_vision_start; int rotary_cogvlm_vision_length; PositionEmbeddingType position_embedding_type; @@ -108,6 +111,8 @@ struct FusedQKVMaskedAttentionDispatchParams bool cross_attention = false; int const* memory_length_per_sample = nullptr; int max_distance = 0; + bool block_sparse_attention = false; + BlockSparseParams block_sparse_params; }; template @@ -237,6 +242,9 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel xqaParams.spec_decoding_packed_mask = generationsParams.spec_decoding_packed_mask; xqaParams.spec_decoding_position_offsets = generationsParams.spec_decoding_position_offsets; xqaParams.spec_decoding_generation_lengths = generationsParams.spec_decoding_generation_lengths; + + xqaParams.total_num_input_tokens = generationsParams.total_num_input_tokens; + xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr); return true; } @@ -299,9 +307,11 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params(input_params.relative_attention_bias); params.relative_attention_bias_stride = input_params.relative_attention_bias_stride; params.max_distance = input_params.max_distance; + params.block_sparse_attention = input_params.block_sparse_attention; + params.block_sparse_params = input_params.block_sparse_params; // The slope of linear position bias per head, e.g., ALiBi. if (input_params.linear_bias_slopes != nullptr) @@ -378,14 +390,16 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, - float rotary_embedding_scale, float rotary_embedding_m_scale, int rotary_embedding_max_positions, int tp_size, + float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale, + int rotary_embedding_max_positions, int rotary_embedding_original_max_positions, int tp_size, int tp_rank, // for ALiBi bool unfuse_qkv_gemm, // for AutoPP tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, bool enable_xqa, int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, - bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, - bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha, - bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_spec_decoding_enabled) + tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block, + nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention, int max_distance, + bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha, + bool use_cache, bool is_spec_decoding_enabled) : mLayerIdx(layer_idx) , mNumHeads(num_heads) , mVisionStart(vision_start) @@ -399,13 +413,16 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, , mRotaryEmbeddingBase(rotary_embedding_base) , mRotaryEmbeddingScaleType(rotary_embedding_scale_type) , mRotaryEmbeddingScale(rotary_embedding_scale) - , mRotaryEmbeddingMscale(rotary_embedding_m_scale) + , mRotaryEmbeddingShortMscale(rotary_embedding_short_m_scale) + , mRotaryEmbeddingLongMscale(rotary_embedding_long_m_scale) , mRotaryEmbeddingMaxPositions(rotary_embedding_max_positions) + , mRotaryEmbeddingOriginalMaxPositions(rotary_embedding_original_max_positions) , mPositionEmbeddingType(position_embedding_type) , mEnableContextFMHA(context_fmha_type != ContextFMHAType::DISABLED) , mFMHAForceFP32Acc( context_fmha_type == ContextFMHAType::ENABLED_WITH_FP32_ACC || type == nvinfer1::DataType::kBF16) , mMaskType(mask_type) + , mBlockSparseParams(block_sparse_params) , mType(type) , mMultiBlockMode(multi_block_mode) , mEnableXQA(enable_xqa) @@ -514,8 +531,10 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng read(d, mRotaryEmbeddingBase); read(d, mRotaryEmbeddingScaleType); read(d, mRotaryEmbeddingScale); - read(d, mRotaryEmbeddingMscale); + read(d, mRotaryEmbeddingShortMscale); + read(d, mRotaryEmbeddingLongMscale); read(d, mRotaryEmbeddingMaxPositions); + read(d, mRotaryEmbeddingOriginalMaxPositions); read(d, mTpSize); read(d, mTpRank); read(d, mUnfuseQkvGemm); @@ -526,6 +545,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng read(d, kvCacheQuantMode); read(d, mRemovePadding); read(d, mMaskType); + read(d, mBlockSparseParams); read(d, mPagedKVCache); read(d, mTokensPerBlock); read(d, mType); @@ -619,7 +639,7 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t } size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration( - nvinfer1::DataType type, int32_t total_num_seq, int32_t max_attention_window) const noexcept + nvinfer1::DataType type, int32_t total_num_seq, int32_t max_attention_window, int32_t max_num_tokens) const noexcept { int const local_hidden_units_qo = mNumHeads * getHeadSize(); int const local_hidden_units_kv = mNumKVHeads * getHeadSize(); @@ -655,7 +675,7 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration( size_t mqa_workspaces[XQA_NUM_BUFFERS]; size_t const cu_seqlens_size = sizeof(int) * (batch_beam + 1); size_t const rotary_inv_freq_size = sizeof(float) * batch_beam * mRotaryEmbeddingDim / 2; - mqa_workspaces[0] = mDecoderXQARunner->getWorkspaceSize(batch_beam); + mqa_workspaces[0] = mDecoderXQARunner->getWorkspaceSize(batch_beam, max_num_tokens); mqa_workspaces[1] = cu_seqlens_size; mqa_workspaces[2] = rotary_inv_freq_size; mqa_workspace_size = tc::calculateTotalWorkspaceSize(mqa_workspaces, XQA_NUM_BUFFERS); @@ -819,6 +839,7 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams(params.attention_input), stream, beforeRopeStr) + == false, + "Found Nan in " + beforeRopeStr); + } invokeQKVPreprocessing(preprocessingParams, stream); + { + std::string const afterRopeStr = "ctx attention after RoPE at layer " + std::to_string(mLayerIdx); + TLLM_CHECK_DEBUG_WITH_INFO(tensorrt_llm::runtime::utils::tensorHasNan(params.num_tokens, + (local_hidden_units_qo + 2 * local_hidden_units_kv), mType, + const_cast(params.attention_input), stream, afterRopeStr) + == false, + "Found Nan in " + afterRopeStr); + } sync_check_cuda_error(); @@ -1114,6 +1152,9 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams(linear_bias_slopes); // (head_num,), optional + param.block_sparse_attn = mMaskType == AttentionMaskType::BLOCKSPARSE; + param.block_sparse_params = mBlockSparseParams; + param.q_seq_lengths = params.q_seq_lengths; invokeMaskedSoftmax(param, stream); } else @@ -1144,6 +1185,9 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams(linear_bias_slopes); // (head_num,), optional + param.block_sparse_attn = mMaskType == AttentionMaskType::BLOCKSPARSE; + param.block_sparse_params = mBlockSparseParams; + param.q_seq_lengths = params.q_seq_lengths; invokeMaskedSoftmax(param, stream); } @@ -1428,14 +1472,18 @@ int GPTAttentionPluginCommon::enqueueGeneration( dispatch_params.rotary_embedding_base = mRotaryEmbeddingBase; dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType; dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale; - dispatch_params.rotary_embedding_m_scale = mRotaryEmbeddingMscale; + dispatch_params.rotary_embedding_short_m_scale = mRotaryEmbeddingShortMscale; + dispatch_params.rotary_embedding_long_m_scale = mRotaryEmbeddingLongMscale; dispatch_params.rotary_embedding_scaling_factors = params.rotary_embedding_scaling_factors; dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions; + dispatch_params.rotary_embedding_original_max_positions = mRotaryEmbeddingOriginalMaxPositions; dispatch_params.position_shift_enabled = mPosShiftEnabled; dispatch_params.rotary_cogvlm_vision_start = mVisionStart; dispatch_params.rotary_cogvlm_vision_length = mVisionLength; dispatch_params.cross_attention = mCrossAttention; dispatch_params.memory_length_per_sample = params.encoder_input_lengths; + dispatch_params.block_sparse_attention = mMaskType == AttentionMaskType::BLOCKSPARSE; + dispatch_params.block_sparse_params = mBlockSparseParams; using DataType = typename SATypeConverter::Type; if (!mCrossAttention) @@ -1600,15 +1648,17 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling) + sizeof(mQKTanhScale) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase) - + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingMscale) - + sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingShortMscale) + + sizeof(mRotaryEmbeddingLongMscale) + sizeof(mRotaryEmbeddingMaxPositions) + + sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode - + sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType) - + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance) - + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mFP8ContextFMHA) - + sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm) + sizeof(mIsSpecDecodingEnabled) - + sizeof(mNbMultiBlockSemaphores) + sizeof(uint32_t) // size of mDecoderXQARunnerResource buffer. + + sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache) + + sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + + sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + + sizeof(mPagedContextFMHA) + sizeof(mFP8ContextFMHA) + sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm) + + sizeof(mIsSpecDecodingEnabled) + sizeof(mNbMultiBlockSemaphores) + + sizeof(uint32_t) // size of mDecoderXQARunnerResource buffer. + mDecoderXQARunnerResource.getSerializationSize(); } @@ -1629,8 +1679,10 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept write(d, mRotaryEmbeddingBase); write(d, mRotaryEmbeddingScaleType); write(d, mRotaryEmbeddingScale); - write(d, mRotaryEmbeddingMscale); + write(d, mRotaryEmbeddingShortMscale); + write(d, mRotaryEmbeddingLongMscale); write(d, mRotaryEmbeddingMaxPositions); + write(d, mRotaryEmbeddingOriginalMaxPositions); write(d, mTpSize); write(d, mTpRank); write(d, mUnfuseQkvGemm); @@ -1641,6 +1693,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept write(d, mKVCacheQuantMode.value()); write(d, mRemovePadding); write(d, mMaskType); + write(d, mBlockSparseParams); write(d, mPagedKVCache); write(d, mTokensPerBlock); write(d, mType); @@ -1718,8 +1771,12 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon() mPluginAttributes.emplace_back(PluginField("rotary_embedding_base", nullptr, PluginFieldType::kFLOAT32, 0)); mPluginAttributes.emplace_back(PluginField("rotary_embedding_scale_type", nullptr, PluginFieldType::kINT8, 0)); mPluginAttributes.emplace_back(PluginField("rotary_embedding_scale", nullptr, PluginFieldType::kFLOAT32, 0)); - mPluginAttributes.emplace_back(PluginField("rotary_embedding_m_scale", nullptr, PluginFieldType::kFLOAT32, 0)); + mPluginAttributes.emplace_back( + PluginField("rotary_embedding_short_m_scale", nullptr, PluginFieldType::kFLOAT32, 0)); + mPluginAttributes.emplace_back(PluginField("rotary_embedding_long_m_scale", nullptr, PluginFieldType::kFLOAT32, 0)); mPluginAttributes.emplace_back(PluginField("rotary_embedding_max_positions", nullptr, PluginFieldType::kINT32, 0)); + mPluginAttributes.emplace_back( + PluginField("rotary_embedding_original_max_positions", nullptr, PluginFieldType::kINT32, 0)); mPluginAttributes.emplace_back(PluginField("tp_size", nullptr, PluginFieldType::kINT32, 0)); mPluginAttributes.emplace_back(PluginField("tp_rank", nullptr, PluginFieldType::kINT32, 0)); mPluginAttributes.emplace_back(PluginField("unfuse_qkv_gemm", nullptr, PluginFieldType::kINT8, 0)); diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h index 6333accf0..848dee911 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h @@ -42,15 +42,17 @@ class GPTAttentionPluginCommon : public BasePlugin tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, - float rotary_embedding_scale, float rotary_embedding_m_scale, int rotary_embedding_max_positions, int tp_size, + float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale, + int rotary_embedding_max_positions, int rotary_embedding_original_max_positions, int tp_size, int tp_rank, // for ALiBi bool unfuse_qkv_gemm, // for AutoPP tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, bool enable_xqa, int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, - bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, - bool qkv_bias_enabled, bool cross_attention = false, int max_distance = 0, bool pos_shift_enabled = false, - bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false, - bool use_cache = true, bool is_spec_decoding_enabled = false); + tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block, + nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention = false, + int max_distance = 0, bool pos_shift_enabled = false, bool dense_context_fmha = false, + bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false, bool use_cache = true, + bool is_spec_decoding_enabled = false); GPTAttentionPluginCommon(void const* data, size_t length); @@ -84,8 +86,8 @@ class GPTAttentionPluginCommon : public BasePlugin size_t getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t nbReq, int32_t max_input_length, int32_t max_kv_cache_len, int32_t cross_qkv_length = 0, int32_t max_num_tokens = 0) const noexcept; // total_num_seq is the sum of beam_width for multiple requests - size_t getWorkspaceSizeForGeneration( - nvinfer1::DataType type, int32_t total_num_seq, int32_t max_kv_cache_length) const noexcept; + size_t getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t total_num_seq, int32_t max_kv_cache_length, + int32_t max_num_tokens) const noexcept; template struct EnqueueContextParams @@ -222,6 +224,7 @@ class GPTAttentionPluginCommon : public BasePlugin int32_t const* spec_decoding_packed_mask = nullptr; int32_t const* spec_decoding_position_offsets = nullptr; int32_t const* spec_decoding_generation_lengths = nullptr; + int32_t total_num_input_tokens; }; template @@ -295,11 +298,15 @@ class GPTAttentionPluginCommon : public BasePlugin float mRotaryEmbeddingBase; tensorrt_llm::kernels::RotaryScalingType mRotaryEmbeddingScaleType; float mRotaryEmbeddingScale; - float mRotaryEmbeddingMscale; + float mRotaryEmbeddingShortMscale; + float mRotaryEmbeddingLongMscale; int mRotaryEmbeddingMaxPositions; + int mRotaryEmbeddingOriginalMaxPositions; tensorrt_llm::kernels::PositionEmbeddingType mPositionEmbeddingType; bool mRemovePadding = false; tensorrt_llm::kernels::AttentionMaskType mMaskType; + tensorrt_llm::kernels::BlockSparseParams mBlockSparseParams; + // NOTE: default values for paged kv cache. bool mPagedKVCache = false; int mTokensPerBlock = 0; diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index 6497f9f37..07517d7a1 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -22,8 +22,10 @@ #include "tensorrt_llm/plugins/common/plugin.h" #include "tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h" #include "tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommonImpl.h" +#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/utils/debugUtils.h" #include #include #include @@ -43,21 +45,25 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_ tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int rotary_embedding_dim, // for RoPE. 0 for non-RoPE float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, - float rotary_embedding_scale, float rotary_embedding_m_scale, int rotary_embedding_max_positions, int tp_size, - int tp_rank, // for ALiBi - bool unfuse_qkv_gemm, // for AutoPP + float rotary_embedding_scale, float rotary_embedding_short_m_scale, + float rotary_embedding_long_m_scale, // magnitude scaling factors for Phi-3 long RoPE + int rotary_embedding_max_positions, int rotary_embedding_original_max_positions, int tp_size, + int tp_rank, // for ALiBi + bool unfuse_qkv_gemm, // for AutoPP tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, bool enable_xqa, int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, - bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, - bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha, - bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_spec_decoding_enabled) + tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block, + nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention, int max_distance, + bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha, + bool use_cache, bool is_spec_decoding_enabled) : GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size, unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base, - rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_m_scale, rotary_embedding_max_positions, - tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode, enable_xqa, kv_cache_quant_mode, - remove_input_padding, mask_type, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled, - cross_attention, max_distance, pos_shift_enabled, dense_context_fmha, use_paged_context_fmha, - use_fp8_context_fmha, use_cache, is_spec_decoding_enabled) + rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale, + rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size, + tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode, enable_xqa, kv_cache_quant_mode, + remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block, type, + max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled, dense_context_fmha, + use_paged_context_fmha, use_fp8_context_fmha, use_cache, is_spec_decoding_enabled) { initEntryIdx(); } @@ -362,7 +368,13 @@ size_t GPTAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* in type, nbReq, max_context_length, max_kv_cache_length, cross_qkv_length, max_num_tokens); int const total_num_seq = inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0]; - size_t const generation_workspace_size = getWorkspaceSizeForGeneration(type, total_num_seq, max_kv_cache_length); + + int32_t const num_spec_dec_tokens + = mIsSpecDecodingEnabled ? inputs[getIdx(IdxEntry::SPEC_DECODING_POSITION_OFFSETS)].dims.d[1] : 1; + int32_t const max_batch_beam = inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0]; + int32_t const max_num_gen_tokens = std::min(max_num_tokens, num_spec_dec_tokens * max_batch_beam); + size_t const generation_workspace_size + = getWorkspaceSizeForGeneration(type, total_num_seq, max_kv_cache_length, max_num_tokens); size_t attention_input_workspace_size = 0; if (mUnfuseQkvGemm) @@ -392,6 +404,7 @@ int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { + TLLM_LOG_TRACE("Attention plugin start at layer %d", mLayerIdx); int32_t const nbSeq = inputDesc[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0]; int32_t const beam_width = useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[1] : 1; RequestType const* reqTypes = static_cast(inputs[getIdx(IdxEntry::REQUEST_TYPES)]); @@ -444,6 +457,8 @@ int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, inputs, outputs, workspace, stream); } + TLLM_LOG_TRACE("Attention plugin stop at layer %d", mLayerIdx); + return 0; } @@ -705,6 +720,15 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 } enqueueContext(enqueue_params, stream); + + { + std::string const afterContexStr = "ctx attention at layer " + std::to_string(mLayerIdx); + TLLM_CHECK_DEBUG_WITH_INFO(tensorrt_llm::runtime::utils::tensorHasNan(localNbTokens, + outputDesc[0].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)], mType, + context_buf_, stream, afterContexStr) + == false, + "Found Nan in " + afterContexStr); + } } else // generation stage; max_context_q_len == input_seq_len == 1 { @@ -753,8 +777,18 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 enqueue_params.spec_decoding_position_offsets = spec_decoding_position_offsets; enqueue_params.spec_decoding_generation_lengths = spec_decoding_generation_lengths; } + enqueue_params.total_num_input_tokens = localNbTokens; enqueueGeneration(enqueue_params, stream); + + { + std::string const afterGenStr = "gen attention at layer " + std::to_string(mLayerIdx); + TLLM_CHECK_DEBUG_WITH_INFO(tensorrt_llm::runtime::utils::tensorHasNan(localNbTokens, + outputDesc[0].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)], mType, + context_buf_, stream, afterGenStr) + == false, + "Found Nan in " + afterGenStr); + } } return 0; @@ -885,8 +919,10 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField p.getScalar("rotary_embedding_dim").value(), p.getScalar("rotary_embedding_base").value(), static_cast(p.getScalar("rotary_embedding_scale_type").value()), p.getScalar("rotary_embedding_scale").value(), - p.getScalar("rotary_embedding_m_scale").value(), + p.getScalar("rotary_embedding_short_m_scale").value(), + p.getScalar("rotary_embedding_long_m_scale").value(), p.getScalar("rotary_embedding_max_positions").value(), + p.getScalar("rotary_embedding_original_max_positions").value(), static_cast(p.getScalar("tp_size").value()), static_cast(p.getScalar("tp_rank").value()), static_cast(p.getScalar("unfuse_qkv_gemm").value()), @@ -896,6 +932,10 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField p.getScalar("kv_cache_quant_mode").value(), static_cast(p.getScalar("remove_input_padding").value()), static_cast(p.getScalar("mask_type").value()), + BlockSparseParams{p.getScalar("block_sparse_block_size").value(), + static_cast(p.getScalar("block_sparse_homo_head_pattern").value()), + p.getScalar("block_sparse_num_local_blocks").value(), + p.getScalar("block_sparse_vertical_stride").value()}, static_cast(p.getScalar("paged_kv_cache").value()), p.getScalar("tokens_per_block").value(), static_cast(p.getScalar("type_id").value()), diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h index 9ed2c4342..4c27d4981 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h @@ -84,15 +84,17 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, int rotary_embedding_dim, // for RoPE. 0 for non-RoPE float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type, - float rotary_embedding_scale, float rotary_embedding_m_scale, int rotary_embedding_max_positions, int tp_size, + float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale, + int rotary_embedding_max_positions, int rotary_embedding_original_max_positions, int tp_size, int tp_rank, // for ALiBi bool unfuse_qkv_gemm, // for AutoPP tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, bool enable_xqa, int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, - bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, - bool qkv_bias_enabled, bool cross_attention = false, int max_distance = 0, bool pos_shift_enabled = false, - bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false, - bool use_cache = true, bool is_spec_decoding_enabled = false); + tensorrt_llm::kernels::BlockSparseParams block_sparse_params, bool paged_kv_cache, int tokens_per_block, + nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled, bool cross_attention = false, + int max_distance = 0, bool pos_shift_enabled = false, bool dense_context_fmha = false, + bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false, bool use_cache = true, + bool is_spec_decoding_enabled = false); GPTAttentionPlugin(void const* data, size_t length); diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp index 29339d9ff..1edbe5f86 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp @@ -287,7 +287,7 @@ int AllreducePlugin::getNbOutputs() const noexcept return 1; } -bool AllreducePlugin::isCustomAllReduceSuported(int ranks_per_node) const noexcept +bool AllreducePlugin::isCustomAllReduceSupported(int ranks_per_node) const noexcept { constexpr bool isCudaVersionSupported = #if defined(CUDART_VERSION) && CUDART_VERSION >= 11020 @@ -381,13 +381,31 @@ std::set getLocalGroup(std::set const& group) void AllreducePlugin::initGroupTopology() noexcept { + static std::map, std::tuple> cache; + if (cache.find(mGroup) != cache.end()) + { + auto [isNVLINKSupported, isP2PSupported] = cache[mGroup]; + mIsNVLINKSupported = isNVLINKSupported; + mIsP2PSupported = isP2PSupported; + return; + } + setGroupTopology(); + cache[mGroup] = {mIsNVLINKSupported, mIsP2PSupported}; +} + +void AllreducePlugin::setGroupTopology() noexcept +{ + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_INFO("Detecting local TP group for rank %d", rank); std::set localGroup = getLocalGroup(mGroup); if (mGroup.size() != localGroup.size()) { mIsP2PSupported = false; mIsNVLINKSupported = false; + TLLM_LOG_INFO("Found inter-node TP group for rank %d", rank); return; } + TLLM_LOG_INFO("TP group is intra-node for rank %d", rank); NvmlManager nvmlManager; std::unordered_set visitedDevice; @@ -492,7 +510,10 @@ int AllreducePlugin::initialize() noexcept } initCommMap(mGroup); - initGroupTopology(); + if (mStrategy != AllReduceStrategyType::NCCL) + { + initGroupTopology(); + } return 0; } diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h index dcfe72b45..b8c2b93a0 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h @@ -67,8 +67,9 @@ class AllreducePlugin : public BasePlugin void destroy() noexcept override; private: - bool isCustomAllReduceSuported(int ranks_per_node) const noexcept; + bool isCustomAllReduceSupported(int ranks_per_node) const noexcept; void initGroupTopology() noexcept; + void setGroupTopology() noexcept; kernels::AllReduceStrategyType selectImplementation( size_t messageSize, int worldSize, nvinfer1::DataType type) noexcept; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp index 339c11500..db4fbeb7f 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp @@ -16,6 +16,8 @@ */ #include "weightOnlyGroupwiseQuantMatmulPlugin.h" +#include + using namespace nvinfer1; using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels::cutlass_kernels; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp index af6ab94cc..9583bcfaa 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp @@ -16,6 +16,8 @@ */ #include "weightOnlyQuantMatmulPlugin.h" +#include + using namespace nvinfer1; using namespace tensorrt_llm::common; using namespace tensorrt_llm::kernels::cutlass_kernels; diff --git a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp index f8abd8780..7fc401128 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp @@ -146,6 +146,8 @@ void InferenceRequest::initBindings(py::module_& m) .def_property("lora_weights", &InferenceRequest::getLoraWeightsUnchecked, &InferenceRequest::setLoraWeights) .def_property("lora_config", &InferenceRequest::getLoraConfigUnchecked, &InferenceRequest::setLoraConfig) .def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming) + .def_property("no_repeat_ngram_size", &InferenceRequest::getNoRepeatNgramSizeUnchecked, + &InferenceRequest::setNoRepeatNgramSize) .def_property_readonly("request_id", &InferenceRequest::getRequestId) .def(py::pickle( [](InferenceRequest const& p) { // __getstate__ diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 75fe48751..a6698bd4d 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -64,6 +64,10 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) { m.doc() = "TensorRT-LLM Python bindings for C++ runtime"; + // Create submodule for executor bindings. + py::module_ executor_submodule = m.def_submodule("executor", "Executor bindings"); + tensorrt_llm::pybind::executor::InitBindings(executor_submodule); + tpr::PromptTuningParams::initBindings(m); tpr::GenerationInput::initBindings(m); tpr::GenerationOutput::initBindings(m); @@ -257,11 +261,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) return py::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty, config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty, - config.earlyStopping); + config.earlyStopping, config.noRepeatNgramSize); }; auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig { - assert(t.size() == 15); + assert(t.size() == 16); tr::SamplingConfig config; config.beamWidth = t[0].cast(); @@ -279,6 +283,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) config.beamSearchDiversityRate = t[12].cast>(); config.lengthPenalty = t[13].cast>(); config.earlyStopping = t[14].cast>(); + config.noRepeatNgramSize = t[15].cast>(); return std::move(config); }; @@ -300,6 +305,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate) .def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty) .def_readwrite("early_stopping", &tr::SamplingConfig::earlyStopping) + .def_readwrite("no_repeat_ngram_size", &tr::SamplingConfig::noRepeatNgramSize) .def(py::pickle(SamplingConfigGetState, SamplingConfigSetState)) .def("__eq__", &tr::SamplingConfig::operator==); @@ -360,11 +366,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) py::enum_(m, "LlmRequestState") .value("REQUEST_STATE_UNKNOWN", tb::LlmRequestState_t::REQUEST_STATE_UNKNOWN) + .value("REQUEST_STATE_ENCODER_INIT", tb::LlmRequestState_t::REQUEST_STATE_ENCODER_INIT) .value("REQUEST_STATE_CONTEXT_INIT", tb::LlmRequestState_t::REQUEST_STATE_CONTEXT_INIT) .value("REQUEST_STATE_GENERATION_IN_PROGRESS", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_IN_PROGRESS) .value("REQUEST_STATE_GENERATION_TO_COMPLETE", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_TO_COMPLETE) - .value("REQUEST_STATE_GENERATION_COMPLETE", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_COMPLETE) - .value("REQUEST_STATE_ENC_INIT", tb::LlmRequestState_t::REQUEST_STATE_ENC_INIT); + .value("REQUEST_STATE_GENERATION_COMPLETE", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_COMPLETE); tpb::NamedTensor::initBindings(m); tpb::LlmRequest::initBindings(m); @@ -396,6 +402,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) tensorNames.attr("RETURN_GENERATION_LOGITS") = py::str(tb::inference_request::kReturnGenerationLogitsTensorName); tensorNames.attr("PROMPT_EMBEDDING_TABLE") = py::str(tb::inference_request::kPromptEmbeddingTableName); tensorNames.attr("PROMPT_VOCAB_SIZE") = py::str(tb::inference_request::kPromptVocabSizeName); + tensorNames.attr("NO_REPEAT_NGRAM_SIZE") = py::str(tb::inference_request::kNoRepeatNgramSizeTensorName); // Output tensor names tensorNames.attr("OUTPUT_IDS") = py::str(tb::inference_request::kOutputIdsTensorName); @@ -443,10 +450,6 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def(py::pickle(gptModelParamsGetState, gptModelParamsSetState)) .def("__eq__", &tb::TrtGptModelOptionalParams::operator==); - // Create submodule for executor bindings. - py::module_ executor_submodule = m.def_submodule("executor", "Executor bindings"); - tensorrt_llm::pybind::executor::InitBindings(executor_submodule); - tpb::GptManager::initBindings(m); py::class_(m, "MemoryCounters") diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index c97d3f930..fa6defc8c 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -44,7 +44,10 @@ namespace tensorrt_llm::pybind::executor void InitBindings(pybind11::module_& m) { - py::enum_(m, "ModelType").value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY); + py::enum_(m, "ModelType") + .value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY) + .value("ENCODER_ONLY", tle::ModelType::kENCODER_ONLY) + .value("ENCODER_DECODER", tle::ModelType::kENCODER_DECODER); py::enum_(m, "BatchingType") .value("STATIC", tle::BatchingType::kSTATIC) @@ -124,6 +127,7 @@ void InitBindings(pybind11::module_& m) py::enum_(m, "RequestStage") .value("QUEUED", tle::RequestStage::kQUEUED) + .value("ENCODER_IN_PROGRESS", tle::RequestStage::kENCODER_IN_PROGRESS) .value("CONTEXT_IN_PROGRESS", tle::RequestStage::kCONTEXT_IN_PROGRESS) .value("GENERATION_IN_PROGRESS", tle::RequestStage::kGENERATION_IN_PROGRESS) .value("GENERATION_COMPLETE", tle::RequestStage::kGENERATION_COMPLETE); @@ -154,14 +158,15 @@ void InitBindings(pybind11::module_& m) std::optional const&, std::optional const&, std::optional const&, std::optional const&, std::optional const&, std::optional const&, std::optional const&, std::optional const&, - std::optional const&>(), + std::optional const&, std::optional const&>(), py::arg("beam_width") = 1, py::arg("top_k") = py::none(), py::arg("top_p") = py::none(), py::arg("top_p_min") = py::none(), py::arg("top_p_reset_ids") = py::none(), py::arg("top_p_decay") = py::none(), py::arg("random_seed") = py::none(), py::arg("temperature") = py::none(), py::arg("min_length") = py::none(), py::arg("beam_search_diversity_rate") = py::none(), py::arg("repetition_penalty") = py::none(), py::arg("presence_penalty") = py::none(), py::arg("frequency_penalty") = py::none(), - py::arg("length_penalty") = py::none(), py::arg("early_stopping") = py::none()) + py::arg("length_penalty") = py::none(), py::arg("early_stopping") = py::none(), + py::arg("no_repeat_ngram_size") = py::none()) .def_property("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth) .def_property("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK) .def_property("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP) @@ -180,16 +185,19 @@ void InitBindings(pybind11::module_& m) .def_property( "frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty) .def_property("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty) - .def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping); + .def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping) + .def_property("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize, + &tle::SamplingConfig::setNoRepeatNgramSize); py::class_(m, "OutputConfig") - .def(py::init(), py::arg("return_log_probs") = false, + .def(py::init(), py::arg("return_log_probs") = false, py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false, - py::arg("exclude_input_from_output") = false) + py::arg("exclude_input_from_output") = false, py::arg("return_encoder_output") = false) .def_readwrite("return_log_probs", &tle::OutputConfig::returnLogProbs) .def_readwrite("return_context_logits", &tle::OutputConfig::returnContextLogits) .def_readwrite("return_generation_logits", &tle::OutputConfig::returnGenerationLogits) - .def_readwrite("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput); + .def_readwrite("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput) + .def_readwrite("return_encoder_output", &tle::OutputConfig::returnEncoderOutput); py::class_(m, "ExternalDraftTokensConfig") .def(py::init, std::optional const&>(), py::arg("tokens"), @@ -214,14 +222,14 @@ void InitBindings(pybind11::module_& m) std::optional const&, std::optional const&, std::optional>, std::optional>, std::optional, std::optional, std::optional, - std::optional, std::optional>(), + std::optional, std::optional, std::optional>(), py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false, py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"), py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(), py::arg("pad_id") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(), py::arg("embedding_bias") = py::none(), py::arg("external_draft_tokens_config") = py::none(), py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none(), - py::arg("logits_post_processor_name") = py::none()) + py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none()) .def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds) .def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens) .def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -238,7 +246,9 @@ void InitBindings(pybind11::module_& m) "prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig) .def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig) .def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName, - &tle::Request::setLogitsPostProcessorName); + &tle::Request::setLogitsPostProcessorName) + .def_property( + "encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds); py::class_(m, "Result") .def(py::init<>()) @@ -247,7 +257,8 @@ void InitBindings(pybind11::module_& m) .def_readwrite("cum_log_probs", &tle::Result::cumLogProbs) .def_readwrite("log_probs", &tle::Result::logProbs) .def_readwrite("context_logits", &tle::Result::contextLogits) - .def_readwrite("generation_logits", &tle::Result::generationLogits); + .def_readwrite("generation_logits", &tle::Result::generationLogits) + .def_readwrite("encoder_output", &tle::Result::encoderOutput); py::class_(m, "Response") .def(py::init(), py::arg("request_id"), py::arg("error_msg")) diff --git a/cpp/tensorrt_llm/pybind/executor/executor.cpp b/cpp/tensorrt_llm/pybind/executor/executor.cpp index 20746282a..87af73ab6 100644 --- a/cpp/tensorrt_llm/pybind/executor/executor.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executor.cpp @@ -36,6 +36,12 @@ Executor::Executor( mExecutor = std::make_unique(modelPath, modelType, executorConfig); } +Executor::Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath, + tle::ModelType modelType, tle::ExecutorConfig const& executorConfig) +{ + mExecutor = std::make_unique(encoderModelPath, decoderModelPath, modelType, executorConfig); +} + Executor::Executor(std::string const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig) { @@ -43,6 +49,16 @@ Executor::Executor(std::string const& engineBuffer, std::string const& jsonConfi std::vector(engineBuffer.begin(), engineBuffer.end()), jsonConfigStr, modelType, executorConfig); } +Executor::Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, + std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig) +{ + mExecutor + = std::make_unique(std::vector(encoderEngineBuffer.begin(), encoderEngineBuffer.end()), + encoderJsonConfigStr, std::vector(decoderEngineBuffer.begin(), decoderEngineBuffer.end()), + decoderJsonConfigStr, modelType, executorConfig); +} + py::object Executor::enter() { TLLM_CHECK(static_cast(mExecutor)); @@ -72,8 +88,16 @@ void Executor::initBindings(py::module_& m) py::class_(m, "Executor") .def(py::init(), py::arg("model_path"), py::arg("model_type"), py::arg("executor_config")) + .def(py::init(), + py::arg("encoder_model_path"), py::arg("decoder_model_path"), py::arg("model_type"), + py::arg("executor_config")) .def(py::init(), py::arg("engine_buffer"), py::arg("json_config_str"), py::arg("model_type"), py::arg("executor_config")) + .def(py::init(), + py::arg("encoder_engine_buffer"), py::arg("encoder_json_config_str"), py::arg("decoder_engine_buffer"), + py::arg("decoder_json_config_str"), py::arg("model_type"), py::arg("executor_config")) .def("shutdown", &Executor::shutdown) .def("__enter__", &Executor::enter) .def("__exit__", &Executor::exit) diff --git a/cpp/tensorrt_llm/pybind/executor/executor.h b/cpp/tensorrt_llm/pybind/executor/executor.h index 13b12e74a..5c950a0ff 100644 --- a/cpp/tensorrt_llm/pybind/executor/executor.h +++ b/cpp/tensorrt_llm/pybind/executor/executor.h @@ -31,9 +31,16 @@ class Executor Executor( std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); + Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath, + tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); + Executor(std::string const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); + Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, + std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig); + pybind11::object enter(); void exit([[maybe_unused]] pybind11::handle type, [[maybe_unused]] pybind11::handle value, [[maybe_unused]] pybind11::handle traceback); diff --git a/cpp/tensorrt_llm/runtime/gptDecoder.cpp b/cpp/tensorrt_llm/runtime/gptDecoder.cpp index 5bb433f01..7c3c4196d 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoder.cpp @@ -67,6 +67,7 @@ void GptDecoder::setup( setupParams->penaltyParams.frequencyPenalty = mSamplingConfig.frequencyPenalty; setupParams->penaltyParams.temperature = mSamplingConfig.temperature; setupParams->penaltyParams.minLength = mSamplingConfig.minLength; + setupParams->penaltyParams.noRepeatNgramSize = mSamplingConfig.noRepeatNgramSize; setupParams->randomSeed = mSamplingConfig.randomSeed; diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp index 13d4e503d..911153b54 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp @@ -54,6 +54,7 @@ SamplingConfig extractSamplingConfig(SamplingConfig const& batchSamplingConfig, extractOptional(samplingConfig.repetitionPenalty, batchSamplingConfig.repetitionPenalty); extractOptional(samplingConfig.presencePenalty, batchSamplingConfig.presencePenalty); extractOptional(samplingConfig.frequencyPenalty, batchSamplingConfig.frequencyPenalty); + extractOptional(samplingConfig.noRepeatNgramSize, batchSamplingConfig.noRepeatNgramSize); // sampling layers extractOptional(samplingConfig.topK, batchSamplingConfig.topK); extractOptional(samplingConfig.topP, batchSamplingConfig.topP); diff --git a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp index 01feb3c14..7d9420d83 100644 --- a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp +++ b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp @@ -113,12 +113,13 @@ ModelConfig createModelConfig( { auto const& config = engineVersionNone ? json.at("builder_config") : json.at("pretrained_config"); - auto const useCrossAttention = parseJsonFieldOptional(config, "cross_attention"); + auto const* const archField = "architecture"; auto const* const numLayersField = engineVersionNone ? "num_layers" : "num_hidden_layers"; auto const* const numHeadsField = engineVersionNone ? "num_heads" : "num_attention_heads"; auto const* const numKvHeadsField = engineVersionNone ? "num_kv_heads" : "num_key_value_heads"; auto const* const mlpHiddenSizeField = engineVersionNone ? "mlp_hidden_size" : "intermediate_size"; + auto const arch = engineVersionNone ? std::string("none") : config.at(archField).template get(); auto const numLayers = config.at(numLayersField).template get(); auto const numHeads = config.at(numHeadsField).template get() / tensorParallelism; auto const layerStringTypes @@ -145,10 +146,14 @@ ModelConfig createModelConfig( modelConfig.setNbKvHeads(numKvHeads); modelConfig.setLayerTypes(layerTypes); - if (useCrossAttention.has_value()) - { - modelConfig.useCrossAttention(useCrossAttention.value()); - } + // only enable cross attention for the decoder in encoder-decoder model + // TODO: add cross_attention and has_token_type_embedding as fields in pretrained config + auto const useCrossAttention = arch == std::string("DecoderModel") ? true : false; + auto const usePositionEmbedding = parseJsonFieldOr(config, "has_position_embedding", false); + auto const useTokenTypeEmbedding = parseJsonFieldOr(config, "has_token_type_embedding", false); + modelConfig.setUseCrossAttention(useCrossAttention); + modelConfig.setUsePositionEmbedding(usePositionEmbedding); + modelConfig.setUseTokenTypeEmbedding(useTokenTypeEmbedding); if (mlpHiddenSize.has_value()) { diff --git a/cpp/tensorrt_llm/runtime/iTensor.cpp b/cpp/tensorrt_llm/runtime/iTensor.cpp index 4597df563..52923a7e4 100644 --- a/cpp/tensorrt_llm/runtime/iTensor.cpp +++ b/cpp/tensorrt_llm/runtime/iTensor.cpp @@ -34,6 +34,32 @@ ITensor::UniquePtr ITensor::slice(SharedPtr tensor, std::size_t offset, std::siz return std::make_unique(std::move(tensor), offset, size); } +ITensor::UniquePtr ITensor::slice(SharedPtr tensor, Shape const& offsetDims, ITensor::DimType64 size) +{ + auto shape = tensor->getShape(); + TLLM_CHECK(offsetDims.nbDims > 0); + TLLM_CHECK(shape.nbDims >= offsetDims.nbDims); + + Shape strides = ITensor::strides(shape); + DimType64 offset{0}; + for (SizeType32 di = 0; di < offsetDims.nbDims; di++) + { + TLLM_CHECK(0 <= offsetDims.d[di] && offsetDims.d[di] < shape.d[di]); + offset += offsetDims.d[di] * strides.d[di]; + } + TLLM_CHECK(offsetDims.d[offsetDims.nbDims - 1] + size <= shape.d[offsetDims.nbDims - 1]); + + Shape dims; + dims.nbDims = shape.nbDims - offsetDims.nbDims + 1; + dims.d[0] = size; + for (SizeType32 di = 1; di < dims.nbDims; di++) + { + dims.d[di] = shape.d[di - 1 + offsetDims.nbDims]; + } + + return std::make_unique(std::move(tensor), offset, volume(dims), dims); +} + ITensor::UniquePtr ITensor::view(IBuffer::SharedPtr buffer, nvinfer1::Dims const& dims) { auto const size = buffer->getSize(); diff --git a/cpp/tensorrt_llm/runtime/loraCache.cpp b/cpp/tensorrt_llm/runtime/loraCache.cpp index 278df30aa..3bdadcf06 100644 --- a/cpp/tensorrt_llm/runtime/loraCache.cpp +++ b/cpp/tensorrt_llm/runtime/loraCache.cpp @@ -35,6 +35,20 @@ namespace tensorrt_llm::runtime { +LoraExpectedException::LoraExpectedException(std::string const& msg) + : std::runtime_error(msg) +{ +} + +LoraExpectedException::~LoraExpectedException() noexcept = default; + +LoraCacheFullException::LoraCacheFullException(std::string const& msg) + : LoraExpectedException(msg) +{ +} + +LoraCacheFullException::~LoraCacheFullException() noexcept = default; + LoraCachePageManager::LoraCachePageManager(LoraCachePageManagerConfig const& config, BufferManager const& bufferManager) : mConfig(config) { @@ -305,7 +319,7 @@ std::vector LoraCache::claimPagesWithEvict(SizeType32 numPages) } if (it == mDoneTasks.rend()) { - TLLM_THROW("Cache is full. There are no done tasks to evict"); + throw LoraCacheFullException("Cache is full. There are no done tasks to evict"); } TLLM_LOG_DEBUG("evicting " + std::to_string(taskIdsToEvict.size())); diff --git a/cpp/tensorrt_llm/runtime/tllmBuffers.h b/cpp/tensorrt_llm/runtime/tllmBuffers.h index d64a333bd..02480210e 100644 --- a/cpp/tensorrt_llm/runtime/tllmBuffers.h +++ b/cpp/tensorrt_llm/runtime/tllmBuffers.h @@ -46,10 +46,9 @@ class BaseAllocator public: using ValueType = void; using PointerType = ValueType*; - using SizeType32 = std::size_t; static auto constexpr kMemoryType = memoryType; - PointerType allocate(SizeType32 n) + PointerType allocate(std::size_t n) { PointerType ptr{}; static_cast(this)->allocateImpl(&ptr, n); @@ -58,7 +57,7 @@ class BaseAllocator return ptr; } - void deallocate(PointerType ptr, SizeType32 n) + void deallocate(PointerType ptr, std::size_t n) { if (ptr) { @@ -82,13 +81,13 @@ class CudaAllocator : public BaseAllocator CudaAllocator() noexcept = default; protected: - void allocateImpl(PointerType* ptr, SizeType32 n) // NOLINT(readability-convert-member-functions-to-static) + void allocateImpl(PointerType* ptr, std::size_t n) // NOLINT(readability-convert-member-functions-to-static) { TLLM_CUDA_CHECK(::cudaMalloc(ptr, n)); } void deallocateImpl( // NOLINT(readability-convert-member-functions-to-static) - PointerType ptr, [[maybe_unused]] SizeType32 n) + PointerType ptr, [[maybe_unused]] std::size_t n) { TLLM_CUDA_CHECK_FREE_RESOURCE(::cudaFree(ptr)); } @@ -113,12 +112,12 @@ class CudaAllocatorAsync : public BaseAllocatorget())); } - void deallocateImpl(PointerType ptr, [[maybe_unused]] SizeType32 n) + void deallocateImpl(PointerType ptr, [[maybe_unused]] std::size_t n) { TLLM_CUDA_CHECK_FREE_RESOURCE(::cudaFreeAsync(ptr, mCudaStream->get())); } @@ -136,14 +135,14 @@ class UVMAllocator : public BaseAllocator UVMAllocator() noexcept = default; protected: - void allocateImpl(PointerType* ptr, SizeType32 n) // NOLINT(readability-convert-member-functions-to-static) + void allocateImpl(PointerType* ptr, std::size_t n) // NOLINT(readability-convert-member-functions-to-static) { TLLM_CUDA_CHECK(::cudaMallocManaged(ptr, n)); // TLLM_CUDA_CHECK(::cudaMemAdvise(ptr, n, cudaMemAdviseSetPreferredLocation, 0)); } void deallocateImpl( // NOLINT(readability-convert-member-functions-to-static) - PointerType ptr, [[maybe_unused]] SizeType32 n) + PointerType ptr, [[maybe_unused]] std::size_t n) { TLLM_CUDA_CHECK_FREE_RESOURCE(::cudaFree(ptr)); } @@ -158,13 +157,13 @@ class PinnedAllocator : public BaseAllocator HostAllocator() noexcept = default; protected: - void allocateImpl(PointerType* ptr, SizeType32 n) // NOLINT(readability-convert-member-functions-to-static) + void allocateImpl(PointerType* ptr, std::size_t n) // NOLINT(readability-convert-member-functions-to-static) { *ptr = std::malloc(n); if (*ptr == nullptr) @@ -188,7 +187,7 @@ class HostAllocator : public BaseAllocator } void deallocateImpl( // NOLINT(readability-convert-member-functions-to-static) - PointerType ptr, [[maybe_unused]] SizeType32 n) + PointerType ptr, [[maybe_unused]] std::size_t n) { std::free(ptr); } @@ -202,9 +201,8 @@ class BorrowingAllocator : public BaseAllocator, public: using Base = BaseAllocator, memoryType, false>; using PointerType = typename Base::PointerType; - using SizeType32 = typename Base::SizeType32; - BorrowingAllocator(void* ptr, SizeType32 capacity) + BorrowingAllocator(void* ptr, std::size_t capacity) : mPtr(ptr) , mCapacity(capacity) { @@ -213,7 +211,7 @@ class BorrowingAllocator : public BaseAllocator, } protected: - void allocateImpl(PointerType* ptr, SizeType32 n) // NOLINT(readability-convert-member-functions-to-static) + void allocateImpl(PointerType* ptr, std::size_t n) // NOLINT(readability-convert-member-functions-to-static) { if (n <= mCapacity) { @@ -226,13 +224,13 @@ class BorrowingAllocator : public BaseAllocator, } void deallocateImpl( // NOLINT(readability-convert-member-functions-to-static) - [[maybe_unused]] PointerType ptr, [[maybe_unused]] SizeType32 n) + [[maybe_unused]] PointerType ptr, [[maybe_unused]] std::size_t n) { } private: PointerType mPtr; - SizeType32 mCapacity; + std::size_t mCapacity; }; using CpuBorrowingAllocator = BorrowingAllocator; @@ -254,17 +252,14 @@ class MemoryPool : public BaseAllocator, TAllocator::kMem public: using Base = BaseAllocator, TAllocator::kMemoryType, false>; using PointerType = typename Base::PointerType; - using SizeType32 = typename Base::SizeType32; using Allocator = TAllocator; static_assert(std::is_same_v); - static_assert(std::is_same_v); - static SizeType32 constexpr kInitialChunkSize{SizeType32{1} << 30}; // 1 GB - static SizeType32 constexpr kChunkResizeFactor{2}; - static SizeType32 constexpr kAlignment{256}; + static std::size_t constexpr kInitialChunkSize{std::size_t{1} << 29}; // 512 MB + static std::size_t constexpr kAlignment{256}; - explicit MemoryPool(SizeType32 chunkSize = kInitialChunkSize, Allocator allocator = Allocator{}) + explicit MemoryPool(std::size_t chunkSize = kInitialChunkSize, Allocator allocator = Allocator{}) : mChunkSize(chunkSize) , mAllocator{allocator} { @@ -289,36 +284,36 @@ class MemoryPool : public BaseAllocator, TAllocator::kMem mAllocatedChunks.clear(); } - [[nodiscard]] SizeType32 getChunkSize() const + [[nodiscard]] std::size_t getChunkSize() const { std::lock_guard lock(mLock); return mChunkSize; } - void setChunkSize(SizeType32 chunkSize) + void setChunkSize(std::size_t chunkSize) { std::lock_guard lock(mLock); mChunkSize = chunkSize; } - [[nodiscard]] SizeType32 getUsedSize() const + [[nodiscard]] std::size_t getUsedSize() const { std::lock_guard lock(mLock); - return std::accumulate(mMemorySegments.cbegin(), mMemorySegments.cend(), SizeType32{0}, - [](SizeType32 sum, auto const& chunk) { return chunk.tag ? sum + chunk.size : sum; }); + return std::accumulate(mMemorySegments.cbegin(), mMemorySegments.cend(), std::size_t{0}, + [](std::size_t sum, auto const& chunk) { return chunk.tag ? sum + chunk.size : sum; }); } - [[nodiscard]] SizeType32 getReservedSize() const + [[nodiscard]] std::size_t getReservedSize() const { std::lock_guard lock(mLock); - return std::accumulate(mAllocatedChunks.cbegin(), mAllocatedChunks.cend(), SizeType32{0}, - [](SizeType32 sum, auto const& chunk) { return sum + std::get<1>(chunk); }); + return std::accumulate(mAllocatedChunks.cbegin(), mAllocatedChunks.cend(), std::size_t{0}, + [](std::size_t sum, auto const& chunk) { return sum + std::get<1>(chunk); }); } class MemorySegment { public: - MemorySegment(PointerType basePointer, SizeType32 size, SizeType32 offset = 0, PointerType tag = nullptr) + MemorySegment(PointerType basePointer, std::size_t size, std::size_t offset = 0, PointerType tag = nullptr) : basePointer{basePointer} , size{size} , offset{offset} @@ -327,8 +322,8 @@ class MemoryPool : public BaseAllocator, TAllocator::kMem } PointerType const basePointer; - SizeType32 size; - SizeType32 offset; + std::size_t size; + std::size_t offset; PointerType tag; }; @@ -343,17 +338,17 @@ class MemoryPool : public BaseAllocator, TAllocator::kMem void logSegments() const; protected: - void allocateImpl(PointerType* ptr, SizeType32 requestedSize); + void allocateImpl(PointerType* ptr, std::size_t requestedSize); - void deallocateImpl(PointerType tag, SizeType32 n); + void deallocateImpl(PointerType tag, std::size_t n); private: - SizeType32 mChunkSize; + std::size_t mChunkSize; TAllocator mAllocator; std::mutex mutable mLock{}; std::list mMemorySegments = {}; - std::vector> mAllocatedChunks = {}; + std::vector> mAllocatedChunks = {}; void allocateChunk() { @@ -365,9 +360,10 @@ class MemoryPool : public BaseAllocator, TAllocator::kMem }; template -void MemoryPool::allocateImpl(MemoryPool::PointerType* ptr, MemoryPool::SizeType32 requestedSize) +void MemoryPool::allocateImpl(MemoryPool::PointerType* ptr, std::size_t requestedSize) { std::lock_guard lock(mLock); + // Align requested size to kAlignment // When requesting 0 B, default to allocating 1 B (from "Effective C++", item 51) // See https://stackoverflow.com/questions/2660076/returning-aligned-memory-with-new @@ -382,15 +378,13 @@ void MemoryPool::allocateImpl(MemoryPool::PointerType* ptr, MemoryPo if (it == mMemorySegments.end()) { - // There is no space available for this request - // If the request is bigger than mChunkSize / chunkResizeFactor, adapt mChunkSize to request * - // chunkResizeFactor - // Allocate more space in mChunkSize, and fulfill this request + // There is no space available for this request: + // Adapt mChunkSize to the aligned requested size in case it doesn't fit, + // allocate a chunk of mChunkSize and fulfill this request TLLM_LOG_DEBUG("MemoryPool: Needs more space to accommodate request of %zu B", requestedSize); - auto const minChunkSize = alignedRequest * kChunkResizeFactor; - if (mChunkSize < minChunkSize) + if (mChunkSize < alignedRequest) { - mChunkSize = minChunkSize; + mChunkSize = alignedRequest; TLLM_LOG_DEBUG("MemoryPool: Increasing chunk size to %zu B", mChunkSize); } allocateChunk(); @@ -417,7 +411,7 @@ void MemoryPool::allocateImpl(MemoryPool::PointerType* ptr, MemoryPo } template -void MemoryPool::deallocateImpl(PointerType tag, SizeType32 n) +void MemoryPool::deallocateImpl(PointerType tag, std::size_t n) { std::lock_guard lock(mLock); auto it = std::find_if(mMemorySegments.begin(), mMemorySegments.end(), @@ -477,19 +471,18 @@ class PoolAllocator : public BaseAllocator, TAllocator public: using Base = BaseAllocator, TAllocator::kMemoryType, false>; using PointerType = typename Base::PointerType; - using SizeType32 = typename Base::SizeType32; using PoolType = MemoryPool; static PoolType& getPool(); protected: - void allocateImpl(PointerType* ptr, SizeType32 n) // NOLINT(readability-convert-member-functions-to-static) + void allocateImpl(PointerType* ptr, std::size_t n) // NOLINT(readability-convert-member-functions-to-static) { *ptr = getPool().allocate(n); } void deallocateImpl( // NOLINT(readability-convert-member-functions-to-static) - typename TAllocator::PointerType ptr, SizeType32 n) + typename TAllocator::PointerType ptr, std::size_t n) { getPool().deallocate(ptr, n); } diff --git a/cpp/tensorrt_llm/runtime/transformerBuffers.cpp b/cpp/tensorrt_llm/runtime/transformerBuffers.cpp index fd965bd7d..440da6441 100644 --- a/cpp/tensorrt_llm/runtime/transformerBuffers.cpp +++ b/cpp/tensorrt_llm/runtime/transformerBuffers.cpp @@ -87,10 +87,7 @@ TransformerBuffers::TransformerBuffers( } else { - char* disableReuseChar = std::getenv("TRTLLM_DISABLE_OOTB_KVCACHE_REUSE"); - bool reuse = (disableReuseChar == nullptr || std::string(disableReuseChar) != "ON"); - - int32_t extraKeyValBufferNum = reuse ? 1 : localNbLayers; + constexpr int32_t extraKeyValBufferNum = 1; presentKeysValsAlt = utils::createBufferVector(runtime, extraKeyValBufferNum, MemoryType::kGPU, kvDtype); } @@ -737,42 +734,32 @@ void TransformerBuffers::getRuntimeBuffers(RuntimeBuffers const* runtimeBuffers, kvCacheShape = presentKeysValsAlt.at(0)->getShape(); kvCacheShape.d[3] = 0; } - char* disableReuseChar = std::getenv("TRTLLM_DISABLE_OOTB_KVCACHE_REUSE"); - bool reuse = (disableReuseChar == nullptr || std::string(disableReuseChar) != "ON"); // TODO: fix for recurrentgemma for (int32_t idx = 0; idx < localNbLayers; ++idx) { TensorPtr input; TensorPtr output; - if (reuse) + // We will make current layer's output KV-cache overwrite previous layers input KV-cache + // buffer id: ... 5, 6, 7, 8, 9, ... + // layer n: out in + // layer n+1: out in + // layer n+2 out in + // And when finish a step, we will make every layer's in/out buffer index subtract 1 in + // a circular buffer way to make sure current outputs become next step's inputs. + int32_t input_ind = idx - (step % (localNbLayers + 1)); // Subtract 1 for every step. + if (input_ind < 0) { - // We will make current layer's output KV-cache overwrite previous layers input KV-cache - // buffer id: ... 5, 6, 7, 8, 9, ... - // layer n: out in - // layer n+1: out in - // layer n+2 out in - // And when finish a step, we will make every layer's in/out buffer index subtract 1 in - // a circular buffer way to make sure current outputs become next step's inputs. - int32_t input_ind = idx - (step % (localNbLayers + 1)); // Subtract 1 for every step. - if (input_ind < 0) - { - // When underflow, go to the back to achieve a circular buffers. - input_ind = localNbLayers + 1 + input_ind; - } - // Output buffer is just before input buffer. When input is buffer 0, - // output should use the back buffer to achieve circular buffers. - int32_t output_ind = input_ind > 0 ? input_ind - 1 : localNbLayers; - - // We only allocate localNbLayers of normal buffers. If index is overflow, use the extra buffer. - input = input_ind < localNbLayers ? presentKeysVals[input_ind] : presentKeysValsAlt[0]; - output = output_ind < localNbLayers ? presentKeysVals[output_ind] : presentKeysValsAlt[0]; - } - else - { - input = step % 2 ? presentKeysVals[idx] : presentKeysValsAlt[idx]; - output = step % 2 ? presentKeysValsAlt[idx] : presentKeysVals[idx]; + // When underflow, go to the back to achieve a circular buffers. + input_ind = localNbLayers + 1 + input_ind; } + // Output buffer is just before input buffer. When input is buffer 0, + // output should use the back buffer to achieve circular buffers. + int32_t output_ind = input_ind > 0 ? input_ind - 1 : localNbLayers; + + // We only allocate localNbLayers of normal buffers. If index is overflow, use the extra buffer. + input = input_ind < localNbLayers ? presentKeysVals[input_ind] : presentKeysValsAlt[0]; + output = output_ind < localNbLayers ? presentKeysVals[output_ind] : presentKeysValsAlt[0]; if (step == 0) { diff --git a/cpp/tensorrt_llm/runtime/utils/debugUtils.cu b/cpp/tensorrt_llm/runtime/utils/debugUtils.cu index a9265dfa7..4ec3d4d8a 100644 --- a/cpp/tensorrt_llm/runtime/utils/debugUtils.cu +++ b/cpp/tensorrt_llm/runtime/utils/debugUtils.cu @@ -14,15 +14,16 @@ * limitations under the License. */ -#include "debugUtils.h" +#include "tensorrt_llm/runtime/utils/debugUtils.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" +#include namespace { - -__global__ void checkTensorNanKernel(float const* data, std::size_t size, int* foundNan) +template +__global__ void checkTensorNanKernel(T const* data, std::size_t size, int* foundNan) { auto tidx = blockIdx.x * blockDim.x + threadIdx.x; @@ -30,7 +31,7 @@ __global__ void checkTensorNanKernel(float const* data, std::size_t size, int* f for (auto idx = tidx; idx < size; idx += blockDim.x * gridDim.x) { - auto value = data[idx]; + auto value = static_cast(data[idx]); if (isnan(value)) { found = 1; @@ -47,20 +48,123 @@ namespace tc = tensorrt_llm::common; namespace tensorrt_llm::runtime::utils { -void invokeCheckTensorNanKernel(float const* data, std::size_t size, int* foundNan, cudaStream_t stream) +template +void invokeCheckTensorNanKernel(T const* data, std::size_t size, int* foundNan, cudaStream_t stream) { constexpr uint32_t kThreadsPerCta = 256; checkTensorNanKernel<<>>(data, size, foundNan); } -bool tensorHasNan(IBuffer const& tensor, BufferManager const& manager) +template void invokeCheckTensorNanKernel(float const* data, std::size_t size, int* foundNan, cudaStream_t stream); +template void invokeCheckTensorNanKernel(half const* data, std::size_t size, int* foundNan, cudaStream_t stream); +template void invokeCheckTensorNanKernel( + __nv_bfloat16 const* data, std::size_t size, int* foundNan, cudaStream_t stream); + +template +void printLogitsKeyInfo(ITensor const& tensor, std::string const& infoStr) +{ + auto const& shape = tensor.getShape(); + auto const volume = ITensor::volume(shape); + + BufferManager::ITensorPtr host{}; + T const* hostData; + if (tensor.getMemoryType() == MemoryType::kGPU) + { + auto streamPtr = std::make_shared(); + BufferManager manager{streamPtr}; + host = manager.copyFrom(tensor, MemoryType::kCPU); + streamPtr->synchronize(); + hostData = bufferCast(*host); + } + else + { + hostData = bufferCast(tensor); + } + + std::stringstream ss; + ss << infoStr; + ss << " Shape: " << shape; + ss << "; Top 5: "; + for (size_t ki = 0; ki < 5; ++ki) + { + ss << static_cast(hostData[ki]) << ", "; + } + + ss << " Last 5: "; + for (size_t ki = volume - 6; ki < volume; ++ki) + { + ss << static_cast(hostData[ki]) << ", "; + } + + // find max, min, avg + double mSum = 0.f; + float mMax = -FLT_MAX; + float mMin = FLT_MAX; + + for (size_t ki = 0; ki < volume; ++ki) + { + float value = static_cast(hostData[ki]); + mSum += value; + if (value > mMax) + { + mMax = value; + } + if (value < mMin) + { + mMin = value; + } + } + float mAvg = mSum / volume; + + ss << " avg: " << mAvg << ", min: " << mMin << ", max: " << mMax << std::endl; + + TLLM_LOG_TRACE(ss.str()); +} + +template void printLogitsKeyInfo(ITensor const& tensor, std::string const& infoStr); +template void printLogitsKeyInfo(ITensor const& tensor, std::string const& infoStr); +template void printLogitsKeyInfo<__nv_bfloat16>(ITensor const& tensor, std::string const& infoStr); + +template +bool tensorHasNan(ITensor const& tensor, BufferManager const& manager, std::string const& infoStr) { + printLogitsKeyInfo(tensor, infoStr); auto foundNan = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); auto foundNanPtr = bufferCast(*foundNan); foundNanPtr[0] = 0; auto const size = tensor.getSize(); - invokeCheckTensorNanKernel(bufferCast(tensor), size, foundNanPtr, manager.getStream().get()); + invokeCheckTensorNanKernel(bufferCast(tensor), size, foundNanPtr, manager.getStream().get()); manager.getStream().synchronize(); return static_cast(foundNanPtr[0]); } + +template bool tensorHasNan(ITensor const& tensor, BufferManager const& manager, std::string const& infoStr); +template bool tensorHasNan(ITensor const& tensor, BufferManager const& manager, std::string const& infoStr); +template bool tensorHasNan<__nv_bfloat16>( + ITensor const& tensor, BufferManager const& manager, std::string const& infoStr); + +bool tensorHasNan( + size_t M, size_t K, nvinfer1::DataType type, void const* data, cudaStream_t stream, std::string const& infoStr) +{ + auto tensorView = ITensor::wrap( + const_cast(data), type, ITensor::makeShape({static_cast(M), static_cast(K)})); + auto manager = BufferManager(std::make_shared(stream)); + if (type == nvinfer1::DataType::kFLOAT) + { + return tensorHasNan(*tensorView, manager, infoStr); + } + else if (type == nvinfer1::DataType::kHALF) + { + return tensorHasNan(*tensorView, manager, infoStr); + } + else if (type == nvinfer1::DataType::kBF16) + { + return tensorHasNan<__nv_bfloat16>(*tensorView, manager, infoStr); + } + else + { + TLLM_THROW("Not supported type for Nan check"); + } +} + } // namespace tensorrt_llm::runtime::utils diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp index 565b93069..b03863ebe 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp @@ -106,8 +106,8 @@ void FtDynamicDecode::setup(size_t const batch_size, size_t const beam_width, th::optional min_length_opt, th::optional length_penalty_opt, th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, - th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, bool output_log_probs, - bool cum_log_probs) + th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, + th::optional no_repeat_ngram_size_opt, bool output_log_probs, bool cum_log_probs) { auto stream = at::cuda::getCurrentCUDAStream().stream(); dynamic_decode_layer_->setStream(stream); @@ -118,6 +118,7 @@ void FtDynamicDecode::setup(size_t const batch_size, size_t const beam_width, safeInsert(presence_penalty_opt, setupParams->penaltyParams.presencePenalty); safeInsert(frequency_penalty_opt, setupParams->penaltyParams.frequencyPenalty); safeInsert(min_length_opt, setupParams->penaltyParams.minLength); + safeInsert(no_repeat_ngram_size_opt, setupParams->penaltyParams.noRepeatNgramSize); safeInsert(runtime_top_k_opt, setupParams->samplingParams.runtime_top_k); safeInsert(runtime_top_p_opt, setupParams->samplingParams.runtime_top_p); safeInsert(random_seed_opt, setupParams->randomSeed); @@ -141,17 +142,16 @@ void FtDynamicDecode::forward(th::Tensor const& logits, int const step, int c th::optional sequence_limit_length_opt, th::optional stop_words_list_ptrs_opt, th::optional stop_words_lens_opt, int32_t const max_stop_words_len, th::optional bad_words_list_ptrs_opt, th::optional bad_words_lens_opt, - int32_t const max_bad_words_len, th::optional no_repeat_ngram_size_opt, - th::optional src_cache_indirection_opt, th::Tensor& output_token_ids, th::Tensor& newTokens, - th::Tensor& should_stop, th::optional finished_input, th::optional finished_output, - th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, - th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, - th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, - th::optional beam_hyps_output_ids_cba_opt, th::optional beam_hyps_seq_len_cba_opt, - th::optional beam_hyps_cum_log_probs_cba_opt, th::optional beam_hyps_normed_scores_cba_opt, - th::optional beam_hyps_log_probs_cba_opt, th::optional beam_hyps_min_normed_scores_opt, - th::optional beam_hyps_num_beams_opt, th::optional beam_hyps_is_done_opt, - bool const use_beam_hyps) + int32_t const max_bad_words_len, th::optional src_cache_indirection_opt, th::Tensor& output_token_ids, + th::Tensor& newTokens, th::Tensor& should_stop, th::optional finished_input, + th::optional finished_output, th::optional sequence_lengths_opt, + th::optional cum_log_probs_opt, th::optional output_log_probs_opt, + th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, + th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_cba_opt, + th::optional beam_hyps_seq_len_cba_opt, th::optional beam_hyps_cum_log_probs_cba_opt, + th::optional beam_hyps_normed_scores_cba_opt, th::optional beam_hyps_log_probs_cba_opt, + th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, + th::optional beam_hyps_is_done_opt, bool const use_beam_hyps) { auto forwardParams = std::make_shared(step, static_cast(ite), max_input_length, max_attention_window, sink_token_length, local_batch_size, convert_tensor(end_id)); @@ -167,7 +167,6 @@ void FtDynamicDecode::forward(th::Tensor const& logits, int const step, int c safeUpdate(bad_words_list_ptrs_opt, forwardParams->bad_words_ptr); safeUpdate(bad_words_lens_opt, forwardParams->bad_words_lengths); forwardParams->max_bad_words_len = max_bad_words_len; - safeUpdate(no_repeat_ngram_size_opt, forwardParams->no_repeat_ngram_size); safeUpdate(src_cache_indirection_opt, forwardParams->src_cache_indirection); auto const& output_ids_converted = convert_tensor(output_token_ids); @@ -263,8 +262,8 @@ void DynamicDecodeOp::setup(int64_t const batch_size, int64_t const beam_width, th::optional min_length_opt, th::optional length_penalty_opt, th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, - th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, bool output_log_probs, - bool cum_log_probs) + th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, + th::optional no_repeat_ngram_size_opt, bool output_log_probs, bool cum_log_probs) { // TODO: Revise DynamicDecodeLayer and make the decode arguments consistent. // TODO: add parameters "normalize_log_probs" and "topKMedusaHeads" @@ -277,6 +276,7 @@ void DynamicDecodeOp::setup(int64_t const batch_size, int64_t const beam_width, CHECK_OPTIONAL_CPU_INPUT(min_length_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(length_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(early_stopping_opt, torch::kInt32); + CHECK_OPTIONAL_CPU_INPUT(no_repeat_ngram_size_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(beam_search_diversity_rate_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(random_seed_opt, torch::kInt64); CHECK_OPTIONAL_INPUT(top_p_decay_opt, torch::kFloat); @@ -286,7 +286,7 @@ void DynamicDecodeOp::setup(int64_t const batch_size, int64_t const beam_width, dynamic_decode_->setup(static_cast(batch_size), static_cast(beam_width), runtime_top_k_opt, runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, frequency_penalty_opt, min_length_opt, length_penalty_opt, early_stopping_opt, beam_search_diversity_rate_opt, random_seed_opt, - top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt, output_log_probs, cum_log_probs); + top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt, no_repeat_ngram_size_opt, output_log_probs, cum_log_probs); } th::Tensor DynamicDecodeOp::forward( @@ -308,7 +308,6 @@ th::Tensor DynamicDecodeOp::forward( th::optional bad_words_list_ptrs_opt, // [BS][2, bad_words_length], int64 th::optional bad_words_lens_opt, // [BS], int int64_t const max_bad_words_len, // - th::optional no_repeat_ngram_size_opt, // [BS], int th::optional src_cache_indirection_opt, // [local_BS, BM, MSL], int // Outputs th::Tensor output_token_ids, // [BS, BM, MSL], variables for output @@ -347,7 +346,6 @@ th::Tensor DynamicDecodeOp::forward( CHECK_OPTIONAL_INPUT(stop_words_lens_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(bad_words_list_ptrs_opt, torch::kInt64); CHECK_OPTIONAL_INPUT(bad_words_lens_opt, torch::kInt32); - CHECK_OPTIONAL_INPUT(no_repeat_ngram_size_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(src_cache_indirection_opt, torch::kInt32); CHECK_INPUT(output_token_ids, torch::kInt32); CHECK_INPUT(newTokens, torch::kInt32); @@ -368,7 +366,7 @@ th::Tensor DynamicDecodeOp::forward( static_cast(sink_token_length), static_cast(ite), static_cast(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt, sequence_limit_length_opt, stop_words_list_ptrs_opt, stop_words_lens_opt, static_cast(max_stop_words_len), bad_words_list_ptrs_opt, bad_words_lens_opt, - static_cast(max_bad_words_len), no_repeat_ngram_size_opt, src_cache_indirection_opt, + static_cast(max_bad_words_len), src_cache_indirection_opt, // Outputs output_token_ids, newTokens, should_stop, finished_input, finished_output, sequence_lengths_opt, cum_log_probs_opt, output_log_probs_opt, output_log_probs_tiled_opt, parent_ids_opt, tgt_cache_indirection_opt, diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h index da3f19ab6..805370d71 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h @@ -34,7 +34,8 @@ class IFtDynamicDecode th::optional length_penalty_opt, th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, th::optional top_p_min_opt, - th::optional top_p_reset_ids_opt, bool output_log_probs, bool cum_log_probs) + th::optional top_p_reset_ids_opt, th::optional no_repeat_ngram_size_opt, + bool output_log_probs, bool cum_log_probs) = 0; virtual void forward(th::Tensor const& logits, int const step, int const max_input_length, @@ -43,9 +44,9 @@ class IFtDynamicDecode th::optional sequence_limit_length_opt, th::optional stop_words_list_ptrs_opt, th::optional stop_words_lens_opt, int32_t const max_stop_words_len, th::optional bad_words_list_ptrs_opt, th::optional bad_words_lens_opt, - int32_t const max_bad_words_len, th::optional no_repeat_ngram_size_opt, - th::optional src_cache_indirection_opt, th::Tensor& output_token_ids, th::Tensor& newTokens, - th::Tensor& should_stop, th::optional finished_input, th::optional finished_output, + int32_t const max_bad_words_len, th::optional src_cache_indirection_opt, + th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, + th::optional finished_input, th::optional finished_output, th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, @@ -71,7 +72,8 @@ class FtDynamicDecode : public IFtDynamicDecode th::optional length_penalty_opt, th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, th::optional top_p_min_opt, - th::optional top_p_reset_ids_opt, bool output_log_probs, bool cum_log_probs) override; + th::optional top_p_reset_ids_opt, th::optional no_repeat_ngram_size_opt, + bool output_log_probs, bool cum_log_probs) override; void forward(th::Tensor const& logits, int const step, int const max_input_length, int const max_attention_window, int const sink_token_length, uint64_t const ite, int const local_batch_size, th::Tensor end_id, @@ -79,9 +81,9 @@ class FtDynamicDecode : public IFtDynamicDecode th::optional sequence_limit_length_opt, th::optional stop_words_list_ptrs_opt, th::optional stop_words_lens_opt, int32_t const max_stop_words_len, th::optional bad_words_list_ptrs_opt, th::optional bad_words_lens_opt, - int32_t const max_bad_words_len, th::optional no_repeat_ngram_size_opt, - th::optional src_cache_indirection_opt, th::Tensor& output_token_ids, th::Tensor& newTokens, - th::Tensor& should_stop, th::optional finished_input, th::optional finished_output, + int32_t const max_bad_words_len, th::optional src_cache_indirection_opt, + th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, + th::optional finished_input, th::optional finished_output, th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, @@ -110,7 +112,8 @@ class DynamicDecodeOp : public th::jit::CustomClassHolder th::optional length_penalty_opt, th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, th::optional top_p_decay_opt, th::optional top_p_min_opt, - th::optional top_p_reset_ids_opt, bool output_log_probs, bool cum_log_probs); + th::optional top_p_reset_ids_opt, th::optional no_repeat_ngram_size_opt, + bool output_log_probs, bool cum_log_probs); th::Tensor forward(th::Tensor const& logits, int64_t const step, int64_t const max_input_length, int64_t const max_attention_window, int64_t const sink_token_length, int64_t const ite, @@ -119,13 +122,13 @@ class DynamicDecodeOp : public th::jit::CustomClassHolder th::optional stop_words_list_ptrs_opt, th::optional stop_words_lens_opt, int64_t const max_stop_words_len, th::optional bad_words_list_ptrs_opt, th::optional bad_words_lens_opt, int64_t const max_bad_words_len, - th::optional no_repeat_ngram_size_opt, th::optional src_cache_indirection_opt, - th::Tensor output_token_ids, th::Tensor newTokens, th::optional finished_input, - th::optional finished_output, th::optional sequence_lengths_opt, - th::optional cum_log_probs_opt, th::optional output_log_probs_opt, - th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, - th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_cba_opt, - th::optional beam_hyps_seq_len_cba_opt, th::optional beam_hyps_cum_log_probs_cba_opt, + th::optional src_cache_indirection_opt, th::Tensor output_token_ids, th::Tensor newTokens, + th::optional finished_input, th::optional finished_output, + th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, + th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, + th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, + th::optional beam_hyps_output_ids_cba_opt, th::optional beam_hyps_seq_len_cba_opt, + th::optional beam_hyps_cum_log_probs_cba_opt, th::optional beam_hyps_normed_scores_cba_opt, th::optional beam_hyps_log_probs_cba_opt, th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, th::optional beam_hyps_is_done_opt, bool const use_beam_hyps); diff --git a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp index d468f0588..02156e4e3 100644 --- a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp +++ b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp @@ -142,8 +142,10 @@ Tensor preprocess_weights_for_mixed_gemm( int8_t* input_byte_ptr = get_ptr(row_major_quantized_weight); int8_t* output_byte_ptr = get_ptr(processed_tensor); + bool force_interleave = row_major_quantized_weight.dim() == 3; // WAR for MoE 3-D tensors. + preprocess_weights_for_mixed_gemm( - output_byte_ptr, input_byte_ptr, {num_experts, num_rows, num_cols}, ft_quant_type); + output_byte_ptr, input_byte_ptr, {num_experts, num_rows, num_cols}, ft_quant_type, force_interleave); return processed_tensor; } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 257aa5439..8245cecb8 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -18,7 +18,8 @@ include(FetchContent) FetchContent_Declare( googletest - GIT_REPOSITORY https://github.com/google/googletest.git + GIT_REPOSITORY + https://github.com/google/googletest.git GIT_TAG release-1.12.1) FetchContent_MakeAvailable(googletest) include(GoogleTest) @@ -34,18 +35,27 @@ find_library_create_target(nvonnxparser ${ONNX_PARSER_LIB_NAME} SHARED include_directories( ${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include - ${PROJECT_SOURCE_DIR}/include) + ${PROJECT_SOURCE_DIR}/include ${3RDPARTY_DIR}/cutlass/include + ${3RDPARTY_DIR}/cutlass/tools/util/include) set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..") add_custom_target(google-tests) function(add_gtest test_name test_src) + set(options NO_GTEST_MAIN NO_TLLM_LINKAGE) + cmake_parse_arguments(ARGS "${options}" "${oneValueArgs}" "${multiValueArgs}" + ${ARGN}) add_executable(${test_name} ${test_src}) - target_link_libraries( - ${test_name} PUBLIC ${SHARED_TARGET} gtest_main gmock_main nvonnxparser - nvinfer_plugin_tensorrt_llm) + target_link_libraries(${test_name} PUBLIC gmock_main nvonnxparser) + if(NOT ARGS_NO_GTEST_MAIN) + target_link_libraries(${test_name} PUBLIC gtest_main) + endif() + if(NOT ARGS_NO_TLLM_LINKAGE) + target_link_libraries(${test_name} PUBLIC ${SHARED_TARGET} + nvinfer_plugin_tensorrt_llm) + endif() target_compile_features(${test_name} PRIVATE cxx_std_17) target_compile_definitions(${test_name} @@ -120,6 +130,49 @@ set(LOOKAHEAD_RANDOMLLM_TEST_SRC layers/randomLlm.cpp layers/lookaheadRandomLlmTest.cpp) add_gtest(lookaheadRandomLlmTest "${LOOKAHEAD_RANDOMLLM_TEST_SRC}") +add_gtest( + gemmSwigluRunnerTest + kernels/fused_gated_gemm/gemmSwigluRunnerTest.cu + ${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/kernels/fused_gated_gemm/gemm_swiglu_e4m3.cu + NO_GTEST_MAIN) +add_gtest(gemmSwigluKernelTestSm90Fp8 + kernels/fused_gated_gemm/gemmSwigluKernelTestSm90Fp8.cu NO_GTEST_MAIN + NO_TLLM_LINKAGE) + +foreach(target_name gemmSwigluRunnerTest;gemmSwigluKernelTestSm90Fp8) + set_property(TARGET ${target_name} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + + # Note - we deliberately do not include 90a PTX (even when 9.0+PTX is + # specified). This is because sm_90a has arch conditional instructions that + # are not forward compatible. As a result, it does not make sense to embed PTX + # into the binary anyway. + if("90" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG + OR "90-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG + OR "90-real" IN_LIST CMAKE_CUDA_ARCHITECTURES_NATIVE) + + message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.") + target_compile_options( + ${target_name} + PRIVATE $<$:-gencode=arch=compute_90a,code=sm_90a + -res-usage>) + + # Hopper kernels require cuda lib for TMA APIs + target_link_libraries(${target_name} PRIVATE CUDA::cuda_driver) + + # No kernels should be parsed, unless hopper is specified. This is a build + # time improvement + target_compile_definitions(${target_name} PRIVATE COMPILE_HOPPER_TMA_GEMMS) + endif() + + # Suppress GCC note: the ABI for passing parameters with 64-byte alignment has + # changed in GCC 4.6 This note appears for kernels using TMA and clutters the + # compilation output. + if(NOT WIN32) + target_compile_options( + ${target_name} PRIVATE $<$:-Xcompiler=-Wno-psabi>) + endif() +endforeach() + if(BUILD_BATCH_MANAGER) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/batch_manager) add_subdirectory(batch_manager) diff --git a/cpp/tests/kernels/fused_gated_gemm/fused_gated_gemm_util.h b/cpp/tests/kernels/fused_gated_gemm/fused_gated_gemm_util.h new file mode 100644 index 000000000..e3488a53d --- /dev/null +++ b/cpp/tests/kernels/fused_gated_gemm/fused_gated_gemm_util.h @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/gemm_coord.h" +#include "cutlass/layout/layout.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result(double avg_runtime_ms = 0, double gflops = 0, cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : avg_runtime_ms(avg_runtime_ms) + , gflops(gflops) + , status(status) + , error(error) + , passed(false) + { + } +}; + +/// Command line options parsing +struct Options +{ + std::string command_name; + bool help; + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord problem_size_out; + float alpha; + float beta; + float scale_d0; + float scale_d1; + float scale_output; + bool has_bias; + int split_k_factor; + int avail_sms; + int iterations; + bool real; + bool debug; + bool no_check; + + Options(std::string command_name) + : command_name(command_name) + , help(false) + , problem_size({32, 96, 128}) + , alpha(1.0f) + , beta(1.0f) + , scale_d0(1.0f) + , scale_d1(1.0f) + , scale_output(1.0f) + , has_bias(false) + , split_k_factor(1) + , avail_sms(-1) // Number of device SMs to use is unlimited + , real(false) + , iterations(10) + , debug(false) + , no_check(false) + { + parse(0, nullptr); + } + + Options() + : Options("") + { + } + + Options(Options const& other) + : command_name(other.command_name) + , help(other.help) + , problem_size((other.problem_size)) + , problem_size_out((other.problem_size_out)) + , alpha(other.alpha) + , beta(other.beta) + , scale_d0(other.scale_d0) + , scale_d1(other.scale_d1) + , scale_output(other.scale_output) + , has_bias(other.has_bias) + , split_k_factor(other.split_k_factor) + , avail_sms(other.avail_sms) // Number of device SMs to use is unlimited + , real(other.real) + , iterations(other.iterations) + , debug(other.debug) + , no_check(other.no_check) + { + } + + bool valid() const + { + return true; + } + + void parse(int argc, char const** args) + { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) + { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("scale_d0", scale_d0); + cmd.get_cmd_line_argument("scale_d1", scale_d1); + cmd.get_cmd_line_argument("scale_output", scale_output); + cmd.get_cmd_line_argument("split", split_k_factor); + cmd.get_cmd_line_argument("iterations", iterations); + real = cmd.check_cmd_line_flag("real"); + debug = cmd.check_cmd_line_flag("debug"); + no_check = cmd.check_cmd_line_flag("nocheck"); + has_bias = cmd.check_cmd_line_flag("bias"); + + problem_size_out = cutlass::gemm::GemmCoord(problem_size.m(), problem_size.n() / 2, problem_size.k()); + } + + /// Prints the usage statement. + std::ostream& print_usage(std::ostream& out) const + { + out << "Performs a GEMM computation.\n" + << "\n" + << "Options:\n" + << "\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m= GEMM M dimension\n" + << " --n= GEMM N dimension\n" + << " --k= GEMM K dimension\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --scale_d0= Epilogue scalar scale_d0\n" + << " --scale_d1= Epilogue scalar scale_d1\n\n" + << " --scale_output= Epilogue scalar scale_output\n\n" + << " --split= Split-K factor to emulate\n\n" + << " --real If specified, initializes with real values instead of whole numbers. " + "Errors are to be expected.\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s; + } +}; diff --git a/cpp/tests/kernels/fused_gated_gemm/gemmSwigluKernelTestSm90Fp8.cu b/cpp/tests/kernels/fused_gated_gemm/gemmSwigluKernelTestSm90Fp8.cu new file mode 100644 index 000000000..5aa2f0a51 --- /dev/null +++ b/cpp/tests/kernels/fused_gated_gemm/gemmSwigluKernelTestSm90Fp8.cu @@ -0,0 +1,499 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*************************************************************************************************** + This test code is adapted from CUTLASS + https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm + + Requires NVIDIA Hopper or newer device (SM00+). + **************************************************************************************************/ + +#include +#include +#include +#include + +#include "fused_gated_gemm_util.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm_kernel_template_sm90.h" + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass/conv/convolution.h" +// Order matters here, packed_stride.hpp is missing cute and convolution includes +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr bool SwapAB = true; + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands + +// D matrix configuration +using ElementD = cutlass::float_e4m3_t; +using LayoutD = cutlass::layout::RowMajor; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64, _16, _128>; // Threadblock-level tile size +using ClusterShape = Shape<_8, _1, _1>; // Shape of the threadblocks in a cluster +// using MainloopScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative; +// using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecializedCooperative; +using MainloopScheduleType = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; +using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + +// Reference device GEMM implementation type +// always use float for ElementC here because we need float output +using DeviceGemmReference = cutlass::reference::device::Gemm; + +// NOTE: debug purpose +template +struct Passthrough +{ + + CUTLASS_HOST_DEVICE + T operator()(T const& value) const + { + return 1; + } +}; + +struct Buffers +{ + cutlass::HostTensor tensor_a; + cutlass::HostTensor tensor_b; + cutlass::HostTensor tensor_c_bias; + cutlass::HostTensor tensor_d; + cutlass::HostTensor tensor_ref_d; + // we need float dtype for reference GEMM output + cutlass::HostTensor tensor_ref_d_2x; +}; + +// Activation +template +using Activation = cutlass::epilogue::thread::SiLu; +// using Activation = Passthrough; + +// using TileSchedulerType = cutlass::gemm::StreamKScheduler; +using TileSchedulerType = void; + +using Gemm = typename tensorrt_llm::kernels::cutlass_kernels::DeviceGemmGatedSm90::Gemm; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +typename Gemm::Arguments args_from_options(Gemm const& gemm, Options const& options, + cutlass::HostTensor& tensor_a, cutlass::HostTensor& tensor_b, + cutlass::HostTensor& tensor_d, cutlass::HostTensor& tensor_c_bias) +{ + using ElementT = typename Gemm::ElementA; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) + { + scheduler_args = {2}; + } + if constexpr (SwapAB) + { + int m = options.problem_size.n() / 2; + int n = options.problem_size.m(); + int k = options.problem_size.k(); + std::cout << "m: " << m << ", n: " << n << ", k: " << k << std::endl; + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC stride_C; + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); + printf("stride_A: "); + cute::print(stride_A); + printf("\nstride_B: "); + cute::print(stride_B); + printf("\nstride_D: "); + cute::print(stride_D); + printf("\n"); + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, + {tensor_b.device_data(), stride_A, tensor_a.device_data(), stride_B, options.scale_d0, options.scale_d1}, + {{}, tensor_c_bias.device_data(), stride_C, tensor_d.device_data(), stride_D}}; + args.epilogue.thread.alpha = options.scale_output; + args.scheduler = scheduler_args; + return args; + } + else + { + int m = options.problem_size.m(); + int n = options.problem_size.n() / 2; + int k = options.problem_size.k(); + std::cout << "m: " << m << ", n: " << n << ", k: " << k << std::endl; + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideC stride_C; + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); + printf("stride_A: "); + cute::print(stride_A); + printf("\nstride_B: "); + cute::print(stride_B); + printf("\nstride_D: "); + cute::print(stride_D); + printf("\n"); + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, + {tensor_a.device_data(), stride_A, tensor_b.device_data(), stride_B, options.scale_d0, options.scale_d1}, + {{}, tensor_c_bias.device_data(), stride_C, tensor_d.device_data(), stride_D}}; + args.epilogue.thread.alpha = options.scale_output; + args.scheduler = scheduler_args; + return args; + } +} + +/// Execute a given example GEMM computation +template +Result run(std::string description, Options& options, Buffers& buffers) +{ + // Display test description + std::cout << std::endl << description << std::endl; + + // Zero-initialize test output matrix D + cutlass::reference::host::TensorFill(buffers.tensor_d.host_view()); + buffers.tensor_d.sync_device(); + + // Instantiate CUTLASS kernel depending on templates + DeviceGemmT device_gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT + auto arguments = args_from_options(device_gemm, options, buffers.tensor_a, buffers.tensor_b, buffers.tensor_d, + buffers.tensor_c_bias /*, buffers.tensor_Tensor*/); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = DeviceGemmT::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + // device_gemm.can_implement(arguments); + auto can_implement = device_gemm.can_implement(arguments); + if (can_implement != cutlass::Status::kSuccess) + { + throw std::runtime_error("[TensorRT-LLM Error][fusedGatedGemm Runner]"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + device_gemm.initialize(arguments, workspace.get()); + + // Correctness / Warmup iteration + device_gemm(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + buffers.tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + + if (!options.no_check) + { + result.passed = cutlass::reference::host::TensorRelativelyEquals( + buffers.tensor_d.host_view(), buffers.tensor_ref_d.host_view(), ElementD{1e-2}, ElementD{1e-2}); + result.passed + = cutlass::reference::host::TensorEquals(buffers.tensor_d.host_view(), buffers.tensor_ref_d.host_view()); + EXPECT_TRUE(result.passed); + + double err = cutlass::reference::host::TensorRelativeErrorMetric( + buffers.tensor_d.host_view(), buffers.tensor_ref_d.host_view()); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err + << std::endl; + + if (!result.passed && options.debug) + { + std::cout << "ref_output=\n" + << buffers.tensor_ref_d.host_view() << "\noutput=\n" + << buffers.tensor_d.host_view() << std::endl; + } + } + + // Run profiling loop + if (options.iterations > 0) + { + cudaEvent_t start; + cudaEvent_t stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + cudaDeviceSynchronize(); + cudaEventRecord(start, 0); + for (int iter = 0; iter < options.iterations; ++iter) + { + device_gemm(); + } + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); + + float elapsed_ms; + cudaEventElapsedTime(&elapsed_ms, start, stop); + + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + } + + return result; +} + +/// Program entrypoint +int main(int argc, char const** argv) +{ + + // Current device must must have compute capability at least 80 + cudaDeviceProp props; + int current_device_id; + cudaGetDevice(¤t_device_id); + cudaGetDeviceProperties(&props, current_device_id); + if (!((props.major * 10 + props.minor) >= 90)) + { + std::cerr << "Hopper Tensor Core operations must be run on a machine with compute capability at least 90." + << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + exit(0); + } + + Buffers buffers; + // Parse commandline options + Options options("hopper_fp8_gemm_swiglu"); + options.parse(argc, argv); + + if (options.help) + { + options.print_usage(std::cout) << std::endl; + exit(0); + } + + std::cout << options.iterations << " timing iterations of " << options.problem_size.m() << " x " + << options.problem_size.n() << " x " << options.problem_size.k() << " matrix-matrix multiply" + << std::endl; + + if (!options.valid()) + { + std::cerr << "Invalid problem." << std::endl; + EXPECT_TRUE(false); + exit(-1); + } + + if (options.debug) + { + std::cout << "scale_d0: " << options.scale_d0 << ", scale_d1: " << options.scale_d1 + << ", scale_output: " << options.scale_output << std::endl; + } + + // + // Initialize GEMM datasets + // + + // Initialize tensors using CUTLASS helper functions + buffers.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K + buffers.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N + buffers.tensor_c_bias.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions 1 x N + buffers.tensor_d.resize( + options.problem_size_out + .mn()); // <- Create matrix D with dimensions M x N/2 used to store output from CUTLASS kernel + buffers.tensor_ref_d_2x.resize( + options.problem_size + .mn()); // <- Create temp matrix D with dimensions M x N used to store output from reference kernel + buffers.tensor_ref_d.resize( + options.problem_size_out + .mn()); // <- Create matrix D with dimensions M x N/2 used to store output from reference kernel + + int _init_bits = options.real ? -1 : 0; + + // Fill matrix A on host with uniform-random data [-2, 2] + if (options.debug) + { + cutlass::Array range; + range[0] = ElementA(256); + range[1] = ElementA(1); + cutlass::reference::host::TensorFillLinear(buffers.tensor_a.host_view(), range); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + buffers.tensor_a.host_view(), 1, ElementA(2), ElementA(-2), _init_bits); + } + + // Fill matrix B on host with uniform-random data [-2, 2] + if (options.debug) + { + cutlass::reference::host::TensorFillIdentity(buffers.tensor_b.host_view()); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + buffers.tensor_b.host_view(), 1, ElementB(2), ElementB(-2), _init_bits); + } + + if (options.debug || !options.has_bias) + { + cutlass::reference::host::TensorFill(buffers.tensor_c_bias.host_view()); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + buffers.tensor_c_bias.host_view(), 1, ElementC(2), ElementC(-2), _init_bits); + } + + if (options.debug) + { + std::cout << "A=" << std::endl << buffers.tensor_a.host_view() << std::endl; + std::cout << "B=" << std::endl << buffers.tensor_b.host_view() << std::endl; + std::cout << "C=" << std::endl << buffers.tensor_c_bias.host_view() << std::endl; + } + + // + // Compute reference output + // + + // Copy data from host to GPU + buffers.tensor_a.sync_device(); + buffers.tensor_b.sync_device(); + buffers.tensor_c_bias.sync_device(); + + // Zero-initialize reference output matrix D + cutlass::reference::host::TensorFill(buffers.tensor_ref_d_2x.host_view()); + buffers.tensor_ref_d_2x.sync_device(); + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference(options.problem_size, ElementAccumulator(options.alpha), buffers.tensor_a.device_ref(), + buffers.tensor_b.device_ref(), ElementAccumulator(options.beta), buffers.tensor_ref_d_2x.device_ref(), + buffers.tensor_ref_d_2x.device_ref()); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from reference kernel to host for comparison + buffers.tensor_ref_d_2x.sync_host(); + + // Add broadcast vector (without multiplier) + // Vector broadcast on host + // for (int i = 0; i < options.problem_size.m(); ++i) + // { + // for (int j = 0; j < options.problem_size.n(); ++j) + // { + // buffers.tensor_ref_d_2x.host_view().ref().at({i, j}) += buffers.tensor_c_bias.host_view().ref().at({0, + // j}); + // } + // } + cutlass::NumericConverter converter; + int half_n = options.problem_size.n() / 2; + for (int i = 0; i < options.problem_size.m(); i++) + { + for (int j = 0; j < half_n; j++) + { + auto s = options.scale_output + * ElementCompute(options.scale_d0 * buffers.tensor_ref_d_2x.host_view().ref().at({i, j})) + * Activation{}(options.scale_d1 * buffers.tensor_ref_d_2x.at({i, j + half_n})); + auto t = converter(s); + buffers.tensor_ref_d.host_view().ref().at({i, j}) = t; + } + } + + cudaDeviceSynchronize(); + + if (options.debug) + { + std::cout << "tensor_ref_d_2x=" << buffers.tensor_ref_d_2x.host_view() << std::endl; + } + + // + // Evaluate CUTLASS kernels + // +#ifdef COMPILE_HOPPER_TMA_GEMMS + Result hopperFp8 = run(std::string("Hopper fp8 swiglu"), options, buffers); +#else // COMPILE_HOPPER_TMA_GEMMS + std::cout << "[TensorRT-LLm Error][GemmSwigluKernelTestSm90Fp8] Please recompile with support for hopper by " + "passing 90-real as an arch to build_wheel.py." + << std::endl; +#endif // COMPILE_HOPPER_TMA_GEMMS + // for (int i = 0; i < options.problem_size_out.m(); i++) + // { + // for (int j = 0; j < options.problem_size_out.n(); j++) + // { + // std::cout << "i: " << i << ", j: " << j; + // std::cout << ", ref val: " << buffers.tensor_ref_d.host_view().ref().at({i, j}); + // std::cout << ", val: " << buffers.tensor_d.host_view().ref().at({i, j}) << std::endl; + // } + // } + + return 0; +} diff --git a/cpp/tests/kernels/fused_gated_gemm/gemmSwigluRunnerTest.cu b/cpp/tests/kernels/fused_gated_gemm/gemmSwigluRunnerTest.cu new file mode 100644 index 000000000..3db0d1a4c --- /dev/null +++ b/cpp/tests/kernels/fused_gated_gemm/gemmSwigluRunnerTest.cu @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include + +#include "fused_gated_gemm_util.h" + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm.h" + +#include "cutlass/arch/mma.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +using namespace tensorrt_llm::kernels::cutlass_kernels; + +Options g_options; + +template +struct Buffers +{ + cutlass::HostTensor tensor_a; + cutlass::HostTensor tensor_b; + cutlass::HostTensor tensor_d; + cutlass::HostTensor tensor_ref_d_2x; + cutlass::HostTensor tensor_ref_d; + cutlass::HostTensor tensor_c_bias; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Execute a given example GEMM computation +template +Result run(std::string description, Options& options, Buffers buffers) +{ + + // Display test description + std::cout << std::endl << description << std::endl; + + // Initialize + Result result; + + // Zero-initialize test output matrix D + cutlass::reference::host::TensorFill(buffers.tensor_d.host_view()); + buffers.tensor_d.sync_device(); + + // Instantiate CUTLASS kernel depending on templates + std::shared_ptr runner + = std::make_shared::type>>(); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size + = runner->getWorkspaceSize(options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + std::vector configs = runner->getConfigs(); + + cudaEvent_t start; + cudaEvent_t stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + float bestTime = std::numeric_limits::max(); + tensorrt_llm::cutlass_extensions::CutlassGemmConfig bestConfig; + for (auto const& config : configs) + { + std::cout << config << std::endl; + try + { + // Correctness / Warmup iteration + runner->gemm(buffers.tensor_d.device_data(), buffers.tensor_a.device_data(), buffers.tensor_b.device_data(), + buffers.tensor_c_bias.device_data(), tk::QuantMode{}, options.problem_size.m(), + options.problem_size.n(), options.problem_size.k(), options.scale_d0, options.scale_d1, + options.scale_output, config, workspace.get(), workspace_size, 0); + } + catch (std::runtime_error& e) + { + // We can ignore these error because most are related to SMEM oversubscription + std::cout << e.what() << std::endl; + continue; + } + // Copy output data from CUTLASS and reference kernel to host for comparison + buffers.tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + if (!options.no_check) + { + result.passed = cutlass::reference::host::TensorRelativelyEquals( + buffers.tensor_d.host_view(), buffers.tensor_ref_d.host_view(), ElementT{1e-3}, ElementT{1e-3}); + + EXPECT_TRUE(result.passed); + + double err = cutlass::reference::host::TensorRelativeErrorMetric( + buffers.tensor_d.host_view(), buffers.tensor_ref_d.host_view()); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err + << std::endl; + + if (!result.passed || options.debug) + { + std::cout << "ref_output=\n" + << buffers.tensor_ref_d.host_view() << "\noutput=\n" + << buffers.tensor_d.host_view() << std::endl; + } + } + + // Run profiling loop + if (options.iterations > 0) + { + cudaDeviceSynchronize(); + cudaEventRecord(start, 0); + for (int iter = 0; iter < options.iterations; ++iter) + { + runner->gemm(buffers.tensor_d.device_data(), buffers.tensor_a.device_data(), + buffers.tensor_b.device_data(), buffers.tensor_c_bias.device_data(), tk::QuantMode{}, + options.problem_size.m(), options.problem_size.n(), options.problem_size.k(), options.scale_d0, + options.scale_d1, options.scale_output, config, workspace.get(), workspace_size, 0); + } + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); + + float elapsed_ms; + cudaEventElapsedTime(&elapsed_ms, start, stop); + + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + + if (result.avg_runtime_ms < bestTime) + { + bestTime = result.avg_runtime_ms; + bestConfig = config; + } + } + } + + std::cout << "Best runtime: " << bestTime << " ms" << std::endl; + std::cout << "Best config: " << bestConfig << std::endl; + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + return result; +} + +template +using Activation = cutlass::epilogue::thread::SiLu; + +TEST(GemmSwigluRunner, Sm90FP8) +{ + using ElementT = cutlass::float_e4m3_t; + using ElementAccumulatorT = float; + using ElementComputeT = float; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementD2x = float; + + Buffers buffers; + + // Parse commandline options + Options options(g_options); + + std::cout << options.iterations << " timing iterations of " << options.problem_size.m() << " x " + << options.problem_size.n() << " x " << options.problem_size.k() << " matrix-matrix multiply" + << std::endl; + + if (!options.valid()) + { + std::cerr << "Invalid problem." << std::endl; + FAIL(); + } + + if (options.debug) + { + std::cout << "scale_d0: " << options.scale_d0 << ", scale_d1: " << options.scale_d1 + << ", scale_output: " << options.scale_output << std::endl; + } + + // + // Initialize GEMM datasets + // + + // Initialize tensors using CUTLASS helper functions + buffers.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K + buffers.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N + buffers.tensor_c_bias.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions 1 x N + buffers.tensor_d.resize( + options.problem_size_out + .mn()); // <- Create matrix D with dimensions M x N/2 used to store output from CUTLASS kernel + buffers.tensor_ref_d_2x.resize( + options.problem_size + .mn()); // <- Create temp matrix D with dimensions M x N used to store output from reference kernel + buffers.tensor_ref_d.resize( + options.problem_size_out + .mn()); // <- Create matrix D with dimensions M x N/2 used to store output from reference kernel + + int _init_bits = options.real ? -1 : 0; + + // Fill matrix A on host with uniform-random data [-2, 2] + if (options.debug) + { + cutlass::Array range; + range[0] = ElementT(256); + range[1] = ElementT(1); + cutlass::reference::host::TensorFillLinear(buffers.tensor_a.host_view(), range); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + buffers.tensor_a.host_view(), 1, ElementT(2), ElementT(-2), _init_bits); + } + + // Fill matrix B on host with uniform-random data [-2, 2] + if (options.debug) + { + cutlass::reference::host::TensorFillIdentity(buffers.tensor_b.host_view()); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + buffers.tensor_b.host_view(), 1, ElementT(2), ElementT(-2), _init_bits); + } + + if (options.debug || !options.has_bias) + { + cutlass::reference::host::TensorFill(buffers.tensor_c_bias.host_view()); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + buffers.tensor_c_bias.host_view(), 1, ElementT(2), ElementT(-2), _init_bits); + } + + if (options.debug) + { + std::cout << "A=" << std::endl << buffers.tensor_a.host_view() << std::endl; + std::cout << "B=" << std::endl << buffers.tensor_b.host_view() << std::endl; + std::cout << "C=" << std::endl << buffers.tensor_c_bias.host_view() << std::endl; + } + + // + // Compute reference output + // + + // Copy data from host to GPU + buffers.tensor_a.sync_device(); + buffers.tensor_b.sync_device(); + buffers.tensor_c_bias.sync_device(); + + // Zero-initialize reference output matrix D + cutlass::reference::host::TensorFill(buffers.tensor_ref_d_2x.host_view()); + buffers.tensor_ref_d_2x.sync_device(); + + // Create instantiation for device reference gemm kernel + // Reference device GEMM implementation type + using DeviceGemmReference = cutlass::reference::device::Gemm; + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference(options.problem_size, ElementAccumulatorT(options.alpha), buffers.tensor_a.device_ref(), + buffers.tensor_b.device_ref(), ElementAccumulatorT(options.beta), buffers.tensor_ref_d_2x.device_ref(), + buffers.tensor_ref_d_2x.device_ref()); + + // Wait for kernels to finish + tk::check_cuda_error(cudaDeviceSynchronize()); + + // Copy output data from reference kernel to host for comparison + buffers.tensor_ref_d_2x.sync_host(); + + // Add broadcast vector (without multiplier) + // Vector broadcast on host + // for (int i = 0; i < options.problem_size.m(); ++i) + // { + // for (int j = 0; j < options.problem_size.n(); ++j) + // { + // buffers.tensor_ref_d_2x.host_view().ref().at({i, j}) += buffers.tensor_c_bias.host_view().ref().at({0, + // j}); + // } + // } + cutlass::NumericConverter converter; + int half_n = options.problem_size.n() / 2; + for (int i = 0; i < options.problem_size.m(); i++) + { + for (int j = 0; j < half_n; j++) + { + auto s = options.scale_output + * ElementComputeT(options.scale_d0 * buffers.tensor_ref_d_2x.host_view().ref().at({i, j})) + * Activation{}(options.scale_d1 * buffers.tensor_ref_d_2x.at({i, j + half_n})); + auto t = converter(s); + buffers.tensor_ref_d.host_view().ref().at({i, j}) = t; + } + } + + tk::check_cuda_error(cudaDeviceSynchronize()); + + if (options.debug) + { + std::cout << "tensor_ref_d_2x=" << buffers.tensor_ref_d_2x.host_view() << std::endl; + } + + // + // Evaluate CUTLASS kernels + // + +#ifdef COMPILE_HOPPER_TMA_GEMMS + Result hopperFp8 = run("SM90 FP8 WS GEMM", options, buffers); + EXPECT_TRUE(hopperFp8.passed); +#else // COMPILE_HOPPER_TMA_GEMMS + std::cout << "[TensorRT-LLm Error][GemmSwigluRunnerTest] Please recompile with support for hopper by passing " + "90-real as an arch to build_wheel.py." + << std::endl; +#endif // COMPILE_HOPPER_TMA_GEMMS +} + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + + g_options.parse(argc, const_cast(argv)); + + return RUN_ALL_TESTS(); +} diff --git a/cpp/tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/kernels/mixtureOfExpertsTest.cu index 46f9ed39e..a02e0ab2d 100644 --- a/cpp/tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/kernels/mixtureOfExpertsTest.cu @@ -159,11 +159,11 @@ protected: void SetUp() override { - assert(mBufferManager); if (shouldSkip()) { GTEST_SKIP() << "Skipping due to no/unsupported GPU"; } + assert(mBufferManager); } void TearDown() @@ -767,6 +767,10 @@ protected: return std::max(in, T(0.0f)); if (mActType == tensorrt_llm::ActivationType::Gelu || mActType == tensorrt_llm::ActivationType::Geglu) return (std::erf(float(in) * float(sqrt(0.5))) + 1) * 0.5f * float(in); + if (mActType == tensorrt_llm::ActivationType::Silu || mActType == tensorrt_llm::ActivationType::Swiglu) + { + return (float(in) / (1.f + std::exp(-(in)))); + } assert(false); return in; } @@ -1068,6 +1072,14 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteGeglu) this->BasicPermuteTest(3); } +TYPED_TEST(MixtureOfExpertsTest, PermuteSwiglu) +{ + this->mActType = tensorrt_llm::ActivationType::Swiglu; + this->BasicPermuteTest(); + this->BasicPermuteTest(2); + this->BasicPermuteTest(3); +} + TYPED_TEST(MixtureOfExpertsTest, Finished) { if (this->FP8) @@ -1227,6 +1239,13 @@ TYPED_TEST(MixtureOfExpertsTest, ExpertParallelGeglu) this->ExpertParallelTest(2); } +TYPED_TEST(MixtureOfExpertsTest, ExpertParallelSwiglu) +{ + this->mActType = tensorrt_llm::ActivationType::Swiglu; + this->ExpertParallelTest(); + this->ExpertParallelTest(2); +} + template void MixtureOfExpertsTest::TensorParallelTest(int k) { @@ -1328,36 +1347,51 @@ TYPED_TEST(MixtureOfExpertsTest, TensorParallelGeglu) this->TensorParallelTest(3); } +TYPED_TEST(MixtureOfExpertsTest, TensorParallelSwiglu) +{ + this->mActType = tensorrt_llm::ActivationType::Swiglu; + this->TensorParallelTest(); + this->TensorParallelTest(2); + this->TensorParallelTest(3); +} + TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) { + std::vector actiavtion_pool = { + tensorrt_llm::ActivationType::Relu, tensorrt_llm::ActivationType::Swiglu, tensorrt_llm::ActivationType::Geglu}; auto configs = this->mMoERunner.getTactics(); - for (auto conf : configs) + for (auto const activation_type : actiavtion_pool) { - using namespace tensorrt_llm::cutlass_extensions; - std::stringstream tactic; - tactic << "Failed " << (conf.is_sm90 ? "SM90+" : "(activation_type); + + EXPECT_NO_THROW({ + this->mActType = activation_type; + this->mSelectedConfig = conf; + this->BasicPermuteTest(); + if (::testing::Test::HasFailure()) + throw std::runtime_error("Test Failed"); + }) << tactic.str(); } - - EXPECT_NO_THROW({ - this->mSelectedConfig = conf; - this->BasicPermuteTest(); - if (::testing::Test::HasFailure()) - throw std::runtime_error("Test Failed"); - }) << tactic.str(); } } diff --git a/cpp/tests/layers/dynamicDecodeLayerTest.cpp b/cpp/tests/layers/dynamicDecodeLayerTest.cpp index 0f9377adb..85cf4f713 100644 --- a/cpp/tests/layers/dynamicDecodeLayerTest.cpp +++ b/cpp/tests/layers/dynamicDecodeLayerTest.cpp @@ -26,7 +26,6 @@ namespace tensorrt_llm::tests::layers::sampling // - finished states // - finished sum // - max length -// - repeat n grams // - padded vocab // - beam search @@ -117,10 +116,14 @@ void DynamicDecodeLayerTest::SetUp() } template -void DynamicDecodeLayerTest::allocateData(TestSamplingParams const& params) +void DynamicDecodeLayerTest::allocateData(TestSamplingParams const& params, TokenIdType endId) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + mEndId = endId == -1 ? mVocabSize - 1 : endId; + mUseMedusa = params.useMedusa; + mMaxTokensPerStep = mUseMedusa ? mMaxOutputLen - mMaxInputLen : 1; + auto const decodingMode = params.decodingMode.value_or( [this]() { @@ -357,6 +360,9 @@ void DynamicDecodeLayerTest::setup(uint64_t seed, TestSamplingParams const& p setupParams->samplingParams.normalize_log_probs = {false}; setupParams->samplingParams.outputLogProbs = {true}; setupParams->samplingParams.cumLogProbs = {true}; + setupParams->penaltyParams.noRepeatNgramSize = params.repeatNGramSizes.size() + ? std::make_optional>(params.repeatNGramSizes) + : std::nullopt; setupParams->medusaParams.topKMedusaHeads = params.topKMedusaHeads; @@ -500,7 +506,7 @@ std::shared_ptr DynamicDecodeLayerTest::createInput // std::optional src_cache_indirection; // std::optional sequence_limit_length; // std::optional input_lengths; - // std::optional no_repeat_ngram_size; + // std::optional no_repeat_ngram_size; has move to sampling config // std::optional> logits_vec; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -645,12 +651,6 @@ void DynamicDecodeLayerTest::runTestImpl( { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - mEndId = endId == -1 ? mVocabSize - 1 : endId; - mUseMedusa = params.useMedusa; - mMaxTokensPerStep = mUseMedusa ? mMaxOutputLen - mMaxInputLen : 1; - - allocateData(params); - bool greedySearch = std::all_of(expectedOutputIds.begin(), expectedOutputIds.end(), [](auto v) { return v.size() == 1; }); for (uint64_t seed = 0; seed < mMaxSeed; ++seed) @@ -746,6 +746,8 @@ template void DynamicDecodeLayerTest::runTest( std::vector> const& expectedOutputIds, TestSamplingParams const& params, TokenIdType endId) { + allocateData(params, endId); + if (!params.useMedusa) { TLLM_LOG_DEBUG("Run test with linear logits"); @@ -1056,6 +1058,50 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPTemperatureBatch) this->runTest(expectedOutputIds, params); } +TYPED_TEST(DynamicDecodeLayerTest, TopPTemperatureMultipleRequests) +{ + this->allocateData(TestSamplingParams{}); + { + std::vector temperatures = {0.01f, 1e3f, 1.0f, 1.0f, 0.01f, 1.0f}; + TestSamplingParams params; + params.temperatures = temperatures; + params.topPs = {0.5f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4, 5, 6, 7}, {4, 5}, {4, 5}, {4}, {4, 5}, // step 0 + {0}, {0, 1, 2, 3}, {0, 1}, {0, 1}, {0}, {0, 1}, // step 1 + {2}, {2, 3, 4, 5}, {2, 3}, {2, 3}, {2}, {2, 3}, // step 2 + {0}, {0, 1, 2, 3}, {0, 1}, {0, 1}, {0}, {0, 1} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + TestSamplingParams params; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {0}, {0}, {0}, {0}, {0}, {0} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + float temperature = 1.0f; + TestSamplingParams params; + params.temperatures = {temperature}; + params.topPs = {0.5f}; + std::vector> expectedOutputIds{ + {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, {4, 5}, // step 0 + {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1}, // step 1 + {2, 3}, {2, 3}, {2, 3}, {2, 3}, {2, 3}, {2, 3}, // step 2 + {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1}, {0, 1} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } +} + TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionPenalty) { SizeType32 topK = 1; @@ -1107,6 +1153,51 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionPenaltiesBatch) this->runTest(expectedOutputIds, params); } +TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionPenaltyMultipleRequests) +{ + this->allocateData(TestSamplingParams{}); + { + float repetitionPenalty = 1e9f; + TestSamplingParams params; + params.repetitionPenalties = {repetitionPenalty}; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {1}, {1}, {1}, {1}, {1}, {1} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + TestSamplingParams params; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {0}, {0}, {0}, {0}, {0}, {0} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + std::vector repetitionPenalties = {1e9f, 1e9f, 1.0f, 1.0f, 1.0f, 1e9f}; + TestSamplingParams params; + params.repetitionPenalties = repetitionPenalties; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {1}, {1}, {0}, {0}, {0}, {1} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } +} + TYPED_TEST(DynamicDecodeLayerTest, TopPPresencePenalty) { float presencePenalty = 1e9f; @@ -1156,6 +1247,51 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPPresencePenaltiesBatch) this->runTest(expectedOutputIds, params); } +TYPED_TEST(DynamicDecodeLayerTest, TopPPresencePenaltyMultipleRequests) +{ + this->allocateData(TestSamplingParams{}); + { + float presencePenalty = 1e9f; + TestSamplingParams params; + params.presencePenalties = {presencePenalty}; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {1}, {1}, {1}, {1}, {1}, {1} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + TestSamplingParams params; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {0}, {0}, {0}, {0}, {0}, {0} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + std::vector presencePenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; + TestSamplingParams params; + params.presencePenalties = presencePenalties; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {1}, {1}, {0}, {0}, {0}, {1} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } +} + TYPED_TEST(DynamicDecodeLayerTest, TopPFrequencyPenalty) { float frequencyPenalty = 1e9f; @@ -1205,6 +1341,51 @@ TYPED_TEST(DynamicDecodeLayerTest, TopPFrequencyPenaltiesBatch) this->runTest(expectedOutputIds, params); } +TYPED_TEST(DynamicDecodeLayerTest, TopPFrequencyPenaltyMultipleRequests) +{ + this->allocateData(TestSamplingParams{}); + { + float frequencyPenalty = 1e9f; + TestSamplingParams params; + params.frequencyPenalties = {frequencyPenalty}; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {1}, {1}, {1}, {1}, {1}, {1} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + TestSamplingParams params; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {0}, {0}, {0}, {0}, {0}, {0} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + std::vector frequencyPenalties = {1e9f, 1e9f, 0.0f, 0.0f, 0.0f, 1e9f}; + TestSamplingParams params; + params.frequencyPenalties = frequencyPenalties; + params.topPs = {0.3f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {1}, {1}, {0}, {0}, {0}, {1} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } +} + TYPED_TEST(DynamicDecodeLayerTest, TopPRepetitionPresencePenalty) { float repetitionPenalty = 1e9f; @@ -1780,6 +1961,111 @@ TYPED_TEST(DynamicDecodeLayerTest, BadWordsNoBadWordsMode) this->runTest(expectedOutputIds, params); } +TYPED_TEST(DynamicDecodeLayerTest, NoRepeatNgramSize) +{ + SizeType32 topK = 1; + TestSamplingParams params; + params.topKs = {topK}; + params.topPs = {1.0f}; + params.badWords = {{{0}}, {{2}}, {{0}, {3}, {4, 1, 2}}, {{5}}, {{0}}, {{1}}}; + params.repeatNGramSizes = {1, 1, 2, 1, 1, 3}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {1}, {0}, {1}, {0}, {1}, {0}, // step 1 + {2}, {3}, {4}, {2}, {2}, {2}, // step 2 + {3}, {1}, {2}, {1}, {3}, {0} // step 3 + }; + this->runTest(expectedOutputIds, params); +} + +TYPED_TEST(DynamicDecodeLayerTest, NoRepeatNgramSizeNoNgramMode) +{ + SizeType32 topK = 1; + TestSamplingParams params; + params.topKs = {topK}; + params.topPs = {1.0f}; + params.badWords = {{{0}}, {{2}}, {{0}, {3}, {4, 1, 2}}, {{5}}, {{0}}, {{1}}}; + params.repeatNGramSizes = {1, 1, 2, 1, 1, 3}; + params.decodingMode = tle::DecodingMode::TopK().useNoRepeatNgramSize(false); + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {1}, {0}, {1}, {0}, {1}, {0}, // step 1 + {2}, {3}, {4}, {2}, {2}, {2}, // step 2 + {1}, {0}, {1}, {0}, {1}, {0} // step 3 + }; + this->runTest(expectedOutputIds, params); +} + +TYPED_TEST(DynamicDecodeLayerTest, NoRepeatNgramSizeNoBanTokensMode) +{ + SizeType32 topK = 1; + TestSamplingParams params; + params.topKs = {topK}; + params.topPs = {1.0f}; + params.badWords = {{{0}}, {{2}}, {{0}, {3}, {4, 1, 2}}, {{5}}, {{0}}, {{1}}}; + params.repeatNGramSizes = {1, 1, 2, 1, 1, 3}; + params.decodingMode = tle::DecodingMode::TopK().useBanTokens(false); + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {0}, {0}, {0}, {0}, {0}, {0}, // step 3 + }; + this->runTest(expectedOutputIds, params); +} + +TYPED_TEST(DynamicDecodeLayerTest, NoRepeatNgramSizeMultipleRequests) +{ + this->allocateData(TestSamplingParams{}); + { + SizeType32 topK = 1; + TestSamplingParams params; + params.topKs = {topK}; + params.topPs = {1.0f}; + params.repeatNGramSizes = {1, 1, 2, 1, 1, 3}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {1}, {1}, {0}, {1}, {1}, {0} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + SizeType32 topK = 1; + TestSamplingParams params; + params.topKs = {topK}; + params.topPs = {1.0f}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {0}, {0}, {0}, {0}, {0}, {0} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } + { + SizeType32 topK = 1; + TestSamplingParams params; + params.topKs = {topK}; + params.topPs = {1.0f}; + params.repeatNGramSizes = {1, 1, 2, 1, 1, 3}; + std::vector> expectedOutputIds{ + // batch + {4}, {4}, {4}, {4}, {4}, {4}, // step 0 + {0}, {0}, {0}, {0}, {0}, {0}, // step 1 + {2}, {2}, {2}, {2}, {2}, {2}, // step 2 + {1}, {1}, {0}, {1}, {1}, {0} // step 3 + }; + this->runTestImpl(expectedOutputIds, params); + } +} + TYPED_TEST(DynamicDecodeLayerTest, StopWords) { SizeType32 topK = 1; diff --git a/cpp/tests/layers/dynamicDecodeLayerTest.h b/cpp/tests/layers/dynamicDecodeLayerTest.h index 0188deabc..414841a6f 100644 --- a/cpp/tests/layers/dynamicDecodeLayerTest.h +++ b/cpp/tests/layers/dynamicDecodeLayerTest.h @@ -55,6 +55,7 @@ struct TestSamplingParams std::vector topPResetIds; std::vector>> badWords; std::vector>> stopWords; + std::vector repeatNGramSizes; bool useBias{false}; std::optional decodingMode; @@ -150,7 +151,6 @@ class DynamicDecodeLayerTest : public testing::Test bool mUseMedusa{false}; private: - void allocateData(TestSamplingParams const& params); void allocateMedusaData(TestSamplingParams const& params); void setup(uint64_t seed, TestSamplingParams const& params); @@ -169,9 +169,6 @@ class DynamicDecodeLayerTest : public testing::Test runtime::SizeType32* seqLens, runtime::SizeType32 leadingDim, runtime::SizeType32 stride, runtime::SizeType32 step, bool outputIdsTransposed = false, runtime::SizeType32 strideTransposed = 0); - void runTestImpl(std::vector> const& expectedOutputIds, - TestSamplingParams const& params, runtime::TokenIdType endId = -1); - void fillRefLogits(runtime::SizeType32 const* seqLenHost, std::vector> const& expectedOutputIds, runtime::SizeType32 step); @@ -181,6 +178,11 @@ class DynamicDecodeLayerTest : public testing::Test public: void runTest(std::vector> const& expectedOutputIds, TestSamplingParams const& params, runtime::TokenIdType endId = -1); + + void allocateData(TestSamplingParams const& params, runtime::TokenIdType endId = -1); + + void runTestImpl(std::vector> const& expectedOutputIds, + TestSamplingParams const& params, runtime::TokenIdType endId = -1); }; typedef testing::Types FloatAndHalfTypes; diff --git a/cpp/tests/resources/scripts/build_enc_dec_engines.py b/cpp/tests/resources/scripts/build_enc_dec_engines.py index 85f6415e8..d40b24a1d 100644 --- a/cpp/tests/resources/scripts/build_enc_dec_engines.py +++ b/cpp/tests/resources/scripts/build_enc_dec_engines.py @@ -30,9 +30,8 @@ class Arguments: rm_pad: bool = True gemm: bool = True - # rmsm: bool = True # TODO: remove this - max_new_tokens: int = 10 + max_new_tokens: int = 64 @property def ckpt(self): @@ -73,7 +72,7 @@ def __post_init__(self): k = k.name v = getattr(self, k) if isinstance(v, bool): - parser.add_argument(f'--{k}', default=int(v), type=int) + parser.add_argument(f'--{k}', action='store_true') else: parser.add_argument(f'--{k}', default=v, type=type(v)) @@ -131,8 +130,8 @@ class Build(RunCMDMixin): def command(self): args = self.args - engine_dir = join(args.engines_dir, f'tp{args.tp}') - weight_dir = join(args.trt_models_dir, f'tp{args.tp}', f'pp{args.pp}') + engine_dir = args.engines_dir + weight_dir = args.trt_models_dir encoder_build = [ f"trtllm-build --checkpoint_dir {join(weight_dir, 'encoder')}", f"--output_dir {join(engine_dir, 'encoder')}", @@ -149,7 +148,7 @@ def command(self): decoder_build = [ f"trtllm-build --checkpoint_dir {join(weight_dir, 'decoder')}", f"--output_dir {join(engine_dir, 'decoder')}", - f'--paged_kv_cache disable', f'--moe_plugin disable', + f'--paged_kv_cache enable', f'--moe_plugin disable', f'--enable_xqa disable', f'--max_beam_width {args.beams}', f'--max_batch_size 8', f'--max_output_len 200', f'--gemm_plugin {args.dtype}', diff --git a/cpp/tests/resources/scripts/generate_expected_enc_dec_output.py b/cpp/tests/resources/scripts/generate_expected_enc_dec_output.py index 3f94378b6..4fb652a64 100644 --- a/cpp/tests/resources/scripts/generate_expected_enc_dec_output.py +++ b/cpp/tests/resources/scripts/generate_expected_enc_dec_output.py @@ -1,5 +1,3 @@ -from os.path import join - from build_enc_dec_engines import Arguments, RunCMDMixin @@ -9,16 +7,18 @@ def command(self): args = self.args world_size = args.tp * args.pp mpi_run = f'mpirun --allow-run-as-root -np {world_size} ' if world_size > 1 else '' - engine_dir = join(args.engines_dir, f'tp{args.tp}') - return (mpi_run + - f'python3 examples/enc_dec/run.py --engine_dir {engine_dir} ' - f'--engine_name {args.ckpt} ' - f'--model_name "{args.hf_models_dir}" ' - f'--max_new_tokens={args.max_new_tokens} ' - f'--num_beams={args.beams} ' - f'--compare_hf_fp32 ' - f"{'--debug_mode ' if args.debug else ''} " - "--output_encoder_npy ") + ret = ( + f'python3 examples/enc_dec/run.py --engine_dir {args.engines_dir}', + f'--engine_name {args.ckpt}', + f'--model_name "{args.hf_models_dir}"', + f'--max_new_tokens={args.max_new_tokens}', + f'--num_beams={args.beams}', + f'--compare_hf_fp32', + f'--output_npy={args.data_dir}', + "--debug_mode" if args.debug else "", + ) + ret = mpi_run + ' '.join(ret) + return ret if __name__ == '__main__': diff --git a/cpp/tests/resources/scripts/test_cpp.py b/cpp/tests/resources/scripts/test_cpp.py index b5e0eb1cd..ca949811a 100755 --- a/cpp/tests/resources/scripts/test_cpp.py +++ b/cpp/tests/resources/scripts/test_cpp.py @@ -105,6 +105,8 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None, run_mamba=False, run_recurrentgemma=False, run_encoder=False, + run_bart=False, + run_t5=False, run_fp8=False, only_multi_gpu=False, trt_root: _tp.Optional[str] = None, @@ -196,6 +198,8 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None, run_mamba=run_mamba, run_recurrentgemma=run_recurrentgemma, run_encoder=run_encoder, + run_bart=run_bart, + run_t5=run_t5, run_fp8=run_fp8) if build_only: @@ -210,6 +214,8 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None, run_mamba=run_mamba, run_recurrentgemma=run_recurrentgemma, run_encoder=run_encoder, + run_bart=run_bart, + run_t5=run_t5, run_fp8=run_fp8, timeout=test_timeout) @@ -245,6 +251,8 @@ def prepare_all_model_tests(python_exe: str, run_mamba=False, run_recurrentgemma=False, run_encoder=False, + run_bart=False, + run_t5=False, run_fp8=False): model_cache_arg = ["--model_cache", model_cache] if model_cache else [] @@ -328,6 +336,24 @@ def prepare_all_model_tests(python_exe: str, else: _log.info("Skipping encoder tests") + if run_bart: + prepare_model_tests(model_name="bart", + python_exe=python_exe, + root_dir=root_dir, + resources_dir=resources_dir, + model_cache_arg=model_cache_arg) + else: + _log.info("Skipping BART tests") + + if run_t5: + prepare_model_tests(model_name="t5", + python_exe=python_exe, + root_dir=root_dir, + resources_dir=resources_dir, + model_cache_arg=model_cache_arg) + else: + _log.info("Skipping T5 tests") + def prepare_multi_gpu_model_tests(python_exe: str, root_dir: _pl.Path, @@ -354,17 +380,25 @@ def prepare_model_tests(model_name: str, scripts_dir = resources_dir / "scripts" model_env = {**_os.environ, "PYTHONPATH": f"examples/{model_name}"} + enc_dec_model_name_arg = [] + if model_name in ('bart', 't5'): + enc_dec_model_name_arg = [ + '--hf_repo_name', + 'facebook/bart-large-cnn' if model_name == 'bart' else 't5-small' + ] + model_name = 'enc_dec' + build_engines = [ python_exe, str(scripts_dir / f"build_{model_name}_engines.py") - ] + model_cache_arg + only_fp8_arg + only_multi_gpu_arg + ] + model_cache_arg + only_fp8_arg + only_multi_gpu_arg + enc_dec_model_name_arg run_command(build_engines, cwd=root_dir, env=model_env, timeout=1800) model_env["PYTHONPATH"] = "examples" generate_expected_output = [ python_exe, str(scripts_dir / f"generate_expected_{model_name}_output.py") - ] + only_fp8_arg + only_multi_gpu_arg + ] + only_fp8_arg + only_multi_gpu_arg + enc_dec_model_name_arg if "enc_dec" in model_name: generate_expected_output += model_cache_arg if only_multi_gpu_arg: @@ -402,6 +436,7 @@ def run_unit_tests(build_dir: _pl.Path, timeout=1800): excluded_tests.append("Mamba") excluded_tests.append("RecurrentGemma") excluded_tests.append("Encoder") + excluded_tests.append("EncDec") ctest.extend(["-E", "|".join(excluded_tests)]) run_command(ctest, cwd=build_dir, env=cpp_env, timeout=timeout) @@ -415,6 +450,8 @@ def run_single_gpu_tests(build_dir: _pl.Path, run_mamba, run_recurrentgemma, run_encoder, + run_bart, + run_t5, run_fp8, timeout=3600): build_tests(build_dir=build_dir) @@ -453,40 +490,55 @@ def run_single_gpu_tests(build_dir: _pl.Path, ctest.extend(["-E", "|".join(excluded_tests)]) run_command(ctest, cwd=build_dir, env=cpp_env, timeout=timeout) + def run_enc_dec_test_with_env(model: str): + enc_dec_test_command = [ + "tests/executor/executorTest", + "--gtest_filter=EncDecBasicTest/EncDecParamsTest.Forward*", + f"--gtest_output=xml:{str(build_dir)}/results-single-gpu-enc-dec.xml" + ] + run_command(enc_dec_test_command, + cwd=build_dir, + env={ + **cpp_env, 'ENC_DEC_MODEL': model + }) + + if run_bart: + run_enc_dec_test_with_env('bart') + if run_t5: + run_enc_dec_test_with_env('t5') + def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500): build_tests(build_dir=build_dir) tests_dir = build_dir / "tests" + xml_output_file = build_dir / "results-multi-gpu-real-decoder.xml" cpp_env = {**_os.environ} - # TP2+PP2 tests fail for beam search - session_test = [ - "mpirun", "-n", "4", "--allow-run-as-root", "gptSessionTest", - "--gtest_filter=*TP4*:*PP4*" - ] - run_command(session_test, cwd=tests_dir, env=cpp_env, - timeout=300) # expecting ~250s - trt_model_test = [ "mpirun", "-n", "4", "--allow-run-as-root", - "batch_manager/trtGptModelRealDecoderTest", "--gtest_filter=*TP*:*PP*" + "batch_manager/trtGptModelRealDecoderTest", "--gtest_filter=*TP*:*PP*", + f"--gtest_output=xml:{xml_output_file}" ] run_command(trt_model_test, cwd=tests_dir, env=cpp_env, timeout=timeout) # expecting ~ 1200s #Executor test in leader mode new_env = cpp_env + xml_output_file = build_dir / "results-multi-gpu-llama-exec-leader-mode.xml" new_env["RUN_LLAMA_MULTI_GPU"] = "true" trt_model_test = [ "mpirun", "-n", "4", "--allow-run-as-root", "executor/executorTest", - "--gtest_filter=*LlamaExecutorTest*LeaderMode*" + "--gtest_filter=*LlamaExecutorTest*LeaderMode*", + f"--gtest_output=xml:{xml_output_file}" ] run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500) #Executor test in orchestrator mode + xml_output_file = build_dir / "results-multi-gpu-llama-exec-orch-mode.xml" trt_model_test = [ "mpirun", "-n", "1", "--allow-run-as-root", "executor/executorTest", - "--gtest_filter=*LlamaExecutorTest*OrchMode*" + "--gtest_filter=*LlamaExecutorTest*OrchMode*", + f"--gtest_output=xml:{xml_output_file}" ] run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500) @@ -652,6 +704,12 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, parser.add_argument("--run_encoder", action="store_true", help="Run the tests for BART encoder") + parser.add_argument("--run_bart", + action="store_true", + help="Run the tests for BART") + parser.add_argument("--run_t5", + action="store_true", + help="Run the tests for T5") parser.add_argument( "--run_fp8", action="store_true", @@ -675,6 +733,8 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, args.run_mamba = True args.run_recurrentgemma = True args.run_encoder = True + args.run_bart = True + args.run_t5 = True del args.run_all_models diff --git a/cpp/tests/runtime/gptSessionTest.cpp b/cpp/tests/runtime/gptSessionTest.cpp index 0eb671843..ef412d9d7 100644 --- a/cpp/tests/runtime/gptSessionTest.cpp +++ b/cpp/tests/runtime/gptSessionTest.cpp @@ -239,6 +239,7 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model samplingConfig.topP = std::vector{0.0f}; samplingConfig.lengthPenalty = std::vector{1.0f}; samplingConfig.earlyStopping = std::vector{1}; + samplingConfig.noRepeatNgramSize = std::vector{1 << 30}; auto const padId = modelIds.padId; auto endId = modelIds.endId; diff --git a/cpp/tests/runtime/iTensorTest.cpp b/cpp/tests/runtime/iTensorTest.cpp index 073678edd..2bad82500 100644 --- a/cpp/tests/runtime/iTensorTest.cpp +++ b/cpp/tests/runtime/iTensorTest.cpp @@ -217,3 +217,384 @@ TEST(ITensorTest, TensorSlice) auto uniqueSlice = ITensor::slice(std::move(constSlice), 1); EXPECT_EQ(uniqueSlice->getShape().d[0], dims.d[0] - offset - 1); } + +TEST(ITensorTest, TensorDimsSliceAtManual) +{ + auto shape = ITensor::makeShape({5, 5, 5, 5, 5}); + auto constexpr dataType = nvinfer1::DataType::kFLOAT; + ITensor::SharedPtr tensor(BufferManager::cpu(shape, dataType)); + auto offsetDims = ITensor::makeShape({4, 3, 3}); + auto sizeDim = 2; + auto sliced = ITensor::slice(tensor, offsetDims, sizeDim); + EXPECT_TRUE(sliced->shapeEquals({2, 5, 5})); + + auto getVolume = [](std::initializer_list const&& dims) + { return ITensor::volume(ITensor::makeShape(dims)); }; + auto slicedVolume = ITensor::volume(sliced->getShape()); + auto offset = 4 * getVolume({5, 5}) + 3 * getVolume({5}) + 3; + EXPECT_EQ(static_cast(sliced->data()) - static_cast(tensor->data()), + offset * (slicedVolume / sizeDim) * BufferDataType(dataType).getSize()); + + EXPECT_EQ(ITensor::volume(shape), getVolume({5, 5, 5}) * (slicedVolume / sizeDim)); + + EXPECT_THROW(ITensor::slice(tensor, {5, 5}, 2), std::runtime_error); + + EXPECT_THROW(ITensor::slice(tensor, {4, 3, 4}, 3), std::runtime_error); + + sliced = ITensor::slice(tensor, {4, 3, 3}, 0); + EXPECT_TRUE(sliced->shapeEquals({0, 5, 5})); + + sliced = ITensor::slice(tensor, {3}); + EXPECT_TRUE(sliced->shapeEquals({2, 5, 5, 5, 5})); + + sliced = ITensor::slice(tensor, {4, 3, 3}); + EXPECT_TRUE(sliced->shapeEquals({2, 5, 5})); + + auto theOne = ITensor::at(tensor, ITensor::makeShape({4, 3, 3})); + EXPECT_TRUE(theOne->shapeEquals({5, 5})); + + theOne = ITensor::at(tensor, {4, 3}); + EXPECT_TRUE(theOne->shapeEquals({5, 5, 5})); + + theOne = ITensor::at(tensor, {4, 4, 4, 4, 4}); + EXPECT_TRUE(theOne->shapeEquals({1})); + + ITensor::SharedConstPtr constTensor = tensor; + + auto constSliced = ITensor::slice(constTensor, {4, 3, 3}, 0); + EXPECT_TRUE(constSliced->shapeEquals({0, 5, 5})); + + constSliced = ITensor::slice(tensor, {1}); + EXPECT_TRUE(constSliced->shapeEquals({4, 5, 5, 5, 5})); + + constSliced = ITensor::slice(tensor, {4, 3, 2}); + EXPECT_TRUE(constSliced->shapeEquals({3, 5, 5})); + + auto theConstOne = ITensor::at(constTensor, ITensor::makeShape({4, 3, 3})); + EXPECT_TRUE(theConstOne->shapeEquals({5, 5})); + + theConstOne = ITensor::at(constTensor, {4, 3}); + EXPECT_TRUE(theConstOne->shapeEquals({5, 5, 5})); + + theConstOne = ITensor::at(constTensor, {4, 4, 4, 4, 4}); + EXPECT_TRUE(theConstOne->shapeEquals({1})); +} + +//! \brief Range shape in [begin, end). +class ShapeRange +{ +public: + ShapeRange(ITensor::Shape const& begin, ITensor::Shape const& end) + : mBegin(begin) + , mEnd(end) + { + TLLM_CHECK(mBegin.nbDims == mEnd.nbDims); + for (int i = 0; i < mEnd.nbDims; i++) + { + TLLM_CHECK(mBegin.d[i] <= mEnd.d[i]); + } + } + + ShapeRange(ITensor::Shape end) + : ShapeRange( + [](auto dims) + { + for (int i = 0; i < dims.nbDims; i++) + { + dims.d[i] = 0; + } + return dims; + }(end), + end) + { + } + + ShapeRange( + std::initializer_list const& begin, std::initializer_list const& end) + : ShapeRange(ITensor::makeShape(begin), ITensor::makeShape(end)) + { + } + + ShapeRange(std::initializer_list const& end) + : ShapeRange(ITensor::makeShape(end)) + { + } + + class Iterator : public std::iterator + { + friend ShapeRange; + + protected: + explicit Iterator(ITensor::Shape const& value, ShapeRange const& range) + : mValue(value) + , mRange(range) + { + } + + public: + Iterator& operator++() + { + auto counter = [](ITensor::DimType64& value, bool& carry, ITensor::DimType64 min, ITensor::DimType64 max) + { + value += carry ? 1 : 0; + carry = value == max; + }; + if (mValue.nbDims == 0) + { + return *this; + } + bool carry = true; + int i = mValue.nbDims; + do + { + i--; + counter(mValue.d[i], carry, mRange.mBegin.d[i], mRange.mEnd.d[i]); + } while (i > 0 && carry); + + if (!carry) + { + i++; + for (; i < mValue.nbDims; i++) + { + mValue.d[i] = mRange.mBegin.d[i]; + } + } + return *this; + } + + Iterator operator++(int) + { + Iterator retval = *this; + ++(*this); + return retval; + } + + bool operator==(Iterator const& other) const + { + return ITensor::shapeEquals(mValue, other.mValue); + } + + bool operator!=(Iterator const& other) const + { + return !(*this == other); + } + + reference operator*() const + { + return mValue; + } + + private: + ITensor::Shape mValue; + ShapeRange const& mRange; + }; + + Iterator begin() const + { + return Iterator(mBegin, *this); + } + + Iterator end() const + { + return Iterator(mEnd, *this); + } + +private: + ITensor::Shape const mBegin; + ITensor::Shape const mEnd; +}; + +TEST(ShapeRange, test) +{ + { + ITensor::Shape a = ITensor::makeShape({}); + ITensor::Shape b = ITensor::makeShape({}); + ShapeRange range(a, b); + EXPECT_TRUE(range.begin() == range.end()); + EXPECT_TRUE(ITensor::shapeEquals(*range.begin(), ITensor::makeShape({}))); + int count = 0; + for (auto const& v : range) + { + count++; + } + EXPECT_EQ(count, 0); + } + { + int count = 0; + for (auto const& v : ShapeRange({1, 1, 1}, {1, 1, 1})) + { + count++; + } + EXPECT_EQ(count, 0); + } + { + ITensor::Shape a = ITensor::makeShape({1, 1, 1}); + ITensor::Shape b = ITensor::makeShape({3, 3, 3}); + ShapeRange range(a, b); + EXPECT_TRUE(range.begin() != range.end()); + EXPECT_TRUE(ITensor::shapeEquals(a, *range.begin())); + EXPECT_TRUE(ITensor::shapeEquals(b, *range.end())); + int count = 0; + for (auto const& v : range) + { + count++; + } + EXPECT_EQ(count, 8); + } + { + int count = 0; + for (auto const& v : ShapeRange({2, 2, 2, 2})) + { + count++; + } + EXPECT_EQ(count, 16); + } + { + EXPECT_THROW(ShapeRange({0}, {1, 1}), std::runtime_error); + EXPECT_THROW(ShapeRange({2, 2}, {1, 1}), std::runtime_error); + } +} + +TEST(ITensorTest, TensorDimsSliceAt) +{ + auto shape = ITensor::makeShape({5, 5, 5, 5}); + auto constexpr dataType = nvinfer1::DataType::kFLOAT; + ITensor::SharedPtr tensor(BufferManager::cpu(shape, dataType)); + + auto verify = [&shape, &tensor, &dataType](ITensor::Shape const& index) + { + auto blockAt = ITensor::at(tensor, index); + auto blockSliceRest = ITensor::slice(tensor, index); + auto blockSliceOne = ITensor::slice(tensor, index, 1); + auto blockSliceTwo = (shape.d[index.nbDims - 1] - index.d[index.nbDims - 1] >= 2) + ? std::make_optional(ITensor::slice(tensor, index, 2)) + : [&tensor, &index]() + { + EXPECT_THROW(ITensor::slice(tensor, index, 2), std::runtime_error); + return std::nullopt; + }(); + + { + auto strides = ITensor::strides(tensor->getShape()); + ITensor::DimType64 offset = 0; + for (SizeType32 i = 0; i < index.nbDims; i++) + { + offset += index.d[i] * strides.d[i]; + } + offset *= BufferDataType(dataType).getSize(); + auto base = static_cast(tensor->data()); + EXPECT_EQ(static_cast(blockAt->data()) - base, offset); + EXPECT_EQ(static_cast(blockSliceRest->data()) - base, offset); + EXPECT_EQ(static_cast(blockSliceOne->data()) - base, offset); + if (blockSliceTwo) + { + EXPECT_EQ(static_cast(blockSliceTwo.value()->data()) - base, offset); + } + } + { + auto blockShape = blockAt->getShape(); + ITensor::Shape goldenShape = ITensor::makeShape({1}); + if (shape.nbDims > index.nbDims) + { + goldenShape.nbDims = shape.nbDims - index.nbDims; + for (SizeType32 i = 0; i < goldenShape.nbDims; i++) + { + goldenShape.d[i] = shape.d[i + index.nbDims]; + } + } + EXPECT_TRUE(ITensor::shapeEquals(blockShape, goldenShape)); + } + { + auto blockShape = blockSliceRest->getShape(); + ITensor::Shape goldenShape; + goldenShape.nbDims = shape.nbDims - index.nbDims + 1; + goldenShape.d[0] = shape.d[0 + index.nbDims - 1] - index.d[0 + index.nbDims - 1]; + for (SizeType32 i = 1; i < goldenShape.nbDims; i++) + { + goldenShape.d[i] = shape.d[i + index.nbDims - 1]; + } + EXPECT_TRUE(ITensor::shapeEquals(blockShape, goldenShape)); + } + { + auto blockShape = blockSliceOne->getShape(); + ITensor::Shape goldenShape; + goldenShape.nbDims = shape.nbDims - index.nbDims + 1; + goldenShape.d[0] = 1; + for (SizeType32 i = 1; i < goldenShape.nbDims; i++) + { + goldenShape.d[i] = shape.d[i + index.nbDims - 1]; + } + EXPECT_TRUE(ITensor::shapeEquals(blockShape, goldenShape)); + } + if (blockSliceTwo) + { + auto blockShape = blockSliceTwo.value()->getShape(); + ITensor::Shape goldenShape; + goldenShape.nbDims = shape.nbDims - index.nbDims + 1; + goldenShape.d[0] = 2; + for (SizeType32 i = 1; i < goldenShape.nbDims; i++) + { + goldenShape.d[i] = shape.d[i + index.nbDims - 1]; + } + EXPECT_TRUE(ITensor::shapeEquals(blockShape, goldenShape)); + } + }; + + for (auto const& range : {ShapeRange({5}), ShapeRange({5, 5}), ShapeRange({5, 5, 5}), ShapeRange({5, 5, 5, 5})}) + { + for (auto const& index : range) + { + verify(index); + } + } + + for (auto& range : {ShapeRange({4}, {7}), ShapeRange({4, 4}, {7, 7}), ShapeRange({4, 4, 4}, {7, 7, 7}), + ShapeRange({4, 4, 4, 4}, {7, 7, 7, 7})}) + { + auto it = range.begin(); + for (it++; it != range.end(); ++it) + { + EXPECT_THROW(ITensor::at(tensor, *it), std::runtime_error); + EXPECT_THROW(ITensor::slice(tensor, *it), std::runtime_error); + EXPECT_THROW(ITensor::slice(tensor, *it, 1), std::runtime_error); + } + } +} + +TEST(BufferRangeTest, ConstType) +{ + auto shape = ITensor::makeShape({5, 5, 5, 5, 5}); + auto constexpr dataType = nvinfer1::DataType::kFLOAT; + ITensor::SharedPtr tensor(BufferManager::cpu(shape, dataType)); + ITensor::SharedConstPtr tensorConst = tensor; + + //// 1 //////////////////////////////////// + BufferRange tensor_RangeConst_WONT_ASSIGN(*tensor); + // tensorRange_WONT_ASSIGN[0] = 3.14159; + + //// 2 //////////////////////////////////// + BufferRange tensorRange(*tensor); + tensorRange[0] = 3.14159; + + //// 3 //////////////////////////////////// + // BufferRange TensorConst_Range_WONT_COMPILE(*tensorConst); + + //// 4 //////////////////////////////////// + BufferRange tensorConst_RangeConst_WONT_ASSIGN(*tensorConst); + // theConstOnerange_WONT_ASSIGN[0] = 1.1; + + BufferRange tensor_RangeConst(*tensor); + BufferRange tensorConst_RangeConst(*tensorConst); + + float acc = 3.14159; + for (auto& v : tensorRange) + { + v = acc * acc + 1.0; + } + for (SizeType32 i = 0; i < tensorRange.size(); i++) + { + EXPECT_EQ(tensorRange[i], tensor_RangeConst[i]); + EXPECT_EQ(tensorRange[i], tensorConst_RangeConst[i]); + } +} diff --git a/cpp/tests/runtime/samplingTest.cpp b/cpp/tests/runtime/samplingTest.cpp index 804e61bb3..b42153663 100644 --- a/cpp/tests/runtime/samplingTest.cpp +++ b/cpp/tests/runtime/samplingTest.cpp @@ -71,7 +71,6 @@ std::shared_ptr dynamicDecodeTest(BufferManager& int* gpuOutputIds = nullptr; int* gpuSequenceLengths = nullptr; int* gpuNewTokens = nullptr; - int* gpuNoRepeatNgramSize = nullptr; tk::FinishedState::UnderlyingType* gpuFinished = nullptr; gpuLogits = allocator->reMalloc(gpuLogits, batchSize * beamWidth * vocabSizePadded * sizeof(float)); @@ -79,7 +78,6 @@ std::shared_ptr dynamicDecodeTest(BufferManager& gpuOutputIds = allocator->reMalloc(gpuOutputIds, batchSize * beamWidth * maxSeqLength * sizeof(int)); gpuSequenceLengths = allocator->reMalloc(gpuSequenceLengths, batchSize * sizeof(int)); gpuNewTokens = allocator->reMalloc(gpuNewTokens, batchSize * beamWidth * sizeof(int)); - gpuNoRepeatNgramSize = allocator->reMalloc(gpuNoRepeatNgramSize, batchSize * sizeof(int)); gpuFinished = allocator->reMalloc(gpuFinished, batchSize * beamWidth * sizeof(tk::FinishedState::UnderlyingType)); cudaMemcpy(gpuLogits, cpuLogits.data(), cpuLogits.size() * sizeof(float), cudaMemcpyHostToDevice); @@ -87,15 +85,12 @@ std::shared_ptr dynamicDecodeTest(BufferManager& cudaMemcpy( gpuSequenceLengths, cpuSequenceLengths.data(), cpuSequenceLengths.size() * sizeof(int), cudaMemcpyHostToDevice); cudaMemcpy(gpuOutputIds, cpuOutputIds.data(), cpuOutputIds.size() * sizeof(int), cudaMemcpyHostToDevice); - cudaMemcpy(gpuNoRepeatNgramSize, cpuNoRepeatNgramSize.data(), cpuNoRepeatNgramSize.size() * sizeof(int), - cudaMemcpyHostToDevice); tc::Tensor logits{tc::MEMORY_GPU, tc::TYPE_FP32, {batchSize, beamWidth, vocabSizePadded}, gpuLogits}; tc::Tensor endIds{tc::MEMORY_GPU, tc::TYPE_INT32, {batchSize}, gpuEndIds}; tc::Tensor outputIds{tc::MEMORY_GPU, tc::TYPE_INT32, {batchSize, beamWidth, maxSeqLength}, gpuOutputIds}; tc::Tensor sequenceLengths{tc::MEMORY_GPU, tc::TYPE_INT32, {batchSize}, gpuSequenceLengths}; tc::Tensor newTokens{tc::MEMORY_GPU, tc::TYPE_INT32, {batchSize}, gpuNewTokens}; - tc::Tensor noRepeatNgramSize{tc::MEMORY_GPU, tc::TYPE_INT32, {batchSize}, gpuNoRepeatNgramSize}; tc::Tensor finished{tc::MEMORY_GPU, tc::TYPE_INT8, {batchSize, beamWidth}, gpuFinished}; auto const decodingMode = beamWidth == 1 ? tle::DecodingMode::TopKTopP() : tle::DecodingMode::BeamSearch(); @@ -103,13 +98,12 @@ std::shared_ptr dynamicDecodeTest(BufferManager& auto ddLayer = tl::DynamicDecodeLayer(decodingMode, decodingDomain, manager.getStream().get(), allocator); auto setupParams = std::make_shared(); - + setupParams->penaltyParams.noRepeatNgramSize = cpuNoRepeatNgramSize; ddLayer.setup(batchSize, beamWidth, nullptr, setupParams); auto forwardParams = std::make_shared( step, ite, maxInputLength, static_cast(maxSeqLength), sinkTokenLength, localBatchSize, endIds); forwardParams->logits = logits; - forwardParams->no_repeat_ngram_size = noRepeatNgramSize; auto outputParams = std::make_shared(outputIds); outputParams->sequence_length = sequenceLengths; diff --git a/cpp/tests/runtime/tllmBuffersTest.cpp b/cpp/tests/runtime/tllmBuffersTest.cpp index 07c476892..d0f2958a5 100644 --- a/cpp/tests/runtime/tllmBuffersTest.cpp +++ b/cpp/tests/runtime/tllmBuffersTest.cpp @@ -469,11 +469,9 @@ TEST_F(TllmBuffersTest, PinnedPoolAllocator) EXPECT_NE(it->tag, nullptr); EXPECT_EQ(it->size, expectedSize(c)); it = std::next(it); - EXPECT_EQ(it->tag, nullptr); - secondChunkSize = expectedSize(c) + it->size; - EXPECT_EQ(secondChunkSize, pool.getChunkSize()); - it = std::next(it); EXPECT_EQ(it, std::end(segments)); + secondChunkSize = expectedSize(c); + EXPECT_EQ(secondChunkSize, pool.getChunkSize()); } { diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index 0f485a124..4458be56d 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -64,19 +64,30 @@ COPY tensorrt_llm tensorrt_llm COPY 3rdparty 3rdparty COPY setup.py requirements.txt requirements-dev.txt ./ +# Create cache directories for pip and ccache +RUN mkdir -p /root/.cache/pip /root/.cache/ccache +ENV CCACHE_DIR=/root/.cache/ccache +# Build the TRT-LLM wheel ARG BUILD_WHEEL_ARGS="--clean --trt_root /usr/local/tensorrt --python_bindings --benchmarks" -RUN python3 scripts/build_wheel.py ${BUILD_WHEEL_ARGS} +RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/ccache \ + python3 scripts/build_wheel.py ${BUILD_WHEEL_ARGS} FROM ${DEVEL_IMAGE} as release +# Create a cache directory for pip +RUN mkdir -p /root/.cache/pip + WORKDIR /app/tensorrt_llm COPY --from=wheel /src/tensorrt_llm/build/tensorrt_llm*.whl . -RUN pip install tensorrt_llm*.whl --extra-index-url https://pypi.nvidia.com && \ +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install tensorrt_llm*.whl --extra-index-url https://pypi.nvidia.com && \ rm tensorrt_llm*.whl COPY README.md ./ COPY docs docs COPY cpp/include include -RUN ln -sv $(python3 -c 'import site; print(f"{site.getsitepackages()[0]}/tensorrt_llm/libs")') lib && \ +RUN ln -sv $(python3 -c 'import site; print(f"{site.getsitepackages()[0]}/tensorrt_llm/bin")') bin && \ + test -f bin/executorWorker && \ + ln -sv $(python3 -c 'import site; print(f"{site.getsitepackages()[0]}/tensorrt_llm/libs")') lib && \ test -f lib/libnvinfer_plugin_tensorrt_llm.so && \ ln -sv lib/libnvinfer_plugin_tensorrt_llm.so lib/libnvinfer_plugin_tensorrt_llm.so.9 && \ echo "/app/tensorrt_llm/lib" > /etc/ld.so.conf.d/tensorrt_llm.conf && \ diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index bc4ce0cd5..2fa035e33 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -11,4 +11,11 @@ ARG GROUP_NAME=root RUN (getent group ${GROUP_ID} || groupadd --gid ${GROUP_ID} ${GROUP_NAME}) && \ (getent passwd ${USER_ID} || useradd --gid ${GROUP_ID} --uid ${USER_ID} --create-home --no-log-init --shell /bin/bash ${USER_NAME}) +RUN apt-get update && \ + apt-get install -y sudo && \ + adduser ${USER_NAME} sudo && \ + echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + USER ${USER_NAME} diff --git a/docker/Makefile b/docker/Makefile index c947d1ba5..1c9066274 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -17,6 +17,9 @@ ifeq ($(LOCAL_USER),1) IMAGE_TAG_SUFFIX ?= -$(USER_NAME) endif +# Set this to 1 to use the image from Jenkins as the image for the `devel` stage in the build phase +JENKINS_DEVEL ?= 0 + # Default stage of the docker multi-stage build STAGE ?= # Set this to define a custom image name and tag @@ -26,20 +29,19 @@ DOCKER_BUILD_OPTS ?= --pull DOCKER_BUILD_ARGS ?= DOCKER_PROGRESS ?= auto CUDA_ARCHS ?= -BUILD_WHEEL_ARGS ?= $(shell grep 'ARG BUILD_WHEEL_ARGS=' Dockerfile.multi | grep -o '=.*' | tr -d '="')$(if $(CUDA_ARCHS), --cuda_architectures $(CUDA_ARCHS)) +BUILD_WHEEL_OPTS ?= +BUILD_WHEEL_ARGS ?= $(shell grep 'ARG BUILD_WHEEL_ARGS=' Dockerfile.multi | grep -o '=.*' | tr -d '="')$(if $(CUDA_ARCHS), --cuda_architectures $(CUDA_ARCHS))$(if $(BUILD_WHEEL_OPTS), $(BUILD_WHEEL_OPTS)) TORCH_INSTALL_TYPE ?= skip CUDA_VERSION ?= CUDNN_VERSION ?= NCCL_VERSION ?= CUBLAS_VERSION ?= TRT_VERSION ?= -DEVEL_IMAGE ?= GIT_COMMIT ?= $(shell git rev-parse HEAD) TRT_LLM_VERSION ?= $(shell grep '^__version__' ../tensorrt_llm/version.py | grep -o '=.*' | tr -d '= "') define add_local_user docker build \ - --progress $(DOCKER_BUILD_OPTS) $(DOCKER_BUILD_ARGS) \ --progress $(DOCKER_PROGRESS) \ --build-arg BASE_IMAGE_WITH_TAG=$(1) \ --build-arg USER_ID=$(USER_ID) \ @@ -56,6 +58,7 @@ define rewrite_tag $(shell echo $(IMAGE_WITH_TAG) | sed "s/\/tensorrt-llm:/\/tensorrt-llm-staging:/g") endef +%_build: DEVEL_IMAGE = $(if $(findstring 1,$(JENKINS_DEVEL)),$(shell grep 'LLM_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')) %_build: @echo "Building docker image: $(IMAGE_WITH_TAG)" DOCKER_BUILDKIT=1 docker build $(DOCKER_BUILD_OPTS) $(DOCKER_BUILD_ARGS) \ diff --git a/docs/source/_templates/footer.html b/docs/source/_templates/footer.html new file mode 100644 index 000000000..164c30cef --- /dev/null +++ b/docs/source/_templates/footer.html @@ -0,0 +1,29 @@ +{% extends "!footer.html" %} +{%- block contentinfo %} +{{ super }} + + + +{% endblock %} diff --git a/docs/source/advanced/gpt-runtime.md b/docs/source/advanced/gpt-runtime.md index 08f82034e..65d76500e 100644 --- a/docs/source/advanced/gpt-runtime.md +++ b/docs/source/advanced/gpt-runtime.md @@ -122,7 +122,7 @@ value for a given parameter, the vector can be limited to a single element ***General*** * `temperature`, a vector of floating-point numbers to control the - modulation of logits when sampling new tokens. It can have any value `> 0.0f`. The default value is `1.0f`(no modulation). Note: the recommended way to enable greedy sampling is to set `temperature` to `1.0f` and `topK` to `1`. + modulation of logits when sampling new tokens. It can have any value `>= 0.0f`. The default value is `1.0f`(no modulation). * `minLength`, a vector of integers to set a lower-bound on the number of tokens generated. It can have any value `>= 0`. Value `0` has no effect, the first generated token can be EOS. The default value is `1` (at least one non-EOS token is generated). * `repetitionPenalty`, a vector of float-point numbers to penalize tokens @@ -133,6 +133,7 @@ value for a given parameter, the vector can be limited to a single element * `frequencyPenalty`, a vector of float-point numbers to penalize tokens already present in the sequence (dependent on the number of appearances). It is additive penalty. It can have any value, values `< 0.0f` encourage repetition, `> 0.0f` discourage it. The default value is `0.0f`(no effect). + * `noRepeatNgramSize`, a vector of integers. If set to int > 0, all ngrams of that size can only occur once. The parameters `repetitionPenalty`, `presencePenalty`, and `frequencyPenalty` are not mutually exclusive. diff --git a/docs/source/advanced/inference-request.md b/docs/source/advanced/inference-request.md index 6de596d31..3560608dd 100644 --- a/docs/source/advanced/inference-request.md +++ b/docs/source/advanced/inference-request.md @@ -24,6 +24,7 @@ Optional tensors that can be supplied to `InferenceRequest` are shown below. Def | `min_length` | [1] | `int32_t` | Sampling Config param: `minLength` | | `presence_penalty` | [1] | `float` | Sampling Config param: `presencePenalty` | | `frequency_penalty` | [1] | `float` | Sampling Config param: `frequencyPenalty` | +| `no_repeat_ngram_size` | [1] | `int32_t` | Sampling Config param: `noRepeatNgramSize` | | `random_seed` | [1] | `uint64_t` | Sampling Config param: `randomSeed` | | `end_id` | [1] | `int32_t` | End token Id. If not specified, defaults to -1 | | `pad_id` | [1] | `int32_t` | Pad token Id | diff --git a/docs/source/advanced/weight-streaming.md b/docs/source/advanced/weight-streaming.md index 0c796baca..9e8078ce3 100644 --- a/docs/source/advanced/weight-streaming.md +++ b/docs/source/advanced/weight-streaming.md @@ -6,7 +6,7 @@ TensorRT Weight Streaming can offload some weights to the CPU memory and stream This can reduce the weights size in GPU memory, therefore, we can run larger models or larger batch sizes in the same GPU memory budget. -During build time, build the engine with `--weight-streaming --strongly_typed --gemm_plugin disable` since Weight Streaming only supports strongly typed models and non-plugin weights. During runtime, run with `--gpu_weights_percent x` to config the percent of weights that remained on the GPU. `x` can be a value from `0.0` to `1.0`. +During build time, build the engine with `--weight-streaming --gemm_plugin disable` since Weight Streaming only supports strongly typed models and non-plugin weights. During runtime, run with `--gpu_weights_percent x` to config the percent of weights that remained on the GPU. `x` can be a value from `0.0` to `1.0`. Here is an example to run llama-7b with Weight Streaming: ```bash @@ -22,7 +22,6 @@ trtllm-build \ --checkpoint_dir /tmp/llama_7b/trt_ckpt/fp16/1-gpu/ \ --output_dir /tmp/llama_7b/trt_engines/fp16/1-gpu/ \ --weight_streaming \ - --strongly_typed \ --gemm_plugin disable \ --max_batch_size 128 \ --max_input_len 512 \ @@ -51,7 +50,6 @@ python3 benchmarks/python/benchmark.py \ --max_output_len 32 \ --gpu_weights_percent "0.0;0.3;0.6;1.0" \ --dtype float16 \ - --strongly_typed \ --csv \ --log_level verbose diff --git a/docs/source/conf.py b/docs/source/conf.py index 511d76299..2455ed8ee 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -14,10 +14,10 @@ sys.path.insert(0, os.path.abspath('../..')) project = 'tensorrt_llm' -copyright = '2023, NVidia' +copyright = '2024, NVidia' author = 'NVidia' branch_name = pygit2.Repository('.').head.shorthand - +html_show_sphinx = False # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/installation/windows.md b/docs/source/installation/windows.md index e1fcf1cf0..dbef47a29 100644 --- a/docs/source/installation/windows.md +++ b/docs/source/installation/windows.md @@ -21,9 +21,11 @@ We recommend checking out the [v0.10.0 tag](https://github.com/NVIDIA/TensorRT-L ./setup_env.ps1 [-skipCUDA] [-skipPython] ``` + 2. Close and re-open any existing PowerShell or Git Bash windows so they pick up the new `Path` modified by the `setup_env.ps1` script above. + 2. Install the dependencies one at a time. - 1. Install [Python 3.10](https://www.python.org/downloads/windows/). + 1. Install [Python 3.10](https://www.python.org/ftp/python/3.10.11/python-3.10.11-amd64.exe). 1. Select **Add python.exe to PATH** at the start of the installation. The installation may only add the `python` command, but not the `python3` command. 2. Navigate to the installation path `%USERPROFILE%\AppData\Local\Programs\Python\Python310` (`AppData` is a hidden folder) and copy `python.exe` to `python3.exe`. diff --git a/docs/source/performance/perf-best-practices.md b/docs/source/performance/perf-best-practices.md index 698fee476..db4695be3 100644 --- a/docs/source/performance/perf-best-practices.md +++ b/docs/source/performance/perf-best-practices.md @@ -179,6 +179,19 @@ by using the `--use_fused_mlp` argument with `trtllm-build`. When the workload is very small, or if you're using FP8 PTQ and the accuracy after enabling it does not satisfy your requirement, it is not recommended to enable that feature. +### GEMM + SwiGLU Fusion in Gated-MLP + +GEMM + SwiGLU fusion in Gated-MLP combines two Matmul operations and one SwiGLU +operation into a single kernel. It only supports FP8 on Hopper now. For FP8 PTQ, +the downside is slight reduction of accuracy because one of the quantization +scaling factors are discarded. + +If model is large and you are running it on Hopper with FP8 precision, it is +recommended to enable the feature by using the `--use_fused_mlp --gemm_swiglu_plugin fp8` +argument with `trtllm-build`. When the workload is very small, or the accuracy +after enabling it does not satisfy your requirement, it is not recommended to +enable that feature. + ### GEMM Plugin The GEMM plugin utilizes NVIDIA cuBLASLt to perform GEMM operations. On FP16 and diff --git a/docs/source/performance/perf-overview.md b/docs/source/performance/perf-overview.md index 0800b6696..76e99ff9b 100644 --- a/docs/source/performance/perf-overview.md +++ b/docs/source/performance/perf-overview.md @@ -390,7 +390,6 @@ trtllm-build --model_config /tmp/engines/gptj/ckpt_config.json \ --max_batch_size 64 \ --max_input_len 2048 \ --max_output_len 2048 \ - --strongly_typed ``` #### Throughput Benchmark @@ -460,7 +459,6 @@ trtllm-build --model_config /tmp/engines/llama/7b/ckpt_config.json \ --max_batch_size 64 \ --max_input_len 2048 \ --max_output_len 2048 \ - --strongly_typed ``` #### Throughput Benchmark @@ -534,7 +532,6 @@ trtllm-build --model_config /tmp/engines/llama/70b/ckpt_config.json \ --max_batch_size 64 \ --max_input_len 2048 \ --max_output_len 2048 \ - --strongly_typed ``` #### Throughput Benchmark @@ -631,7 +628,6 @@ do --max_batch_size $batch_size \ --max_input_len $isl \ --max_output_len $osl \ - --strongly_typed # Throughput benchmark mpirun -n 8 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir $engine_path --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len "${isl},${osl}" diff --git a/docs/source/reference/precision.md b/docs/source/reference/precision.md index 9aa185f0b..df1ea5731 100644 --- a/docs/source/reference/precision.md +++ b/docs/source/reference/precision.md @@ -134,6 +134,7 @@ This release of TensorRT-LLM contains the following examples: | GPT-NeMo | Y | Y | Y | . | . | . | . | . | . | | GPT-NeoX | Y | Y | Y | . | . | . | . | . | Y | | InternLM | Y | Y | Y | . | Y | Y | Y | . | . | +| InternLM2 | Y | Y | Y | . | . | . | . | . | . | | LLaMA | Y | Y | Y | Y | Y | Y | Y | Y | Y | | LLaMA-v2 | Y | Y | Y | Y | Y | Y | Y | Y | Y | | Mamba | Y | Y | Y | . | . | . | . | . | . | diff --git a/docs/source/reference/support-matrix.md b/docs/source/reference/support-matrix.md index 0c169dcd5..03fdaa43b 100644 --- a/docs/source/reference/support-matrix.md +++ b/docs/source/reference/support-matrix.md @@ -73,6 +73,7 @@ The following table shows the supported software for TensorRT-LLM. - [GPT-Nemo](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/gpt) - [GPT-NeoX](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/gptneox) - [InternLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/internlm) + - [InternLM2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/internlm2) - [LLaMA](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama) - [LLaMA-v2](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama) - [Mamba](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/mamba) diff --git a/examples/arctic/README.md b/examples/arctic/README.md index af24794ee..826c04ec4 100644 --- a/examples/arctic/README.md +++ b/examples/arctic/README.md @@ -70,7 +70,6 @@ trtllm-build --checkpoint_dir ./tmp/tllm_checkpoints/${ENGINE} \ --output_dir ./tmp/trt_engines/${ENGINE} \ --gpt_attention_plugin ${PREC_RAW} \ --gemm_plugin ${PREC_RAW} \ - --strongly_typed \ --workers ${TP} |& tee tmp/trt_engines/${ENGINE}_build.log ``` diff --git a/examples/baichuan/README.md b/examples/baichuan/README.md index b3e1d5c3c..18abcf0f8 100644 --- a/examples/baichuan/README.md +++ b/examples/baichuan/README.md @@ -219,7 +219,6 @@ python ../quantization/quantize.py --model_dir baichuan-inc/Baichuan-13B-Chat \ --kv_cache_dtype int8 \ --output_dir ./trt_ckpt/baichuan_int4wo_int8kv_tp1 \ --calib_size 512 \ - --strongly_typed ``` **INT8 KV cache + AWQ** diff --git a/examples/baichuan/requirements.txt b/examples/baichuan/requirements.txt index 86fb26f0f..a2874e20c 100644 --- a/examples/baichuan/requirements.txt +++ b/examples/baichuan/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.15.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/bindings/executor/example_advanced.py b/examples/bindings/executor/example_advanced.py index 9bda42379..28fa028a6 100644 --- a/examples/bindings/executor/example_advanced.py +++ b/examples/bindings/executor/example_advanced.py @@ -67,8 +67,7 @@ def wait_for_responses(args: argparse.Namespace, request_ids: list[int], output_tokens[req_id][beam].extend(outTokens) else: raise RuntimeError( - str(req_id) + " encountered error:" + - response.get_error_msg()) + str(req_id) + " encountered error:" + response.error_msg) return output_tokens diff --git a/examples/bloom/README.md b/examples/bloom/README.md index 7571a89ab..c6506ff8d 100644 --- a/examples/bloom/README.md +++ b/examples/bloom/README.md @@ -158,7 +158,6 @@ python convert_checkpoint.py --model_dir ./bloom/560m/ \ trtllm-build --checkpoint_dir ./bloom/560m/trt_ckpt/int8/1-gpu/ \ --gemm_plugin float16 \ --output_dir ./bloom/560m/trt_engines/int8/1-gpu/ \ - --strongly_typed ``` @@ -206,7 +205,6 @@ trtllm-build --checkpoint_dir /tmp/bloom/3b/trt_ckpts/fp8/1-gpu/ \ --output_dir /tmp/bloom/3b/trt_engines/fp8/1-gpu/ \ --gemm_plugin float16 \ --use_fp8_context_fmha enable \ - --strongly_typed \ --workers 1 ``` diff --git a/examples/bloom/requirements.txt b/examples/bloom/requirements.txt index 94f6e6fbb..56c8b895c 100644 --- a/examples/bloom/requirements.txt +++ b/examples/bloom/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/chatglm/README.md b/examples/chatglm/README.md index d7eead688..26e8911e6 100644 --- a/examples/chatglm/README.md +++ b/examples/chatglm/README.md @@ -363,7 +363,6 @@ python ../quantization/quantize.py --model_dir chatglm3_6b \ # ChatGLM3-6B: single-gpu engine with fp8 quantization, GPT Attention plugin, Gemm plugin trtllm-build --checkpoint_dir trt_ckpt/chatglm3_6b/fp8/1-gpu \ --gemm_plugin float16 \ - --strongly_typed \ --output_dir trt_engines/chatglm3_6b/fp8/1-gpu # Run inference. @@ -394,7 +393,6 @@ python examples/chatglm/convert_checkpoint.py --model_dir chatglm3-6b-128k \ python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/chatglm3-6b-128k/trt_ckpts \ --output_dir /tmp/chatglm3-6b-128k/trt_engines \ --gemm_plugin float16 \ - --strongly_typed \ --gather_all_token_logits \ --max_batch_size 8 \ --max_input_len 25600 @@ -431,7 +429,6 @@ python examples/chatglm/convert_checkpoint.py --model_dir chatglm3-6b-128k \ python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/chatglm3-6b-128k/trt_ckpts \ --output_dir /tmp/chatglm3-6b-128k/trt_engines \ --gemm_plugin float16 \ - --strongly_typed \ --gather_all_token_logits \ --max_batch_size 1 \ --max_input_len 12800 diff --git a/examples/chatglm/requirements.txt b/examples/chatglm/requirements.txt index bdf5cb448..5a7ab552f 100644 --- a/examples/chatglm/requirements.txt +++ b/examples/chatglm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 protobuf diff --git a/examples/cogvlm/convert_checkpoint.py b/examples/cogvlm/convert_checkpoint.py index 2eba91908..ba5ce1595 100644 --- a/examples/cogvlm/convert_checkpoint.py +++ b/examples/cogvlm/convert_checkpoint.py @@ -12,17 +12,16 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import tensorrt_llm -from tensorrt_llm.layers import MoeConfig from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import PretrainedConfig from tensorrt_llm.models.cogvlm.convert import convert_hf_cogvlm from tensorrt_llm.models.convert_utils import load_calib_dataset from tensorrt_llm.models.llama.convert import (capture_activation_range, + load_weights_from_gptq, + load_weights_from_hf_by_shard, + load_weights_from_meta_ckpt, smooth_llama_model) -from tensorrt_llm.models.llama.weight import (load_from_gptq_llama, - load_from_hf_checkpoint, - load_from_meta_llama) -from tensorrt_llm.models.modeling_utils import PretrainedConfig try: from transformers import LlavaConfig, LlavaForConditionalGeneration @@ -204,32 +203,6 @@ def parse_arguments(): type=int, default=1, help='The number of workers for converting checkpoint in parallel') - parser.add_argument( - '--moe_num_experts', - default=0, - type=int, - help='Specify the number of experts to use for MOE layers') - parser.add_argument( - '--moe_top_k', - default=0, - type=int, - help= - 'Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set' - ) - parser.add_argument( - '--moe_tp_mode', - default=MoeConfig.ParallelismMode.TENSOR_PARALLEL, - type=int, - help= - 'Controls how to distribute experts in TP. Check layers/moe.py for accepted values', - ) - parser.add_argument( - '--moe_renorm_mode', - default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE, - type=int, - help= - 'Controls renormalization after gate logits. Check layers/moe.py for accepted values', - ) parser.add_argument('--enable_pos_shift', default=False, action='store_true', @@ -271,9 +244,6 @@ def update_quantization_from_args(config: dict, args: argparse.Namespace): config['quantization'][ 'quant_algo'] = 'W8A8_SQ_PER_TENSOR_PLUGIN' - if args.use_weight_only and args.moe_config.has_moe(): - config['quantization']['exclude_modules'].append('router') - if args.int8_kv_cache: config['quantization']['kv_cache_quant_algo'] = 'INT8' @@ -324,10 +294,6 @@ def create_config_from_args(args: argparse.Namespace): 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, 'use_prompt_tuning': args.use_prompt_tuning, - 'moe_num_experts': args.moe_num_experts, - 'moe_top_k': args.moe_top_k, - 'moe_tp_mode': args.moe_tp_mode, - 'moe_normalization_mode': args.moe_renorm_mode, 'enable_pos_shift': args.enable_pos_shift, 'dense_context_fmha': args.dense_context_fmha, } @@ -405,16 +371,7 @@ def main(): args.rms_norm_eps = hf_config.rms_norm_eps args.vocab_size = hf_config.vocab_size args.n_positions = hf_config.max_position_embeddings - if hf_config.model_type == "mixtral": - # HF LLaMA-type models are implicitly using gated activation. - # With our MoE implementation, we must make it explicit - args.hidden_act = "swiglu" - args.moe_num_experts = getattr(hf_config, "num_local_experts", - args.moe_num_experts) - args.moe_top_k = getattr(hf_config, "num_experts_per_tok", - args.moe_top_k) - args.rotary_base = getattr(hf_config, "rope_theta", - args.rotary_base) + args.architecture = hf_config.architectures[0] args.vision_start = 1 args.vision_length = hf_config.vision_config['num_positions'] - 1 @@ -437,20 +394,11 @@ def main(): (int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1) // args.multiple_of) args.rms_norm_eps = meta_config["norm_eps"] - args.moe_num_experts = meta_config.get("moe", {}).get("num_experts", 0) - args.moe_top_k = meta_config.get("moe", {}).get("num_experts_per_tok", - 0) args.architecture = "LlamaForCausalLM" else: args.n_kv_head = args.n_kv_head or args.n_head args.architecture = "LlamaForCausalLM" - if args.moe_num_experts and args.moe_top_k == 0: - args.moe_top_k = 1 - args.moe_config = MoeConfig(args.moe_num_experts, args.moe_top_k, - args.moe_tp_mode, - args.moe_renorm_mode).validate() - if args.rotary_scaling is not None: # assert args.use_gpt_attention_plugin, "RoPE scaling is only supported through GPT attention plugin." rotary_scaling = { @@ -495,22 +443,23 @@ def covert_and_save(rank): pp_size=args.pp_size) if args.use_weight_only and args.weight_only_precision == 'int4_gptq': - weights = load_from_gptq_llama(args.modelopt_quant_ckpt_path, - args.n_layer, - args.vocab_size, - mapping, - dtype=args.dtype) + weights = load_weights_from_gptq( + args.modelopt_quant_ckpt_path, + PretrainedConfig.from_dict(copy.deepcopy(config)), + ) elif args.meta_ckpt_dir is not None: - weights = load_from_meta_llama( - args.meta_ckpt_dir, mapping, - PretrainedConfig.from_dict(copy.deepcopy(config))) + weights = load_weights_from_meta_ckpt( + args.meta_ckpt_dir, + PretrainedConfig.from_dict(copy.deepcopy(config)), + ) else: if args.load_by_shard: - weights = load_from_hf_checkpoint( - args.model_dir, mapping, - PretrainedConfig.from_dict(copy.deepcopy(config))) + weights = load_weights_from_hf_by_shard( + args.model_dir, + PretrainedConfig.from_dict(copy.deepcopy(config)), + ) else: if args.weight_only_precision == 'int8': @@ -535,8 +484,7 @@ def covert_and_save(rank): int8_kv_cache=args.int8_kv_cache, act_range=act_range, qkv_para=llama_qkv_para, - smoother=llama_smoother, - moe_config=args.moe_config) + smoother=llama_smoother) safetensors.torch.save_file( weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) diff --git a/examples/dbrx/README.md b/examples/dbrx/README.md index fd115e580..97539f291 100644 --- a/examples/dbrx/README.md +++ b/examples/dbrx/README.md @@ -186,7 +186,6 @@ trtllm-build --checkpoint_dir dbrx/trt_ckpt/int8kv/tp4 \ --gpt_attention_plugin float16 \ --gemm_plugin float16 \ --moe_plugin float16 \ - --strongly_typed \ --workers 4 \ --output_dir dbrx/trt_engines/int8kv/tp4 ``` diff --git a/examples/dbrx/convert_checkpoint.py b/examples/dbrx/convert_checkpoint.py index 06e3c0894..58de73320 100644 --- a/examples/dbrx/convert_checkpoint.py +++ b/examples/dbrx/convert_checkpoint.py @@ -689,7 +689,7 @@ def execute(workers, func, hf_model): 'block_embedding' ], }, - 'moe_config': { + 'moe': { "num_experts": args.moe_num_experts, "top_k": args.moe_top_k, "tp_mode": args.moe_tp_mode, @@ -701,10 +701,6 @@ def execute(workers, func, hf_model): 'pp_size': args.pp_size, }, 'clip_qkv': args.clip_qkv, - 'moe_num_experts': args.moe_num_experts, - 'moe_top_k': args.moe_top_k, - 'moe_tp_mode': args.moe_tp_mode, - 'moe_normalization_mode': args.moe_renorm_mode, 'dense_context_fmha': args.dense_context_fmha, } diff --git a/examples/dbrx/requirements.txt b/examples/dbrx/requirements.txt index bbfe74c71..baf4457b4 100644 --- a/examples/dbrx/requirements.txt +++ b/examples/dbrx/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/enc_dec/README.md b/examples/enc_dec/README.md index fd6b1802a..69f3c33b0 100644 --- a/examples/enc_dec/README.md +++ b/examples/enc_dec/README.md @@ -13,6 +13,8 @@ This document shows how to build and run an Encoder-Decoder (Enc-Dec) model in T - [Convert and Split Weights](#convert-and-split-weights) - [Build TensorRT engine(s)](#build-tensorrt-engines) - [Run](#run) + - [Run C++ runtime](#run-c-runtime) + - [Run Python runtime](#run-python-runtime) - [Benchmark](#benchmark) - [Run BART with LoRA](#run-bart-with-lora) - [Reminders](#reminders) @@ -80,7 +82,6 @@ python convert_checkpoint.py --model_type ${MODEL_TYPE} \ --output_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \ --tp_size ${TP_SIZE} \ --pp_size ${PP_SIZE} \ - --weight_data_type float32 \ --dtype ${INFERENCE_PRECISION} ``` @@ -154,7 +155,6 @@ python convert_checkpoint.py --model_type ${MODEL_TYPE} \ --output_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \ --tp_size ${TP_SIZE} \ --pp_size ${PP_SIZE} \ - --weight_data_type float32 \ --dtype ${INFERENCE_PRECISION} # Note: non-T5 models can enable FMHA for the encoder part, for FP16/BF16, the default is enabled @@ -192,12 +192,32 @@ trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION ``` - ### Run Run a TensorRT-LLM Enc-Dec model using the engines generated by build.py. Note that during model deployment, only the TensorRT engine files are needed. Previously downloaded model checkpoints and converted weights can be removed. +Different types of runtime are provided for encoder-decoder models. Following an order of serving performance and good usability, we recommend: +- (NEW) Python binding of C++ runtime w/ Paged KV Cache and Inflight Batching (IFB) +- Python runtime w/ Static Batching +- (NEW) C++ runtime w/ Paged KV Cache and Inflight Batching + +Please refer to the documentation for the details of [paged kv cache](../../docs/source/advanced/gpt-attention.md#paged-kv-cache) and [inflight batching](../../docs/source/advanced/gpt-attention.md#inflight-batching). + +#### Run C++ runtime +For good usability, Python binding of the C++ runtime is provided. You can use the high-level C++ `ModelRunner` under the `examples/` root folder. + +```python +# Inferencing via python binding of C++ runtime with inflight batching (IFB) +python3 ../run.py --engine_dir tmp/trt_engines/${MODEL_NAME}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE} --tokenizer_dir tmp/hf_models/${MODEL_NAME} --max_output_len 64 --input_text "translate English to German: The house is wonderful." +``` + +For pure C++ runtime, there is no example given yet. Please check the [`Executor`](../../cpp/include/tensorrt_llm/executor/executor.h) API to implement your own end-to-end workflow. It is highly recommended to leverage more encapsulated solutions such as the above C++ Python binding or [Triton backend](https://github.com/triton-inference-server/tensorrtllm_backend). + +#### Run Python runtime + +For pure Python runtime, you can still use the encoder-decoder specific script under `examples/enc_dec/`. + ```bash # Inferencing w/ single GPU greedy search, compare results with HuggingFace FP32 python3 run.py --engine_dir tmp/trt_engines/${MODEL_NAME}/${WORLD_SIZE}-gpu/${INFERENCE_PRECISION}/tp${TP_SIZE} --engine_name ${MODEL_NAME} --model_name tmp/hf_models/${MODEL_NAME} --max_new_token=64 --num_beams=1 --compare_hf_fp32 @@ -250,7 +270,6 @@ python convert_checkpoint.py --model_type bart \ --output_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISION} \ --tp_size 1 \ --pp_size 1 \ - --weight_data_type float32 \ --dtype ${INFERENCE_PRECISION} ``` @@ -258,8 +277,8 @@ python convert_checkpoint.py --model_type bart \ ```bash -trtllm-build --checkpoint_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISION}/tp1/pp1/encoder \ - --output_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/tp1/encoder \ +trtllm-build --checkpoint_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISION}/encoder \ + --output_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/encoder \ --paged_kv_cache disable \ --moe_plugin disable \ --enable_xqa disable \ @@ -275,8 +294,8 @@ trtllm-build --checkpoint_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISIO --lora_dir tmp/hf_models/bart-large-cnn-samsum-lora/ \ --lora_target_modules attn_q attn_v -trtllm-build --checkpoint_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISION}/tp1/pp1/decoder \ - --output_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/tp1/decoder \ +trtllm-build --checkpoint_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISION}/decoder \ + --output_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/decoder \ --paged_kv_cache disable \ --moe_plugin disable \ --enable_xqa disable \ @@ -298,7 +317,7 @@ trtllm-build --checkpoint_dir tmp/trt_models/bart-large-cnn/${INFERENCE_PRECISIO ```bash python run.py \ - --engine_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/tp1/ \ + --engine_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/ \ --engine_name bart-large-cnn \ --model_name tmp/hf_models/bart-large-cnn \ --max_new_token=64 \ @@ -311,7 +330,7 @@ python run.py \ ```bash python run.py \ - --engine_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/tp1/ \ + --engine_dir tmp/trt_engines/bart-large-cnn/1-gpu/${INFERENCE_PRECISION}/ \ --engine_name bart-large-cnn \ --model_name tmp/hf_models/bart-large-cnn \ --max_new_token=64 \ @@ -362,7 +381,6 @@ python convert_checkpoint.py --model_type nmt \ --output_dir tmp/trt_models/wmt14/${INFERENCE_PRECISION} \ --tp_size ${TP_SIZE} \ --pp_size ${PP_SIZE} \ - --weight_data_type float32 \ --dtype ${INFERENCE_PRECISION} # Build TensorRT engine(s) diff --git a/examples/enc_dec/convert_checkpoint.py b/examples/enc_dec/convert_checkpoint.py index a0e2d607f..4e9ce2e84 100755 --- a/examples/enc_dec/convert_checkpoint.py +++ b/examples/enc_dec/convert_checkpoint.py @@ -15,7 +15,6 @@ Pix2StructForConditionalGeneration, T5ForConditionalGeneration, VisionEncoderDecoderModel) -from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, MLPType) from tensorrt_llm.models import PretrainedConfig @@ -40,7 +39,6 @@ def parse_t5_config(args, hf_model): config["encoder"] = {} for key, val in hf_model.encoder.config.to_dict().items(): config["encoder"][key] = f"{val}" - config["encoder"]["weight_data_type"] = args.weight_data_type # manually set q_scaling to offset attention scaling's effect. # TODO: modify kernels to control whether to disable attention scaling @@ -51,7 +49,6 @@ def get_offset_q_scaling(config): config["decoder"] = {} for key, val in hf_model.decoder.config.to_dict().items(): config["decoder"][key] = f"{val}" - config["decoder"]["weight_data_type"] = args.weight_data_type config["structure"] = dict() config["structure"]["t5_with_bias"] = "false" @@ -111,8 +108,6 @@ def parse_t5_config_by_component(config, component, args): component_config.logits_dtype = config.get(component, 'logits_dtype', fallback='float32') - component_config.ckpt_weight_dtype = config.get(component, - 'weight_data_type') if component == 'encoder': component_config.n_layer = config.getint(component, 'num_layers') @@ -120,9 +115,6 @@ def parse_t5_config_by_component(config, component, args): component_config.relative_attention = config.get( 'structure', 'position_embedding_type') == 'relative' - component_config.ckpt_weight_dtype = config.get( - component, 'weight_data_type') - elif component == 'decoder': component_config.n_layer = config.getint(component, 'num_decoder_layers') @@ -318,7 +310,6 @@ def parse_nmt_config(args, model): config['encoder'] = dict() for key, val in fairseq_config.items(): config["encoder"][key] = f"{val}" - config["encoder"]["weight_data_type"] = args.weight_data_type config["encoder"]["q_scaling"] = '1' # NMT doesn't have final layernorm config['encoder']['has_model_final_layernorm'] = 'false' @@ -327,7 +318,6 @@ def parse_nmt_config(args, model): config['decoder'] = dict() for key, val in fairseq_config.items(): config["decoder"][key] = f"{val}" - config["decoder"]["weight_data_type"] = args.weight_data_type config["decoder"]["q_scaling"] = '1' config["decoder"]["rescale_before_lm_head"] = 'false' config['decoder']['has_model_final_layernorm'] = 'false' @@ -402,8 +392,6 @@ def parse_nmt_config_by_component(config, component, args): component, 'relative_attention_num_buckets', fallback=0) component_config.max_distance = config.getint( component, 'relative_attention_max_distance', fallback=0) - component_config.ckpt_weight_dtype = config.get(component, - 'weight_data_type') component_config.position_embedding_type = config.get( 'structure', 'position_embedding_type') component_config.logits_dtype = config.get(component, @@ -589,7 +577,6 @@ def parse_bart_config(args, hf_model): config['decoder'] = dict() for key, val in hf_model.model.decoder.config.to_dict().items(): config["decoder"][key] = f"{val}" - config["decoder"]["weight_data_type"] = args.weight_data_type config["decoder"]["q_scaling"] = '1' config["decoder"]["rescale_before_lm_head"] = str(False) config['decoder']['has_model_final_layernorm'] = str( @@ -612,7 +599,6 @@ def parse_bart_config(args, hf_model): config['encoder'] = dict() for key, val in hf_model.model.encoder.config.to_dict().items(): config["encoder"][key] = f"{val}" - config["encoder"]["weight_data_type"] = args.weight_data_type config["encoder"]["q_scaling"] = '1' # mBART has final layernorm, BART does not @@ -686,8 +672,6 @@ def parse_bart_config_by_component(config, component, args): component, 'relative_attention_num_buckets', fallback=0) component_config.max_distance = config.getint( component, 'relative_attention_max_distance', fallback=0) - component_config.ckpt_weight_dtype = config.get(component, - 'weight_data_type') component_config.max_lora_rank = config.getint(component, 'max_lora_rank', fallback=0) @@ -934,7 +918,6 @@ def get_offset_q_scaling(config) -> str: config["decoder"] = {} for key, val in hf_model.decoder.config.to_dict().items(): config["decoder"][key] = f"{val}" - config["decoder"]["weight_data_type"] = args.weight_data_type config["decoder"]["q_scaling"] = get_offset_q_scaling( hf_model.decoder.config) @@ -1007,7 +990,6 @@ def parse_pix2struct_config_by_component(config, component, args): args.encoder_hidden_size = config.getint('decoder', 'hidden_size') args.encoder_num_heads = config.getint('decoder', 'num_heads') args.encoder_head_size = config.getint('decoder', 'd_kv') - args.ckpt_weight_dtype = config.get(component, 'weight_data_type') args.position_embedding_type = config.get( 'structure', 'position_embedding_type') @@ -1181,10 +1163,8 @@ def get_model(args): def convert_checkpoint(args): model = get_model(args) - model = model.to(str_dtype_to_torch(args.weight_data_type)) - saved_dir = Path( - args.output_dir) / f"tp{args.tp_size}" / f"pp{args.pp_size}" + saved_dir = Path(args.output_dir) saved_dir.mkdir(parents=True, exist_ok=True) encoder_saved_dir = saved_dir / "encoder" @@ -1404,11 +1384,6 @@ def convert(worker_rank, world_size, args, model_config, convert_args, type=int, help="How many workers to spawn for conversion (default: 4)", default=4) - parser.add_argument("--weight_data_type", - type=str, - default="float32", - choices=["float32", "float16", - "bfloat16"]) # TODO: test support for bf16? parser.add_argument("--nougat", action="store_true", help="Model which uses vision encoder + mbart decoder") diff --git a/examples/enc_dec/run.py b/examples/enc_dec/run.py index 27f59258b..6837e1539 100644 --- a/examples/enc_dec/run.py +++ b/examples/enc_dec/run.py @@ -153,11 +153,10 @@ def parse_arguments(): action='store_true') parser.add_argument('--lora_dir', type=str, default=None, nargs="+") parser.add_argument('--lora_task_uids', type=str, default=None, nargs="+") - parser.add_argument( - "--output_encoder_npy", - help= - "Store tensors like encoder outputs used for testing enc-dec C++ runtime.", - action="store_true") + parser.add_argument("--output_npy", + type=str, + default=None, + help="Store input/output tensors C++ runtime testing") return parser.parse_args() @@ -222,11 +221,13 @@ def engine_setup(component): self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = engine_setup( component='encoder') - # for Pipeline Parallelism in encoder - self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp( - self.encoder_runtime_mapping.tp_size, - self.encoder_runtime_mapping.pp_size, - self.encoder_runtime_mapping.rank) + self.nccl_comm = None + if self.encoder_runtime_mapping.has_pp(): + # for Pipeline Parallelism in encoder + self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp( + self.encoder_runtime_mapping.tp_size, + self.encoder_runtime_mapping.pp_size, + self.encoder_runtime_mapping.rank) # session setup self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine( @@ -579,8 +580,10 @@ def generate(self, encoder_input_lengths=encoder_input_lengths, return_dict=return_dict, cross_attention_mask=cross_attention_mask) - if return_encoder_output: - return output, encoder_output + + if return_dict and return_encoder_output: + output['encoder_output'] = encoder_output + return output @@ -643,8 +646,8 @@ def test_fairseq_models(args): eos_token_id=eos_token_id, debug_mode=args.debug_mode, ) - tok = time.time() torch.cuda.synchronize() + tok = time.time() if return_dict: tllm_output_ids = tllm_output['output_ids'] @@ -722,21 +725,7 @@ def test_fairseq_models(args): 'cuda') # [batch_size, padded_length] # by default int64, must cast to int32! otherwise C++ kernel will interpret as [a, 0, b, 0, c, 0, ...] - CPP_RESULTS_SAVED_DIR = 'cpp/tests/resources/data/enc_dec' if tensorrt_llm.mpi_rank() == 0: - if args.output_encoder_npy: - if not os.path.isdir(CPP_RESULTS_SAVED_DIR): - os.mkdir(os.path.join(CPP_RESULTS_SAVED_DIR)) - np_input_ids = tokenized_inputs.input_ids.type(torch.IntTensor) - np_input_ids = np_input_ids.numpy() - np.save(os.path.join(CPP_RESULTS_SAVED_DIR, 'enc_input_ids.npy'), - np_input_ids) - input_lengths = tokenized_inputs.attention_mask.sum(dim=1).type( - torch.IntTensor).numpy() - np.save( - os.path.join(CPP_RESULTS_SAVED_DIR, 'enc_input_lengths.npy'), - input_lengths) - print("--------------------------------------") print( f"BOS={tokenizer.bos_token_id}, PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}" @@ -811,7 +800,7 @@ def test_fairseq_models(args): print(f"HF E2E time {(tok-tik)*1000}ms") print("--------------------------------------") - return_dict = False # when set return_dict=True, get outputs by key + return_dict = True # when set return_dict=True, get outputs by key tik = time.time() tllm_output = tllm_model.generate( encoder_input_ids=input_ids, @@ -825,21 +814,16 @@ def test_fairseq_models(args): return_dict=return_dict, attention_mask=tokenized_inputs.attention_mask, time_encoder=True, - return_encoder_output=args.output_encoder_npy - and tensorrt_llm.mpi_rank() == 0) + return_encoder_output=args.output_npy and tensorrt_llm.mpi_rank() == 0) + torch.cuda.synchronize() tok = time.time() - if args.output_encoder_npy and tensorrt_llm.mpi_rank() == 0: - tllm_output, encoder_output = tllm_output - encoder_output = encoder_output.cpu().numpy() - np.save(os.path.join(CPP_RESULTS_SAVED_DIR, 'encoder_output.npy'), - encoder_output) - - if return_dict: - tllm_output_ids = tllm_output['output_ids'] - else: - tllm_output_ids = tllm_output if tensorrt_llm.mpi_rank() == 0: + if return_dict: + tllm_output_ids = tllm_output['output_ids'] + else: + tllm_output_ids = tllm_output + output_ids = tllm_output_ids[:, 0, :] output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) @@ -847,6 +831,7 @@ def test_fairseq_models(args): tokenizer.pad_token_id).sum(dim=1) output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum( dim=1) - decoder_input_lengths + print("--------------------------------------") print("TRT-LLM output_ids: ", output_ids) print("TRT-LLM output text: ", output_text) @@ -855,6 +840,34 @@ def test_fairseq_models(args): print("Precision:", inference_dtype) print("--------------------------------------") + # save input/output tensors for C++ runtime testing + if args.output_npy: + os.makedirs(args.output_npy, exist_ok=True) + + input_lengths = tokenized_inputs.attention_mask.sum(dim=1).type( + torch.IntTensor) + input_ids = tokenized_inputs.input_ids.type(torch.IntTensor) + input_ids_flatten = torch.cat([ + input_ids[i][:input_lengths[i]] + for i in range(len(input_lengths)) + ]) + encoder_output = tllm_output['encoder_output'].type(torch.float16) + + def save_npy(tensor, name): + np.save(os.path.join(args.output_npy, f'{name}.npy'), + tensor.cpu().numpy()) + + print( + f"Saving input/output tensors to {args.output_npy} for C++ runtime testing" + ) + save_npy(input_ids_flatten, 'input_ids') # [num_tokens] + save_npy(input_lengths, 'input_lengths') # [batch_size] + save_npy(encoder_output, + 'encoder_output') # [num_tokens, hidden_size] + save_npy( + output_ids, 'output_ids' + ) # [batch_size, max_output_tokens], max_output_tokens = decoder_input_tokens + max_new_tokens + # simple accuracy check if args.compare_hf_fp32: from difflib import SequenceMatcher diff --git a/examples/eval_long_context.py b/examples/eval_long_context.py index 190da0dec..74748ee4f 100644 --- a/examples/eval_long_context.py +++ b/examples/eval_long_context.py @@ -43,7 +43,7 @@ load_tokenizer, read_model_name) import tensorrt_llm -import tensorrt_llm.profiler +import tensorrt_llm.profiler as profiler from tensorrt_llm.logger import logger from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner @@ -173,8 +173,7 @@ def parse_arguments(args=None): parser.add_argument( "--task", type=str, - # choices=list(DATA_NAME_TO_MAX_NEW_TOKENS.keys()) + ["all"], - choices=['passkey'], + choices=['passkey', 'kv_retrieval'], required=True, help= "Which task to use. Note that \"all\" can only be used in `compute_scores.py`.", # noqa @@ -351,29 +350,32 @@ def main(args): args.stop_idx = len(examples) output_path = None - if args.output_dir is not None: - result_dir = Path(args.output_dir, model_name) - result_dir.mkdir(exist_ok=True, parents=True) - - if args.stop_idx is None: - output_path = (result_dir / f"preds_{data_name}.jsonl") - else: - output_path = ( - result_dir / - f"preds_{data_name}_{args.start_idx}-{args.stop_idx}.jsonl" # noqa - ) + if runtime_rank == 0: + if args.output_dir is not None: + result_dir = Path(args.output_dir, model_name) + result_dir.mkdir(exist_ok=True, parents=True) + + if args.stop_idx is None: + output_path = (result_dir / f"preds_{data_name}.jsonl") + else: + output_path = ( + result_dir / + f"preds_{data_name}_{args.start_idx}-{args.stop_idx}.jsonl" # noqa + ) prompt_template = None if args.use_prompt_template and model_name in DEFAULT_PROMPT_TEMPLATES: prompt_template = DEFAULT_PROMPT_TEMPLATES[model_name] - preds = [] - logger.info("==== Evaluation ====") - logger.info(f"# examples: {len(examples)}") - logger.info(f"Start index: {args.start_idx}") - logger.info(f"Stop index: {args.stop_idx}") - logger.info(f"Max tokens: {max_tokens}") + if runtime_rank == 0: + preds = [] + logger.info("==== Evaluation ====") + logger.info(f"# examples: {len(examples)}") + logger.info(f"Start index: {args.start_idx}") + logger.info(f"Stop index: {args.stop_idx}") + logger.info(f"Max tokens: {max_tokens}") assert args.batch_size == 1 + profiler.start('Evaluation') for i in range(args.start_idx, args.stop_idx): eg = examples[i] input_text = [create_prompt(eg, data_name, args.data_dir)] @@ -389,9 +391,10 @@ def main(args): model_version=model_version) input_lengths = [x.size(0) for x in batch_input_ids] - logger.debug(f"====== Example {i} ======") - logger.debug(f"input_text: {input_text}") - logger.debug(f"answer: {get_answer(eg, data_name)}") + if runtime_rank == 0: + logger.debug(f"====== Example {i} ======") + logger.debug(f"input_text: {input_text}") + logger.debug(f"answer: {get_answer(eg, data_name)}") outputs = runner.generate( batch_input_ids, max_new_tokens=max_tokens, @@ -420,28 +423,34 @@ def main(args): return_dict=True, medusa_choices=args.medusa_choices) torch.cuda.synchronize() - output_ids = outputs['output_ids'] - output_beams_list = [ - tokenizer.batch_decode(output_ids[batch_idx, :, - input_lengths[batch_idx]:], - skip_special_tokens=True) - for batch_idx in range(args.batch_size) - ] - - logger.debug(f"preds: {output_beams_list[0]}") - preds.append({ - "id": i, - "prediction": output_beams_list[0][0], - "ground_truth": get_answer(eg, data_name), - }) - if output_path is not None: - dump_jsonl(preds, output_path) - - logger.info("Compute the score") - acc = compute_scores(preds, args.task) - logger.info(f"accuracy of {len(preds)} examples: {acc}") - if args.tensorrt_llm_accuracy_threshold is not None: - assert acc >= args.tensorrt_llm_accuracy_threshold, f"acc ({acc}) < tensorrt_llm_accuracy_threshold ({args.tensorrt_llm_accuracy_threshold})" + if runtime_rank == 0: + output_ids = outputs['output_ids'] + output_beams_list = [ + tokenizer.batch_decode(output_ids[batch_idx, :, + input_lengths[batch_idx]:], + skip_special_tokens=True) + for batch_idx in range(args.batch_size) + ] + + logger.debug(f"preds: {output_beams_list[0]}") + preds.append({ + "id": i, + "prediction": output_beams_list[0][0], + "ground_truth": get_answer(eg, data_name), + }) + if output_path is not None: + dump_jsonl(preds, output_path) + profiler.stop('Evaluation') + + if runtime_rank == 0: + logger.info("Compute the score") + acc = compute_scores(preds, args.task) + logger.info( + f'Evaluation takes: {profiler.elapsed_time_in_sec("Evaluation")} sec.' + ) + logger.info(f"accuracy of {len(preds)} examples: {acc}") + if args.tensorrt_llm_accuracy_threshold is not None: + assert acc >= args.tensorrt_llm_accuracy_threshold, f"acc ({acc}) < tensorrt_llm_accuracy_threshold ({args.tensorrt_llm_accuracy_threshold})" if __name__ == "__main__": diff --git a/examples/falcon/README.md b/examples/falcon/README.md index aa6f535c8..f4736f5f9 100644 --- a/examples/falcon/README.md +++ b/examples/falcon/README.md @@ -248,7 +248,6 @@ python ../quantization/quantize.py --model_dir ./falcon/180b \ # Build trtllm engines from the trtllm checkpoint trtllm-build --checkpoint_dir ./falcon/180b/trt_ckpt/fp8/tp8-pp1 \ --gemm_plugin float16 \ - --strongly_typed \ --output_dir ./falcon/180b/trt_engines/fp8/tp8-pp1 \ --workers 8 diff --git a/examples/falcon/requirements.txt b/examples/falcon/requirements.txt index f61c2b6a9..96bffe898 100644 --- a/examples/falcon/requirements.txt +++ b/examples/falcon/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 transformers>=4.31.0 datasets~=2.14.5 evaluate~=0.4.1 diff --git a/examples/gemma/README.md b/examples/gemma/README.md index c06d53cf1..1f8f83f1c 100644 --- a/examples/gemma/README.md +++ b/examples/gemma/README.md @@ -399,7 +399,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ --max_input_len 3000 \ --max_output_len 100 \ --enable_xqa enable \ - --strongly_type \ --lookup_plugin bfloat16 \ --output_dir ${ENGINE_PATH} @@ -669,7 +668,6 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ --max_input_len 3000 \ --max_output_len 100 \ --enable_xqa enable \ - --strongly_type \ --lookup_plugin bfloat16 \ --output_dir ${ENGINE_PATH} diff --git a/examples/gemma/convert_checkpoint.py b/examples/gemma/convert_checkpoint.py index aeee3a452..99f94e08d 100644 --- a/examples/gemma/convert_checkpoint.py +++ b/examples/gemma/convert_checkpoint.py @@ -825,7 +825,7 @@ def convert_from_checkpoint( tp_rank, dim=trt_llm_config.embedding_sharding_dim, ) - if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \ + if trt_llm_config.quant_mode.is_int8_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \ not trt_llm_config.quant_mode.has_int8_kv_cache(): # shape of embedding table: [V, K], V: vocab size, K: embedding dim @@ -1000,7 +1000,8 @@ def main(): quant_kwargs.update(quant_algo=quant_algo, kv_cache_quant_algo=kv_cache_quant_algo) if args.use_weight_only_with_precision: - if args.use_weight_only_with_precision.endswith("awq"): + if args.use_weight_only_with_precision.endswith( + "awq") or args.use_weight_only_with_precision.endswith("int4"): quant_kwargs.update(has_zero_point=False, pre_quant_scale=True, exclude_modules=[ diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt index 91aad4aff..41633e585 100644 --- a/examples/gemma/requirements.txt +++ b/examples/gemma/requirements.txt @@ -3,7 +3,7 @@ # WAR the new posting of "nvidia-cudnn-cu12~=9.0". # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9". nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64" -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 flax~=0.8.0 # jax[cuda12_pip]~=0.4.19; platform_system != "Windows" jax~=0.4.19; platform_system == "Windows" diff --git a/examples/gpt/README.md b/examples/gpt/README.md index 624b2019d..6508d3ed3 100644 --- a/examples/gpt/README.md +++ b/examples/gpt/README.md @@ -428,7 +428,6 @@ python3 convert_checkpoint.py --model_dir gpt2 \ --output_dir gpt2/trt_ckpt/int8kv/1-gpu trtllm-build --checkpoint_dir gpt2/trt_ckpt/int8kv/1-gpu \ - --strongly_typed \ --output_dir gpt2/trt_engines/int8kv/1-gpu ``` diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index 37a4dba88..a6e6776b5 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptj/requirements.txt b/examples/gptj/requirements.txt index eb20b2d7a..bc9c169fe 100644 --- a/examples/gptj/requirements.txt +++ b/examples/gptj/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptneox/requirements.txt b/examples/gptneox/requirements.txt index f47ef94a3..e18d7b5a7 100644 --- a/examples/gptneox/requirements.txt +++ b/examples/gptneox/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 rouge_score~=0.1.2 evaluate~=0.4.1 diff --git a/examples/grok/README.md b/examples/grok/README.md new file mode 100644 index 000000000..943747ed7 --- /dev/null +++ b/examples/grok/README.md @@ -0,0 +1,87 @@ +# Grok-1 + +This document shows how to build and run grok-1 model in TensorRT-LLM on both single GPU, single node multi-GPU and multi-node multi-GPU. + +- [Grok1](#Grok-1) + - [Prerequisite](#prerequisite) + - [Hardware](#hardware) + - [Overview](#overview) + - [Support Matrix](#support-matrix) + - [Usage](#usage) + - [Build TensorRT engine(s)](#build-tensorrt-engines) + +## Prerequisite +First of all, please clone the official grok-1 code repo with below commands and install the dependencies. +```bash +git clone https://github.com/xai-org/grok-1.git /path/to/folder +``` +And then downloading the weights per [instructions](https://github.com/xai-org/grok-1?tab=readme-ov-file#downloading-the-weights). + +## Hardware +The grok-1 model requires a node with 8x80GB GPU memory(at least). + +## Overview + +The TensorRT-LLM Grok-1 implementation can be found in [tensorrt_llm/models/grok/model.py](../../tensorrt_llm/models/grok/model.py). The TensorRT-LLM Grok-1 example code is located in [`examples/grok`](./). There is one main file: + +* [`convert_checkpoint.py`](./convert_checkpoint.py) to convert the Grok-1 model into tensorrt-llm checkpoint format. + +In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation: + +* [`../run.py`](../run.py) to run the inference on an input text; +* [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. + +## Support Matrix + * INT8 Weight-Only + * Tensor Parallel + * STRONGLY TYPED + +## Usage + +The TensorRT-LLM Grok-1 example code locates at [examples/grok](./). It takes xai weights as input, and builds the corresponding TensorRT engines. The number of TensorRT engines depends on the number of GPUs used to run inference. + +### Build TensorRT engine(s) + +Please install required packages first to make sure the example uses matched `tensorrt_llm` version: + +```bash +pip install -r requirements.txt +``` + +Need to prepare the Grok-1 checkpoint by following the guides here https://github.com/xai-org/grok-1. + +TensorRT-LLM Grok-1 builds TensorRT engine(s) from Xai's checkpoints. + +Normally `trtllm-build` only requires single GPU, but if you've already got all the GPUs needed for inference, you could enable parallel building to make the engine building process faster by adding `--workers` argument. Please note that currently `workers` feature only supports single node. + + +Below is the step-by-step to run Grok-1 with TensorRT LLM. + +```bash +# Build the bfloat16 engine from xai official weights. +python convert_checkpoint.py --model_dir ./tmp/grok-1/ \ + --output_dir ./tllm_checkpoint_8gpus_bf16 \ + --dtype bfloat16 \ + --use_weight_only \ + --workers 8 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_8gpus_bf16 \ + --output_dir ./tmp/grok-1/trt_engines/bf16/8-gpus \ + --gpt_attention_plugin bfloat16 \ + --gemm_plugin bfloat16 \ + --moe_plugin bfloat16 \ + --paged_kv_cache enable \ + --remove_input_padding enable \ + --workers 8 \ + --strongly_typed + + +# Run Grok-1 with 8 GPUs +mpirun -n 8 --allow-run-as-root \ + python ../run.py \ + --max_output_len 50 \ + --input_text "The answer to life the universe and everything is of course" \ + --engine_dir ./tmp/grok-1/trt_engines/bf16/8-gpus \ + --max_output_len 50 --top_p 1 --top_k 8 --temperature 0.3 \ + --vocab_file ./tmp/grok-1/tokenizer.model +``` diff --git a/examples/grok/convert_checkpoint.py b/examples/grok/convert_checkpoint.py new file mode 100644 index 000000000..ea8a59946 --- /dev/null +++ b/examples/grok/convert_checkpoint.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import sys +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np + +import tensorrt_llm +from tensorrt_llm._utils import release_gc +from tensorrt_llm.layers import MoeConfig +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import GrokForCausalLM +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization import QuantAlgo + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_dir', type=str, default=None) + parser.add_argument('--weights_dir', type=str, default=None) + + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'bfloat16', 'float16']) + parser.add_argument('--vocab_size', type=int, default=32000) + parser.add_argument('--n_positions', type=int, default=2048) + parser.add_argument('--n_layer', type=int, default=32) + parser.add_argument('--n_head', type=int, default=32) + parser.add_argument('--n_kv_head', type=int, default=None) + parser.add_argument('--n_embd', type=int, default=4096) + parser.add_argument('--inter_size', type=int, default=11008) + parser.add_argument('--rms_norm_eps', type=float, default=1e-06) + + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--disable_weight_only_quant_plugin', + default=False, + action="store_true", + help= + 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + + parser.add_argument('--load_by_shard', + action='store_true', + help='Load a pretrained model shard-by-shard.') + parser.add_argument('--hidden_act', type=str, default='silu') + + parser.add_argument('--rotary_base', type=float, default=10000.0) + + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--use_embedding_sharing', + action="store_true", + default=False, + help= + 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' + 'Note: the flag might not take effect when the criteria are not met.') + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + parser.add_argument( + '--moe_num_experts', + default=0, + type=int, + help='Specify the number of experts to use for MOE layers') + parser.add_argument( + '--moe_top_k', + default=0, + type=int, + help= + 'Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set' + ) + parser.add_argument( + '--moe_tp_mode', + default=MoeConfig.ParallelismMode.TENSOR_PARALLEL, + type=int, + help= + 'Controls how to distribute experts in TP. Check layers/moe.py for accepted values', + ) + parser.add_argument( + '--moe_renorm_mode', + default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE, + type=int, + help= + 'Controls renormalization after gate logits. Check layers/moe.py for accepted values', + ) + parser.add_argument( + '--save_config_only', + action="store_true", + default=False, + help= + 'Only save the model config w/o read and converting weights, be careful, this is for debug only' + ) + + args = parser.parse_args() + # changing the default to be consistent as the cli help said. + if args.moe_num_experts and args.moe_top_k == 0: + args.moe_top_k = 1 + return args + + +def args_to_quantization(args: argparse.Namespace) -> QuantConfig: + '''return config dict with quantization info based on the command line args + ''' + quant_config = QuantConfig() + quant_config.exclude_modules = [ + 'lm_head', 'router', 'vocab_embedding', 'position_embedding', + 'block_embedding' + ] + if args.use_weight_only: + if args.weight_only_precision == 'int8': + quant_config.quant_algo = QuantAlgo.W8A16 + + return quant_config + + +def args_to_build_options(args): + return { + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'share_embedding_table': args.use_embedding_sharing, + 'disable_weight_only_quant_plugin': + args.disable_weight_only_quant_plugin + } + + +def from_cli_args(args): + n_kv_head = args.n_kv_head if args.n_kv_head is not None else args.n_head + config = { + 'architecture': "LlamaForCausalLM", + 'dtype': args.dtype, + 'logits_dtype': 'float32', + 'num_hidden_layers': args.n_layer, + 'num_attention_heads': args.n_head, + 'hidden_size': args.n_embd, + 'intermediate_size': args.inter_size, + 'num_key_value_heads': n_kv_head, + 'vocab_size': args.vocab_size, + 'position_embedding_type': 'rope_gpt_neox', + 'max_position_embeddings': args.n_positions, + 'hidden_act': args.hidden_act, + 'rotary_base': args.rotary_base, + 'norm_epsilon': args.rms_norm_eps, + 'moe_num_experts': args.moe_num_experts, + 'moe_top_k': args.moe_top_k, + 'moe_tp_mode': args.moe_tp_mode, + 'moe_normalization_mode': args.moe_renorm_mode, + 'mapping': { + 'world_size': args.tp_size * args.pp_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size + }, + 'quantization': args_to_quantization(args).asdict() + } + config.update(args_to_build_options(args)) + return config + + +def preload_model(model_dir, weights_dir=None): + sys.path.append(model_dir) + from model import LanguageModelConfig, TransformerConfig + from runners import ModelRunner + if weights_dir and os.path.exists(weights_dir): + CKPT_PATH = weights_dir + else: + CKPT_PATH = os.path.join(model_dir, "checkpoints") + + grok_1_model = LanguageModelConfig( + vocab_size=128 * 1024, + pad_token=0, + eos_token=2, + sequence_len=8192, + embedding_init_scale=1.0, + output_multiplier_scale=0.5773502691896257, + embedding_multiplier_scale=78.38367176906169, + model=TransformerConfig( + emb_size=48 * 128, + widening_factor=8, + key_size=128, + num_q_heads=48, + num_kv_heads=8, + num_layers=64, + attn_output_multiplier=0.08838834764831845, + shard_activations=True, + # MoE. + num_experts=8, + num_selected_experts=2, + # Activation sharding. + data_axis="data", + model_axis="model", + ), + ) + + runner = ModelRunner( + model=grok_1_model, + bs_per_device=0.125, + checkpoint_path=CKPT_PATH, + ) + dummy_data = dict( + inputs=np.zeros((1, 256), dtype=np.int32), + targets=np.zeros((1, 256), dtype=np.int32), + ) + runner.transform_forward = True + runner.initialize(dummy_data, (1, 8), (1, 1)) + + params = runner.load_or_init(dummy_data) + + return params + + +def convert_and_save_xai(args): + model_dir = args.model_dir + load_by_shard = args.load_by_shard + world_size = args.tp_size * args.pp_size + # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. + # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, + # before the refactor is done. + override_fields = {'moe_tp_mode': args.moe_tp_mode} + quantization = args_to_quantization(args) + override_fields.update(args_to_build_options(args)) + + # When not loading by shard, preload one complete model and then slice per rank weights from this + # this saves the disk reloading time + xai_model = preload_model( + model_dir, args.weights_dir) if not args.load_by_shard else None + + def convert_and_save_rank(args, rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + grok = GrokForCausalLM.from_hugging_face( + model_dir, + args.dtype, + mapping=mapping, + quantization=quantization, + load_by_shard=load_by_shard, + override_fields=override_fields, + preloaded_model=xai_model, + ) + grok.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del grok + + execute(args.workers, [convert_and_save_rank] * world_size, args) + release_gc() + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) + else: + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + +def main(): + print(tensorrt_llm.__version__) + args = parse_arguments() + + args.tp_size * args.pp_size + tik = time.time() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + if args.model_dir is None: # generate fake config.json + config = from_cli_args(args) + with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) + else: # all other non-gptq paths from hf model + assert args.model_dir is not None + convert_and_save_xai(args) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() diff --git a/examples/grok/requirements.txt b/examples/grok/requirements.txt new file mode 100644 index 000000000..858f328d6 --- /dev/null +++ b/examples/grok/requirements.txt @@ -0,0 +1,10 @@ +-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.11.0.dev2024060400 +datasets==2.14.6 +evaluate~=0.4.1 +rouge_score~=0.1.2 +sentencepiece==0.2.0 +jax[cuda12-pip]==0.4.28 +jaxlib[cuda12-pip]==0.4.28 +dm_haiku==0.0.12 diff --git a/examples/hf_lora_convert.py b/examples/hf_lora_convert.py index a0a2ead8e..525be45fd 100644 --- a/examples/hf_lora_convert.py +++ b/examples/hf_lora_convert.py @@ -79,7 +79,15 @@ def convert_hf_model(model_dir, dtype, out_dir): saved_dir.mkdir(parents=True, exist_ok=True) with open(f"{model_dir}/adapter_config.json", "r") as f: config = json.load(f) - config["r"] + + rank = config.get("r") + alpha = config.get("lora_alpha") + use_rslora = config.get("use_rslora", False) + if use_rslora: + scale = alpha / np.sqrt(rank) + else: + scale = alpha / rank + lora_model = load_state_dict(get_model_path(model_dir, "adapter_model")) all_weights = get_all_lora_weights(lora_model) converted_weights = [] @@ -104,7 +112,8 @@ def convert_hf_model(model_dir, dtype, out_dir): elif dim0 < dim1 and inout == "out": adapter_size = dim0 w = w.transpose(1, 0) - + if inout == "out": + w = w * scale w = w.contiguous().flatten().to(dtype=str_dtype_to_torch(dtype)) in_out_weights.append(w) in_out_weights = torch.concatenate(in_out_weights).flatten() diff --git a/examples/high-level-api/README.md b/examples/high-level-api/README.md index 284b80551..9b073bb62 100644 --- a/examples/high-level-api/README.md +++ b/examples/high-level-api/README.md @@ -203,10 +203,12 @@ Thirdly, you need to prepare a Slurm script to submit the task, the script conta ```sh #SBATCH -N 2 # number of nodes #SBATCH --ntasks-per-node=4 +#SBATCH -p +# more sbatch options here... srun --container-image="" \ --mpi=pmix \ - ... \ # much details here + ... \ # more srun options here trtllm-hlapi-launch python3 .py ``` diff --git a/examples/high-level-api/requirements.txt b/examples/high-level-api/requirements.txt index cca7a9fc3..beccb4d64 100644 --- a/examples/high-level-api/requirements.txt +++ b/examples/high-level-api/requirements.txt @@ -1,2 +1,2 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 diff --git a/examples/infinitebench/construct_synthetic_dataset.py b/examples/infinitebench/construct_synthetic_dataset.py index b13fa0e2e..339aad750 100644 --- a/examples/infinitebench/construct_synthetic_dataset.py +++ b/examples/infinitebench/construct_synthetic_dataset.py @@ -46,10 +46,10 @@ def build_passkey(args): question = "What is the pass key?" # target_length = [ - # 1024 * 8, 1024 * 16, 1024 * 32, 1024 * 64, 1024 * 128, 1024 * 256 + # 1024 * 8, 1024 * 16, 1024 * 32, 1024 * 64, 1024 * 128, 1024 * 256, 1024 * 512, 1024 * 1024 # ] - num_noise = [326, 652, 1305, 2610, 5220, 10440] - step = [6, 12, 22, 45, 90, 180] + num_noise = [326, 652, 1305, 2610, 5220, 10440, 20880, 41760] + step = [6, 12, 22, 45, 90, 180, 360, 720] repeat_time = 5 step_i = step[args.test_level] num_noise_i = num_noise[args.test_level] @@ -85,7 +85,6 @@ def build_kv_retrieval(): with jsonlines.open("kv-retrieval-3000_keys.jsonl") as fin: for line in fin: - print(len(line["ordered_kv_records"])) # return 0 cnt += 1 if cnt == nsample[ii]: @@ -126,7 +125,7 @@ def build_kv_retrieval(): type=int, default=0, help= - "Test level between [0, 5] for task build_passkey and [0, 1] for task build_kv_retrieval. The larger number, the longer context" + "Test level between [0, 7] for task build_passkey and [0, 1] for task build_kv_retrieval. The larger number, the longer context" ) parser.add_argument( '--test_case', @@ -140,5 +139,10 @@ def build_kv_retrieval(): # os.system("git clone https://github.com/nelson-liu/lost-in-the-middle.git") # os.system("python3.10 -u lost-in-the-middle/scripts/make_kv_retrieval_data.py --num-keys 3000 --num-examples 500 --output-path kv-retrieval-3000_keys.jsonl.gz") # os.system("gzip -d kv-retrieval-3000_keys.jsonl.gz") - # build_kv_retrieval() - build_passkey(args) + + if args.test_case == "build_passkey": + build_passkey(args) + elif args.test_case == "build_kv_retrieval": + build_kv_retrieval() + else: + assert False diff --git a/examples/internlm/README.md b/examples/internlm/README.md index fe475c0f1..d9b9c0e93 100644 --- a/examples/internlm/README.md +++ b/examples/internlm/README.md @@ -131,7 +131,6 @@ python convert_checkpoint.py --model_dir ./internlm-chat-7b \ trtllm-build --checkpoint_dir ./internlm-chat-7b/smooth_internlm/int8_kv_cache/ \ --output_dir ./engine_outputs \ --gemm_plugin float16 \ - --strongly_typed ``` @@ -150,7 +149,6 @@ python convert_checkpoint.py --model_dir ./internlm-chat-20b \ trtllm-build --checkpoint_dir ./internlm-chat-20b/smooth_internlm/int8_kv_cache/ \ --output_dir ./engine_outputs \ --gemm_plugin float16 \ - --strongly_typed ``` diff --git a/examples/internlm/requirements.txt b/examples/internlm/requirements.txt index 3c4515451..1c35cb3bf 100644 --- a/examples/internlm/requirements.txt +++ b/examples/internlm/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets==2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/internlm2/.gitignore b/examples/internlm2/.gitignore new file mode 100644 index 000000000..7ce339719 --- /dev/null +++ b/examples/internlm2/.gitignore @@ -0,0 +1,2 @@ +internlm* +tokenizer.model diff --git a/examples/internlm2/README.md b/examples/internlm2/README.md new file mode 100644 index 000000000..bed682819 --- /dev/null +++ b/examples/internlm2/README.md @@ -0,0 +1,201 @@ +# InternLM2 + +This document shows how to build and run InternLM2 7B / 20B models in TensorRT-LLM on both single GPU, single node multi-GPU and multi-node multi-GPU. + +## Overview + +The TensorRT-LLM InternLM2 implementation is based on the LLaMA model. The implementation can +be found in [model.py](../../tensorrt_llm/models/llama/model.py). +The TensorRT-LLM InternLM2 example code lies in [`examples/internlm2`](./): + +* [`convert_checkpoint.py`](./convert_checkpoint.py) converts the Huggingface Model of InternLM2 into TensorRT-LLM checkpoint. + + +In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation: + +* [`../run.py`](../run.py) to run the inference on an input text; +* [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. + +## Support Matrix + * FP16 / BF16 + * INT8 & INT4 Weight-Only + * Tensor Parallel + +## Usage + +The TensorRT-LLM InternLM2 example code locates at [examples/internlm](./). It takes HF weights as input, and builds the corresponding TensorRT engines. The number of TensorRT engines depends on the number of GPUs used to run inference. + +### Build TensorRT engine(s) + +Please install required packages first to make sure the example uses matched `tensorrt_llm` version: + +```bash +pip install -r requirements.txt +``` + +TensorRT-LLM InternLM2 builds TensorRT engine(s) from HF checkpoint. If no checkpoint directory is specified, TensorRT-LLM will build engine(s) with dummy weights. + +InternLM2 has released several checkpoints of different size or capabilities under https://huggingface.co/internlm. Users can pick any one repository and follow instructions to prepare the checkpoint. + +Below examples use [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) and [internlm2-chat-20b](https://huggingface.co/internlm/internlm2-chat-20b) and assume these repositories are cloned or linked under this directory, for example `./internlm2-chat-7b`. + +Normally `trtllm-build` only requires single GPU, but if you've already got all the GPUs needed for inference, you could enable parallel building to make the engine building process faster by adding `--workers` argument. Please note that currently `--workers` feature only supports single node. + +Here're some examples: + +```bash +# Build a single-GPU float16 engine from HF weights. +# gpt_attention_plugin is necessary in InternLM2. +# Try use_gemm_plugin to prevent accuracy issue. +cd examples/internlm2 + +# Convert the InternLM2 7B model using a single GPU and FP16. +python convert_checkpoint.py --model_dir ./internlm2-chat-7b/ \ + --dtype float16 \ + --output_dir ./internlm2-chat-7b/trt_engines/fp16/1-gpu/ +# Note: setting `--dtype bfloat16` to use bfloat16 precision. + +# BUild the InternLM2 7B model using a single GPU +trtllm-build --checkpoint_dir ./internlm2-chat-7b/trt_engines/fp16/1-gpu/ \ + --output_dir ./engine_outputs \ + --gemm_plugin float16 + +# Convert the InternLM2 7B model using a single GPU and apply INT8 weight-only quantization.. +python convert_checkpoint.py --model_dir ./internlm2-chat-7b/ \ + --dtype float16 \ + --output_dir ./internlm2-chat-7b/trt_engines/int8/1-gpu/ \ + --use_weight_only \ + --weight_only_precision int8 + +trtllm-build --checkpoint_dir ./internlm2-chat-7b/trt_engines/int8/1-gpu/ \ + --output_dir ./engine_outputs \ + --gemm_plugin float16 + +# Note: setting `--weight_only_precision int4` to use INT4 weight-only quantization + +# Build InternLM2 7B using 2-way tensor parallelism. +python convert_checkpoint.py --model_dir ./internlm2-chat-7b/ \ + --dtype float16 \ + --output_dir ./internlm2-chat-7b/trt_engines/fp16/2-gpu/ \ + --tp_size 2 + +trtllm-build --checkpoint_dir ./internlm2-chat-7b/trt_engines/fp16/2-gpu/ \ + --output_dir ./engine_outputs \ + --gemm_plugin float16 + +# Build InternLM2 20B using 2-way tensor parallelism. +python convert_checkpoint.py --model_dir ./internlm2-chat-20b/ \ + --dtype bfloat16 \ + --output_dir ./internlm2-chat-20b/trt_engines/bf16/2-gpu/ \ + --tp_size 2 --workers 2 + +trtllm-build --checkpoint_dir ./internlm2-chat-7b/trt_engines/bf16/2-gpu/ \ + --output_dir ./engine_outputs \ + --gpt_attention_plugin bfloat16 \ + --gemm_plugin bfloat16 +``` + +#### INT8 weight only + +Examples: + +```bash +cd examples/internlm2 + +# For 7B models +python convert_checkpoint.py --model_dir ./internlm2-chat-7b \ + --output_dir ./internlm2-chat-7b/w8a16/ \ + --dtype float16 \ + --use_weight_only \ + --weight_only_precision int8 + +# Build 7B model with both INT8 weight-only +trtllm-build --checkpoint_dir ./internlm2-chat-7b/w8a16 \ + --output_dir ./engine_outputs \ + --gemm_plugin float16 +``` + + +```bash +cd examples/internlm2 + +# For 20B models +python convert_checkpoint.py --model_dir ./internlm2-chat-20b \ + --output_dir ./internlm2-chat-20b/w8a16 \ + --dtype float16 \ + --use_weight_only \ + --weight_only_precision int8 + +# Build 20B model with both INT8 weight-only +trtllm-build --checkpoint_dir ./internlm2-chat-20b/w8a16 \ + --output_dir ./engine_outputs \ + --gemm_plugin float16 \ +``` + +### Run + +To run a TensorRT-LLM InternLM2 model using the engines generated by `trtllm-build` + +```bash +# InternLM2 7B with fp16 +python ../run.py --max_output_len=120 \ + --input_text 'Tell me about yourself.' \ + --tokenizer_dir ./internlm2-chat-7b/ \ + --engine_dir=./internlm2-chat-7b/trt_engines/fp16/1-gpu/ + +# InternLM2 7B with bf16 +python ../run.py --max_output_len=120 \ + --input_text 'Tell me about yourself.' \ + --tokenizer_dir ./internlm2-chat-7b/ \ + --engine_dir=./internlm2-chat-7b/trt_engines/bf16/1-gpu/ + +# InternLM2 7B with int8 weight only quantization +python ../run.py --max_output_len=120 \ + --input_text 'Tell me about yourself.' \ + --tokenizer_dir ./internlm2-chat-7b/ \ + --engine_dir=./internlm2-chat-7b/trt_engines/weight_only/1-gpu/ + +# InternLM2 7B with fp16 and tensor parallelism +mpirun -n 2 --allow-run-as-root \ + python ../run.py --max_output_len=120 \ + --input_text 'Tell me about yourself.' \ + --tokenizer_dir ./internlm2-chat-7b/ \ + --engine_dir=./internlm2-chat-7b/trt_engines/fp16/2-gpu/ + +# InternLM2 20B with fp16 and tensor parallelism and pipeline parallelism +mpirun -n 4 --allow-run-as-root \ + python ../run.py --max_output_len=120 \ + --input_text 'Tell me about yourself.' \ + --tokenizer_dir ./internlm2-chat-7b/ \ + --engine_dir=./internlm2-chat-7b/trt_engines/bf16/4-gpu/ +``` + +### Summarization using the InternLM2 model + +```bash +# Run summarization using the InternLM2 7B model in FP16. +python ../summarize.py --test_trt_llm --test_hf \ + --hf_model_dir ./internlm2-chat-7b/ \ + --data_type fp16 \ + --engine_dir ./engine_outputs + +# Run summarization using the InternLM2 7B model quantized to w8a16. +python ../summarize.py --test_trt_llm --test_hf \ + --hf_model_dir ./internlm2-chat-7b/ \ + --data_type fp16 \ + --engine_dir ./engine_outputs + +# Run summarization using the InternLM2 7B model in FP16 using two GPUs. +mpirun -n 2 --allow-run-as-root \ + python ../summarize.py --test_trt_llm --test_hf \ + --hf_model_dir ./internlm2-chat-7b/ \ + --data_type fp16 \ + --engine_dir ./internlm2-chat-7b/trt_engines/fp16/2-gpu/ + +# Run summarization using the InternLM2 20B model in BF16 using 4 GPUs. +mpirun -n 4 --allow-run-as-root \ + python ../summarize.py --test_trt_llm --test_hf \ + --hf_model_dir ./internlm2-chat-20b/ \ + --data_type bf16 \ + --engine_dir ./internlm2-chat-20b/trt_engines/bf16/4-gpu/ +``` diff --git a/examples/internlm2/convert_checkpoint.py b/examples/internlm2/convert_checkpoint.py new file mode 100644 index 000000000..76f91afe7 --- /dev/null +++ b/examples/internlm2/convert_checkpoint.py @@ -0,0 +1,420 @@ +import argparse +import json +import os +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Dict, Optional + +import numpy as np +import safetensors +import torch +from einops import rearrange +from transformers import AutoConfig, AutoModelForCausalLM + +import tensorrt_llm +from tensorrt_llm._utils import release_gc +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.llama import convert + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_dir', type=str, default=None) + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'bfloat16', 'float16']) + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--use_embedding_sharing', + action="store_true", + default=False, + help= + 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' + 'Note: the flag might not take effect when the criteria are not met.') + + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + parser.add_argument('--log_level', type=str, default='info') + args = parser.parse_args() + + tensorrt_llm.logger.set_level(args.log_level) + return args + + +def get_qkv_weight(weight: torch.Tensor, + hidden_size: int, + num_heads: int, + tp_size: int, + tp_rank: int, + is_bias: bool, + num_kv_heads: Optional[int] = None) -> torch.Tensor: + """ Splits the QKV matrix according to tensor parallelism """ + head_size = hidden_size // num_heads + num_kv_groups = num_heads // num_kv_heads + mha_mode = num_kv_heads == num_heads + weight = rearrange(weight, + '(h gs d) dim -> h gs d dim', + gs=2 + num_kv_groups, + d=head_size) + q_w, k_w, v_w = torch.split(weight, [num_kv_groups, 1, 1], dim=1) + if is_bias: + q_w = q_w.ravel() + k_w = k_w.ravel() + v_w = v_w.ravel() + qkv_w = torch.cat((q_w, k_w, v_w)) + qkv_w = convert.split_qkv_bias_tp(qkv_w, num_heads, hidden_size, + tp_size, tp_rank) + else: + q_w = rearrange(q_w, 'h gs d dim -> (h gs d) dim') + k_w = rearrange(k_w, 'h gs d dim -> (h gs d) dim') + v_w = rearrange(v_w, 'h gs d dim -> (h gs d) dim') + if not mha_mode: + if num_kv_heads < tp_size: + k_w = convert.dup_kv_weight(k_w, num_kv_heads, tp_size) + v_w = convert.dup_kv_weight(v_w, num_kv_heads, tp_size) + assert (k_w.shape[0] % (tp_size * head_size)) == 0 + assert (v_w.shape[0] % (tp_size * head_size)) == 0 + wq = convert.split(q_w, tp_size, tp_rank) + wk = convert.split(k_w, tp_size, tp_rank) + wv = convert.split(v_w, tp_size, tp_rank) + qkv_w = torch.concat((wq, wk, wv)) + + else: + qkv_w = torch.cat([q_w, k_w, v_w], dim=0) + + qkv_w = convert.split_qkv_tp(qkv_w, num_heads, hidden_size, tp_size, + tp_rank) + return qkv_w + + +def get_tllm_linear_weight( + weight: torch.Tensor, + prefix: str, + bias: Optional[torch.Tensor] = None, + use_weight_only: bool = False, + plugin_weight_only_quant_type: torch.dtype = torch.int8 +) -> Dict[str, torch.Tensor]: + results = {} + if use_weight_only: + v = weight.t().contiguous() + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + v, plugin_weight_only_quant_type) + results[f'{prefix}.weight'] = processed_torch_weights + results[f'{prefix}.per_channel_scale'] = torch_weight_scales + else: + results[f'{prefix}.weight'] = weight.contiguous() + + if bias is not None: + results[f'{prefix}.bias'] = bias + + return results + + +def convert_from_hf(hf_model, + hf_config, + mapping: Mapping, + dtype: str = 'float32', + use_parallel_embedding: bool = False, + share_embedding_table: bool = False, + sharding_dim: int = 0, + use_weight_only: bool = False, + plugin_weight_only_quant_type: torch.dtype = torch.int8): + weights = {} + tik = time.time() + + model_params = dict(hf_model.named_parameters()) + dtype = getattr(torch, dtype) + num_attention_heads = hf_config.num_attention_heads + hidden_size = hf_config.hidden_size + vocab_size = hf_config.vocab_size + num_kv_heads = hf_config.num_key_value_heads + num_hidden_layers = hf_config.num_hidden_layers + layers_range = mapping.pp_layers(num_hidden_layers) + for l in layers_range: + prefix = f'model.layers.{l}' + tllm_prex = f'transformer.layers.{l - layers_range[0]}' + + qkv_weight = convert.get_weight(model_params, + f'{prefix}.attention.wqkv', dtype) + qkv_w = get_qkv_weight(qkv_weight, + hidden_size, + num_attention_heads, + mapping.tp_size, + mapping.tp_rank, + is_bias=False, + num_kv_heads=num_kv_heads) + + qkv_bias = None + if f'{prefix}.attention.wqkv.bias' in model_params: + qkv_bias = convert.get_bias(model_params, + f'{prefix}.attention.wqkv', dtype) + if qkv_bias is None: + qkv_b = None + else: + qkv_b = get_qkv_weight(qkv_bias, + hidden_size, + num_attention_heads, + mapping.tp_size, + mapping.tp_rank, + is_bias=True, + num_kv_heads=num_kv_heads) + weights.update( + get_tllm_linear_weight( + qkv_w, + f'{tllm_prex}.attention.qkv', + qkv_b, + use_weight_only, + plugin_weight_only_quant_type, + )) + + attn_dense_weight = convert.get_weight(model_params, + f'{prefix}.attention.wo', dtype) + attn_dense_w = convert.split_matrix_tp(attn_dense_weight, + mapping.tp_size, + mapping.tp_rank, + dim=1) + attn_dense_bias = None + if f'{prefix}.attention.wo.bias' in model_params: + attn_dense_bias = convert.get_bias(model_params, + f'{prefix}.attention.wo', dtype) + + weights.update( + get_tllm_linear_weight( + attn_dense_w, + f'{tllm_prex}.attention.dense', + attn_dense_bias, + use_weight_only, + plugin_weight_only_quant_type, + )) + + mlp_fc_weight = convert.get_weight(model_params, + f'{prefix}.feed_forward.w1', dtype) + mlp_fc_w = convert.split_matrix_tp(mlp_fc_weight, + mapping.tp_size, + mapping.tp_rank, + dim=0) + mlp_fc_b = None + weights.update( + get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', mlp_fc_b, + use_weight_only, + plugin_weight_only_quant_type)) + + mlp_proj_weight = convert.get_weight(model_params, + f'{prefix}.feed_forward.w2', dtype) + mlp_proj_w = convert.split_matrix_tp(mlp_proj_weight, + mapping.tp_size, + mapping.tp_rank, + dim=1) + mlp_proj_bias = None + weights.update( + get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj', + mlp_proj_bias, use_weight_only, + plugin_weight_only_quant_type)) + + mlp_gate_weight = convert.get_weight(model_params, + f'{prefix}.feed_forward.w3', dtype) + mlp_gate_w = convert.split_matrix_tp(mlp_gate_weight, + mapping.tp_size, + mapping.tp_rank, + dim=0) + mlp_gate_bias = None + weights.update( + get_tllm_linear_weight(mlp_gate_w, f'{tllm_prex}.mlp.gate', + mlp_gate_bias, use_weight_only, + plugin_weight_only_quant_type)) + + # Layer norms do not use tensor parallelism + input_ln_weight = convert.get_weight(model_params, + f'{prefix}.attention_norm', dtype) + weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_weight + + post_ln_weight = convert.get_weight(model_params, f'{prefix}.ffn_norm', + dtype) + weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_weight + + release_gc() + + embed_w = convert.get_weight(model_params, 'model.tok_embeddings', dtype) + if use_parallel_embedding: + embed_w = convert.split_matrix_tp(embed_w, + mapping.tp_size, + mapping.tp_rank, + dim=sharding_dim) + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = embed_w + lm_head_weights = convert.get_weight(model_params, 'output', dtype) + if mapping.is_last_pp_rank(): + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = convert.pad_vocab_size(vocab_size, + mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + + lm_head_weights = torch.from_numpy( + np.pad(lm_head_weights.detach().cpu().numpy(), + ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + weights['lm_head.weight'] = convert.split_matrix_tp(lm_head_weights, + mapping.tp_size, + mapping.tp_rank, + dim=0) + ln_f_w = convert.get_weight(model_params, 'model.norm', dtype) + weights['transformer.ln_f.weight'] = ln_f_w + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Weights loaded. Total time: {t}') + return weights + + +if __name__ == '__main__': + args = parse_arguments() + world_size = args.tp_size * args.pp_size + + tik = time.time() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + quant_algo = None + plugin_weight_only_quant_type = None + if args.use_weight_only and args.weight_only_precision == 'int8': + plugin_weight_only_quant_type = torch.int8 + quant_algo = 'W8A16' + elif args.use_weight_only and args.weight_only_precision == 'int4': + plugin_weight_only_quant_type = torch.quint4x2 + quant_algo = 'W4A16' + + hf_config = AutoConfig.from_pretrained(args.model_dir, + trust_remote_code=True) + config = { + 'architecture': hf_config.architectures[0], + 'dtype': args.dtype, + 'logits_dtype': 'float32', + 'num_hidden_layers': hf_config.num_hidden_layers, + 'num_attention_heads': hf_config.num_attention_heads, + 'num_key_value_heads': hf_config.num_key_value_heads, + 'hidden_size': hf_config.hidden_size, + 'intermediate_size': hf_config.intermediate_size, + 'norm_epsilon': hf_config.rms_norm_eps, + 'vocab_size': hf_config.vocab_size, + 'position_embedding_type': 'rope_gpt_neox', + 'rotary_base': hf_config.rope_theta, + 'max_position_embeddings': hf_config.max_position_embeddings, + 'hidden_act': hf_config.hidden_act, + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'share_embedding_table': args.use_embedding_sharing, + 'quantization': { + 'quant_algo': quant_algo, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + }, + 'attn_bias': getattr(hf_config, 'bias', False), + 'rotary_scaling': getattr(hf_config, "rope_scaling", None) + } + + with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) + + def covert_and_save(rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + + hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, + trust_remote_code=True, + torch_dtype="auto") + weights = convert_from_hf( + hf_model, + hf_config, + mapping, + dtype=args.dtype, + use_parallel_embedding=args.use_parallel_embedding, + sharding_dim=args.embedding_sharding_dim, + share_embedding_table=args.use_embedding_sharing, + use_weight_only=args.use_weight_only, + plugin_weight_only_quant_type=plugin_weight_only_quant_type) + del hf_model + save_file = os.path.join(args.output_dir, f'rank{rank}.safetensors') + print(f'Saving to {save_file}') + safetensors.torch.save_file(weights, save_file) + + if args.workers == 1: + for rank in range(world_size): + covert_and_save(rank) + else: + with ThreadPoolExecutor(max_workers=args.workers) as p: + futures = [ + p.submit(covert_and_save, rank) for rank in range(world_size) + ] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') diff --git a/examples/internlm2/requirements.txt b/examples/internlm2/requirements.txt new file mode 100644 index 000000000..d27fa26c6 --- /dev/null +++ b/examples/internlm2/requirements.txt @@ -0,0 +1 @@ +einops diff --git a/examples/llama/README.md b/examples/llama/README.md index 66fd2f528..50f985de5 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -10,13 +10,16 @@ This document shows how to build and run a LLaMA model in TensorRT-LLM on both s - [LLaMA v2 Updates](#llama-v2-updates) - [LLaMA v3 Updates](#llama-v3-updates) - [Long context length](#long-context-length) - - [INT8 KV cache](#int8-kv-cache) - - [SmoothQuant](#smoothquant) - - [FP8 Post-Training Quantization](#fp8-post-training-quantization) - - [Groupwise quantization (AWQ/GPTQ)](#groupwise-quantization-awqgptq) - - [AWQ](#awq) - - [GPTQ](#gptq) + - [Long context evaluation](#long-context-evaluation) + - [1M long context test case](#1m-long-context-test-case) + - [INT8 KV cache](#int8-kv-cache) + - [SmoothQuant](#smoothquant) + - [FP8 Post-Training Quantization](#fp8-post-training-quantization) + - [Groupwise quantization (AWQ/GPTQ)](#groupwise-quantization-awqgptq) + - [AWQ](#awq) + - [GPTQ](#gptq) - [Run](#run) + - [Multi-GPU multi-node (MGMN) support](#multi-gpu-multi-node-mgmn-support) - [Summarization using the LLaMA model](#summarization-using-the-llama-model) - [Mistral v0.1](#mistral-v01) - [Running CodeLlama](#running-codellama) @@ -24,6 +27,7 @@ This document shows how to build and run a LLaMA model in TensorRT-LLM on both s - [Run](#run-1) - [Run LLaMa with LoRA](#run-llama-with-lora) - [Run LLaMa with several lora checkpoints](#run-llama-with-several-lora-checkpoints) + - [Run FP8 LLaMa with FP16 lora checkpoints](#run-fp8-llama-with-fp16-lora-checkpoints) - [Run LLaMa with StreamingLLM](#run-llama-with-streamingllm) ## Overview @@ -78,6 +82,8 @@ Normally `trtllm-build` only requires single GPU, but if you've already got all `--use_fused_mlp` enables GEMM horizontal fusion in gated MLP layer, which reduces input traffic and potentially improves performance. For FP8 PTQ, the downside is slight reduction of accuracy because one of the quantization scaling factors are discarded (accuracy 0.45734 vs 0.45755 for LLaMA-v2 7B using modelopt/examples/hf/instruct_eval/mmlu.py). +`--use_fused_mlp --use_gemm_swiglu_plugin ` fuses 2 GEMMs without biases and SwiGLU into one kernel. This is a preview feature and is only supported for dtype `fp8`. The supported architecture is SM90. + Here're some examples: ```bash @@ -270,7 +276,7 @@ A few LLaMA models are fine-tuned for long context length that TRT-LLM can suppo ```bash # Build 8-GPU engine with long context LLaMA model -python convert_checkpoint.py --meta_ckpt_dir ./tmp/LongAlpaca-70B/ \ +python convert_checkpoint.py --model_dir ./tmp/LongAlpaca-70B/ \ --output_dir ./tllm_checkpoint_8gpu_tp8 \ --dtype float16 \ --tp_size 8 \ @@ -326,6 +332,8 @@ git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-104 * Run examples with max_input_len 16384 +To evaluate the PPL of very long context, we need to enable `use_paged_context_fmha` and setup `max_num_tokens` to enable the chunked context inference, reducing the activation memory requirement. Also, we need to enable `gather_all_token_logits` to return the logits to compute the PPL. + ```bash python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ --output_dir /tmp/llama-3-8B-1048k/trt_ckpts \ @@ -334,10 +342,11 @@ python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-8B-Instruct-Gr python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-8B-1048k/trt_ckpts \ --output_dir /tmp/llama-3-8B-1048k/trt_engines \ --gemm_plugin float16 \ - --strongly_typed \ --gather_all_token_logits \ + --max_num_tokens 4096 \ --max_input_len 16384 \ - --max_output_len 10 + --max_output_len 10 \ + --use_paged_context_fmha enable python ./examples/summarize.py --test_trt_llm \ --tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ @@ -351,15 +360,119 @@ python ./examples/summarize.py --test_trt_llm \ * Run evaluation on passkey task +To evaluate the accuracy of very long context on `needle in haystack`, we need to enable `use_paged_context_fmha` and setup `max_num_tokens` to enable the chunked context inference, reducing the activation memory requirement. To save memory, we don't enable the `gather_all_token_logits` here because we don't need logits. + ```bash -python3 examples/infinitebench/construct_synthetic_dataset.py --test_case build_passkey +python3 examples/infinitebench/construct_synthetic_dataset.py --test_case build_passkey --test_level 4 + +python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-8B-1048k/trt_ckpts \ + --output_dir /tmp/llama-3-8B-1048k/trt_engines \ + --gemm_plugin float16 \ + --max_num_tokens 4096 \ + --max_input_len 131072 \ + --max_output_len 10 \ + --use_paged_context_fmha enable python examples/eval_long_context.py --task passkey \ --engine_dir /tmp/llama-3-8B-1048k/trt_engines \ --tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ - --stop_idx 20 \ - --max_input_length 12800 \ - --use_py_session + --stop_idx 10 \ + --max_input_length 131072 \ + --enable_chunked_context \ + --max_tokens_in_paged_kv_cache 131136 +``` + +* Run evaluation on kv_retrieval + +`kv_retrieval` is harder than `passkey` and is helpful to distinguish the model capability. + +To run the kv_retrieval, we need a third-party repo to prepare the keys. + +```bash +git clone git@github.com:nelson-liu/lost-in-the-middle.git +pip install -r lost-in-the-middle/requirements.txt +python -u lost-in-the-middle/scripts/make_kv_retrieval_data.py --num-keys 3000 --num-examples 500 --output-path kv-retrieval-3000_keys.jsonl.gz +gzip -d kv-retrieval-3000_keys.jsonl.gz +``` + +Prepare input data and run evaluation. + +```bash +python examples/infinitebench/construct_synthetic_dataset.py --test_case build_kv_retrieval --test_level 0 + +python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ + --output_dir /tmp/llama-3-8B-1048k/trt_ckpts \ + --dtype float16 \ + --tp_size 1 + +python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-8B-1048k/trt_ckpts \ + --output_dir /tmp/llama-3-8B-1048k/trt_engines \ + --gemm_plugin float16 \ + --max_num_tokens 4096 \ + --max_input_len 131072 \ + --max_output_len 10 \ + --use_paged_context_fmha enable + +python examples/eval_long_context.py --task kv_retrieval \ + --engine_dir /tmp/llama-3-8B-1048k/trt_engines \ + --tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ + --stop_idx 10 \ + --max_input_length 131072 \ + --enable_chunked_context \ + --max_tokens_in_paged_kv_cache 131136 \ + --tensorrt_llm_accuracy_threshold 0.6 +``` + +expected results: + +```bash +[05/28/2024-03:31:43] [TRT-LLM] [I] ==== Evaluation ==== +[05/28/2024-03:31:43] [TRT-LLM] [I] # examples: 500 +[05/28/2024-03:31:43] [TRT-LLM] [I] Start index: 0 +[05/28/2024-03:31:43] [TRT-LLM] [I] Stop index: 10 +[05/28/2024-03:31:43] [TRT-LLM] [I] Max tokens: 50 +[05/28/2024-03:34:50] [TRT-LLM] [I] Compute the score +10it [00:00, 131072.00it/s] +[05/28/2024-03:34:51] [TRT-LLM] [I] Evaluation takes: 187.19733428955078 sec. +[05/28/2024-03:34:51] [TRT-LLM] [I] accuracy of 10 examples: 0.6 +``` + +#### 1M long context test case + +```bash +git-lfs clone https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k/ + +python examples/infinitebench/construct_synthetic_dataset.py --test_case build_passkey --test_level 7 + +python examples/llama/convert_checkpoint.py --model_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ + --output_dir /tmp/llama-3-8B-1048k/trt_ckpts \ + --dtype float16 \ + --tp_size 4 + +python -m tensorrt_llm.commands.build --checkpoint_dir /tmp/llama-3-8B-1048k/trt_ckpts \ + --output_dir /tmp/llama-3-8B-1048k/trt_engines \ + --gemm_plugin float16 \ + --max_num_tokens 4096 \ + --max_input_len 1048576 \ + --max_output_len 10 \ + --use_paged_context_fmha enable \ + --workers 4 + +mpirun -n 4 --allow-run-as-root python examples/eval_long_context.py --task passkey \ + --engine_dir /tmp/llama-3-8B-1048k/trt_engines \ + --tokenizer_dir ./Llama-3-8B-Instruct-Gradient-1048k/ \ + --stop_idx 1 \ + --max_input_length 1048576 \ + --enable_chunked_context \ + --max_tokens_in_paged_kv_cache 1100000 +``` + +expected result: + +```bash +[05/27/2024-10:30:45] [TRT-LLM] [I] Compute the score +1it [00:00, 4215.38it/s] +[05/27/2024-10:30:45] [TRT-LLM] [I] accuracy of 1 examples: 1.0 ``` ### INT8 KV cache @@ -400,7 +513,6 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_int8_kv_wq \ --output_dir ./tmp/llama/7B/trt_engines/int8_kv_cache_weight_only/1-gpu \ --gemm_plugin auto \ --multi_block_mode enable \ - --strongly_typed ``` Test with `../summarize.py`: @@ -429,7 +541,6 @@ python ../quantization/quantize.py --model_dir /tmp/llama-7b-hf \ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_awq_int8_kv_cache \ --output_dir ./tmp/llama/7B/trt_engines/int8_kv_cache_int4_AWQ/1-gpu/ \ --gemm_plugin auto \ - --strongly_typed ``` Test with `../summarize.py`: @@ -499,7 +610,6 @@ python ../quantization/quantize.py --model_dir ./tmp/llama/70B \ trtllm-build --checkpoint_dir ./tllm_checkpoint_2gpu_fp8 \ --output_dir ./engine_outputs \ --gemm_plugin auto \ - --strongly_typed \ --workers 2 ``` @@ -596,6 +706,47 @@ python3 ../run.py --max_output_len=50 \ --engine_dir=./tmp/llama/7B/trt_engines/bf16/1-gpu/ ``` +### Multi-GPU multi-node (MGMN) support + +In MGMN case, you can still convert and build engines on a single node and then run the model on a multi-node environment, such as [Slurm](https://slurm.schedmd.com/documentation.html). + +For example, to build LLaMA 70B for 2 nodes with 8 GPUs per node, we can use 8-way tensor parallelism and 2-way pipeline parallelism: + +```bash +python convert_checkpoint.py --model_dir ./tmp/llama/70B/hf/ \ + --output_dir ./tllm_checkpoint_16gpu_tp8_pp2 \ + --dtype float16 \ + --tp_size 8 \ + --pp_size 2 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_16gpu_tp8_pp2 \ + --output_dir ./tmp/llama/70B/trt_engines/fp16/16-gpu/ \ + --workers 8 \ + --gemm_plugin auto +``` + +Note that `–-workers` is still set to 8 to build all engines within a single node. + +To run the LLaMA 70B model on 2 nodes via Slurm, you need to prepare a Slurm script to submit the task, the script contains the following lines: + +```bash +#SBATCH -N 2 +#SBATCH --ntasks-per-node=8 +#SBATCH -p +# more sbatch options here... + +srun --container-image= \ + --mpi=pmix \ + ... \ # more srun options here + python3 ../run.py --max_output_len=50 \ + --tokenizer_dir ./tmp/llama/70B/hf/ \ + --engine_dir=./tmp/llama/70B/trt_engines/fp16/16-gpu/ +``` + +Finally, you can submit the task with `sbatch .sh`. + +Considering the Slurm or other cluster management systems may be highly customized and the task-submit command may be variant, the forementioned example is for reference only. The key point is to submit the Python script with the MPI runtime, and TensorRT-LLM will take care of the rest. + ### Summarization using the LLaMA model ```bash @@ -847,6 +998,71 @@ Output [Text 5 Beam 0]: "ワシントン D.C." We can observe that `luotuo-lora-7b-0.1` produces correct answers on the first sentence and the fifth sentence (in Chinese), `Japanese-Alpaca-LoRA-7b-v0` produces correct answers on the sixth sentence (in Japanese). +### Run FP8 LLaMa with FP16 lora checkpoints + +In this section, we show how to run an FP8 llama model with multiple FP16 LoRA modules. + +* Quantize the llama model to fp8 from HF +```bash +BASE_LLAMA_MODEL=llama-7b-hf/ +python ../quantization/quantize.py --model_dir ${BASE_LLAMA_MODEL} \ + --dtype float16 \ + --qformat fp8 \ + --kv_cache_dtype fp8 \ + --output_dir ./tllm_checkpoint_1gpu_fp8 \ + --calib_size 512 +``` + +* Download the lora model, build engine, and run inference. +```bash +git-lfs clone https://huggingface.co/qychen/luotuo-lora-7b-0.1 +git-lfs clone https://huggingface.co/kunishou/Japanese-Alpaca-LoRA-7b-v0 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp8 \ + --output_dir /tmp/llama_7b_with_lora_qkv/trt_engines/fp8/1-gpu/ \ + --gemm_plugin auto \ + --lora_plugin auto \ + --max_batch_size 8 \ + --max_input_len 512 \ + --max_output_len 50 \ + --lora_dir "luotuo-lora-7b-0.1/" "Japanese-Alpaca-LoRA-7b-v0/" \ + --max_lora_rank 8 \ + --lora_target_modules attn_q attn_k attn_v + +python ../run.py --engine_dir "/tmp/llama_7b_with_lora_qkv/trt_engines/fp8/1-gpu/" \ + --max_output_len 10 \ + --tokenizer_dir ${BASE_LLAMA_MODEL} \ + --input_text "美国的首都在哪里? \n答案:" "美国的首都在哪里? \n答案:" "美国的首都在哪里? \n答案:" "アメリカ合衆国の首都はどこですか? \n答え:" "アメリカ合衆国の首都はどこですか? \n答え:" "アメリカ合衆国の首都はどこですか? \n答え:" \ + --lora_task_uids -1 0 1 -1 0 1 \ + --use_py_session --top_p 0.5 --top_k 0 +``` + +The results would be like + +```bash +Input [Text 0]: " 美国的首都在哪里? \n答案:" +Output [Text 0 Beam 0]: "Washington, D.C. +What is the" + +Input [Text 1]: " 美国的首都在哪里? \n答案:" +Output [Text 1 Beam 0]: "华盛顿。 +" + +Input [Text 2]: " 美国的首都在哪里? \n答案:" +Output [Text 2 Beam 0]: "沃尔沛。\n" + +Input [Text 3]: " アメリカ合衆国の首都はどこですか? \n答え:" +Output [Text 3 Beam 0]: "Washington, D.C. +Copyright " + +Input [Text 4]: " アメリカ合衆国の首都はどこですか? \n答え:" +Output [Text 4 Beam 0]: "华盛顿。 +" + +Input [Text 5]: " アメリカ合衆国の首都はどこですか? \n答え:" +Output [Text 5 Beam 0]: "ワシントン D.C." +``` + ## Run LLaMa with StreamingLLM * Build engine. Set `--streamingllm enable` to enable StreamingLLM. diff --git a/examples/llama/convert_checkpoint.py b/examples/llama/convert_checkpoint.py index d542efe12..951d67ac9 100644 --- a/examples/llama/convert_checkpoint.py +++ b/examples/llama/convert_checkpoint.py @@ -1,7 +1,6 @@ import argparse import json import os -import sys import time import traceback from concurrent.futures import ThreadPoolExecutor, as_completed @@ -10,8 +9,10 @@ from tensorrt_llm._utils import release_gc from tensorrt_llm.layers import MoeConfig from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models import LLaMAForCausalLM -from tensorrt_llm.models.llama.weight import load_from_gptq_llama +from tensorrt_llm.models import LLaMAConfig, LLaMAForCausalLM +from tensorrt_llm.models.convert_utils import has_safetensors +from tensorrt_llm.models.llama.convert import (load_hf_llama, + load_weights_from_gptq) from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization import QuantAlgo @@ -106,10 +107,10 @@ def parse_arguments(): 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' ) parser.add_argument( - '--modelopt_quant_ckpt_path', + '--quant_ckpt_path', type=str, default=None, - help='Path of a quantized model checkpoint in .npz format') + help='Path of a quantized model checkpoint in .safetensors format') parser.add_argument( '--per_group', @@ -133,10 +134,6 @@ def parse_arguments(): help='Group size used in GPTQ quantization.' ) # AWQ is only supported by quantize.py script - parser.add_argument("--dataset-cache-dir", - type=str, - default=None, - help="cache dir to load the hugging face dataset") parser.add_argument("--load_model_on_cpu", action="store_true") parser.add_argument( '--use_parallel_embedding', @@ -212,7 +209,7 @@ def parse_arguments(): return args -def args_to_quantization(args: argparse.Namespace) -> QuantConfig: +def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: '''return config dict with quantization info based on the command line args ''' quant_config = QuantConfig() @@ -255,12 +252,12 @@ def convert_and_save_meta(args, rank): tp_size=args.tp_size, pp_size=args.pp_size, rank=rank) - assert not args_to_quantization(args).quant_mode.has_any_quant(), \ + assert not args_to_quant_config(args).quant_mode.has_any_quant(), \ "quantization from meta checkpoint or empty model were never supported" llama = LLaMAForCausalLM.from_meta_ckpt( args.meta_ckpt_dir, args.dtype, - mapping, + mapping=mapping, use_parallel_embedding=args.use_parallel_embedding, embedding_sharding_dim=args.embedding_sharding_dim) llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) @@ -293,58 +290,23 @@ def from_cli_args(args): 'hidden_act': args.hidden_act, 'rotary_base': args.rotary_base, 'norm_epsilon': args.rms_norm_eps, - 'moe_num_experts': args.moe_num_experts, - 'moe_top_k': args.moe_top_k, - 'moe_tp_mode': args.moe_tp_mode, - 'moe_normalization_mode': args.moe_renorm_mode, + 'moe': { + 'num_experts': args.moe_num_experts, + 'top_k': args.moe_top_k, + 'tp_mode': args.moe_tp_mode, + 'normalization_mode': args.moe_renorm_mode, + }, 'mapping': { 'world_size': args.tp_size * args.pp_size, 'tp_size': args.tp_size, 'pp_size': args.pp_size }, - 'quantization': args_to_quantization(args).asdict() + 'quantization': args_to_quant_config(args).to_dict() } config.update(args_to_build_options(args)) return config -def preload_model(model_dir, load_model_on_cpu): - use_safetensors = True - from transformers import AutoConfig, AutoModelForCausalLM - if "vila" in model_dir: - use_safetensors = False - sys.path.append(model_dir + "/../VILA") - from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa - from transformers import AutoModel - model = AutoModel.from_pretrained( - model_dir, - device_map='auto', - trust_remote_code=True, - ) - return model.llm - - hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) - model_cls = AutoModelForCausalLM - if hf_config.model_type == "llava": - use_safetensors = False - from transformers import LlavaForConditionalGeneration - model_cls = LlavaForConditionalGeneration - use_safetensors = any( - [f.endswith(".safetensors") - for f in os.listdir(model_dir)]) and use_safetensors - if use_safetensors: - return None - model = model_cls.from_pretrained( - model_dir, - device_map='auto' if not load_model_on_cpu else 'cpu', - torch_dtype='auto', - trust_remote_code=True, - ) - if hf_config.model_type == "llava": - model = model.language_model - return model - - def convert_and_save_hf(args): model_dir = args.model_dir load_model_on_cpu = args.load_model_on_cpu @@ -354,9 +316,10 @@ def convert_and_save_hf(args): # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, # before the refactor is done. override_fields = {'moe_tp_mode': args.moe_tp_mode} - quantization = args_to_quantization(args) override_fields.update(args_to_build_options(args)) + quant_config = args_to_quant_config(args) + if args.smoothquant is not None or args.int8_kv_cache: assert not args.load_by_shard, "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported" assert not args.load_model_on_cpu, "When using quantization, TRT-LLM needs to load the model to GPU" @@ -367,17 +330,22 @@ def convert_and_save_hf(args): pp_size=args.pp_size) LLaMAForCausalLM.quantize(args.model_dir, args.output_dir, - quantization, dtype=args.dtype, mapping=mapping, + quant_config=quant_config, calib_dataset=args.calib_dataset, - override_fields=override_fields, - dataset_cache_dir=args.dataset_cache_dir) + **override_fields) else: # When not loading by shard, preload one complete model and then slice per rank weights from this # this saves the disk reloading time - hf_model = preload_model( - model_dir, load_model_on_cpu) if not args.load_by_shard else None + + hf_model = None + if "vila" in model_dir or "llava" in model_dir: + hf_model = load_hf_llama(model_dir, load_model_on_cpu) + elif not (args.load_by_shard or + (has_safetensors(model_dir) + and not quant_config.quant_mode.has_any_quant())): + hf_model = load_hf_llama(model_dir, load_model_on_cpu) def convert_and_save_rank(args, rank): mapping = Mapping(world_size=world_size, @@ -385,14 +353,12 @@ def convert_and_save_rank(args, rank): tp_size=args.tp_size, pp_size=args.pp_size) llama = LLaMAForCausalLM.from_hugging_face( - model_dir, + model_dir if hf_model is None else hf_model, args.dtype, mapping=mapping, - quantization=quantization, + quant_config=quant_config, load_by_shard=load_by_shard, - load_model_on_cpu=load_model_on_cpu, - override_fields=override_fields, - preloaded_model=hf_model, + **override_fields, ) llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) del llama @@ -406,15 +372,16 @@ def convert_and_save_gptq(args, rank): tp_size=args.tp_size, rank=rank, pp_size=args.pp_size) - llama = LLaMAForCausalLM.from_hugging_face( + config = LLaMAConfig.from_hugging_face( args.model_dir, args.dtype, mapping=mapping, - quantization=args_to_quantization(args), - skip_loading_weights=True) - weights = load_from_gptq_llama(llama.config, args.modelopt_quant_ckpt_path) - llama.load(weights) - llama.save_checkpoint(args.output_dir, rank == 0) + quant_config=args_to_quant_config(args), + ) + model = LLaMAForCausalLM(config) + weights = load_weights_from_gptq(args.quant_ckpt_path, config) + model.load(weights) + model.save_checkpoint(args.output_dir, rank == 0) def execute(workers, func, args): @@ -456,11 +423,11 @@ def main(): execute(args.workers, [convert_and_save_meta] * world_size, args) elif args.weight_only_precision == 'int4_gptq': assert args.model_dir is not None - assert args.modelopt_quant_ckpt_path is not None + assert args.quant_ckpt_path is not None execute(args.workers, [convert_and_save_gptq] * world_size, args) else: # all other non-gptq paths from hf model assert args.model_dir is not None - assert args.modelopt_quant_ckpt_path is None, "only gptq weights only needs this option" + assert args.quant_ckpt_path is None, "only gptq weights only needs this option" convert_and_save_hf(args) tok = time.time() diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index dc02169b6..ab2e413f3 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/llama/summarize_long.py b/examples/llama/summarize_long.py index a2056bd25..4e8ccb753 100644 --- a/examples/llama/summarize_long.py +++ b/examples/llama/summarize_long.py @@ -65,6 +65,13 @@ def parse_args(): parser.add_argument('--tensorrt_llm_rouge1_threshold', type=float, default=15.0) + parser.add_argument( + '--rouge_dir', + default=None, + type=str, + help= + "datasets.load_metrics('rouge') will attempt to pull rouge package from HF. Use cached rouge can avoid network outage of host or HF." + ) args = parser.parse_args() return args @@ -352,8 +359,10 @@ def main(args): # no ground truth, compare with hf if runtime_rank == 0 and args.test_hf and args.test_trt_llm: + rouge_dir = args.rouge_dir if args.rouge_dir and os.path.exists( + args.rouge_dir) else "rouge" metric_tensorrt_llm = [ - load_metric("rouge") for _ in range(args.num_beams) + load_metric(rouge_dir) for _ in range(args.num_beams) ] for i in range(args.num_beams): diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt index 8833ae3b7..08a958192 100644 --- a/examples/mamba/requirements.txt +++ b/examples/mamba/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate rouge_score diff --git a/examples/medusa/convert_checkpoint.py b/examples/medusa/convert_checkpoint.py index affba1bde..b369facfd 100644 --- a/examples/medusa/convert_checkpoint.py +++ b/examples/medusa/convert_checkpoint.py @@ -22,9 +22,9 @@ from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import PretrainedConfig from tensorrt_llm.models.convert_utils import load_calib_dataset -from tensorrt_llm.models.llama.weight import load_from_hf_checkpoint -from tensorrt_llm.models.modeling_utils import PretrainedConfig +from tensorrt_llm.models.llama.convert import load_weights_from_hf_by_shard from tensorrt_llm.quantization import QuantAlgo try: @@ -578,9 +578,9 @@ def get_tllm_linear_weight(weight, postfix='weight'): results = {} if use_weight_only: - v = weight.t().contiguous() + v = weight.t().contiguous().cpu() processed_torch_weights, torch_weight_scales = \ - torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( v, plugin_weight_only_quant_type) results[prefix + postfix] = processed_torch_weights results[prefix + 'per_channel_scale'] = torch_weight_scales @@ -1167,8 +1167,8 @@ def covert_and_save(rank, convert_args): assert False, "Never supported" else: if args.load_by_shard: - weights = load_from_hf_checkpoint( - args.model_dir, mapping, PretrainedConfig.from_dict(config)) + weights = load_weights_from_hf_by_shard( + args.model_dir, PretrainedConfig.from_dict(config)) else: weights = convert_hf_llama( @@ -1193,8 +1193,21 @@ def load_medusa_hf(medusa_path: str, mapping=Mapping(), dtype='float32'): logger.info("Loading Medusa heads' weights ...") + is_ckpt_safetensors = False + ckpt_file = Path(medusa_path) / "medusa_lm_head.pt" - state_dict = torch.load(ckpt_file, map_location="cpu") + if not ckpt_file.exists(): + ckpt_file = Path( + medusa_path) / "medusa_lm_head.safetensors" + is_ckpt_safetensors = True + + if is_ckpt_safetensors: + logger.info("Safetensors Found ...") + from safetensors.torch import load_file + state_dict = load_file(ckpt_file) + else: + state_dict = torch.load(ckpt_file, map_location="cpu") + torch_dtype = str_dtype_to_torch(dtype) weights = {} @@ -1203,10 +1216,13 @@ def load_medusa_hf(medusa_path: str, w = state_dict[f"{h}.{l}.linear.weight"].clone().to( torch_dtype) - weights[ - 'medusa_heads.{}.medusa_layers.{}.linear.weight' - .format(h, l)] = split(w, mapping.tp_size, - mapping.tp_rank) + split_v = split(w, mapping.tp_size, mapping.tp_rank) + weights.update( + get_tllm_linear_weight( + split_v, + f'medusa_heads.{h}.medusa_layers.{l}.linear.', + None, args.use_weight_only, + plugin_weight_only_quant_type)) b = state_dict[f"{h}.{l}.linear.bias"].clone().to( torch_dtype) diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt index 6579db58e..3e9dbb6ec 100644 --- a/examples/medusa/requirements.txt +++ b/examples/medusa/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/mixtral/README.md b/examples/mixtral/README.md index e0c41296e..11f458d1f 100644 --- a/examples/mixtral/README.md +++ b/examples/mixtral/README.md @@ -133,7 +133,6 @@ python ../quantization/quantize.py --model_dir ./Mixtral-8x7B-v0.1 \ trtllm-build --checkpoint_dir ./tllm_checkpoint_mixtral_2gpu \ --output_dir ./engine_outputs \ --gemm_plugin float16 \ - --strongly_typed \ --workers 2 ``` diff --git a/examples/mixtral/requirements.txt b/examples/mixtral/requirements.txt index 04cc1f2f3..b41bcc386 100644 --- a/examples/mixtral/requirements.txt +++ b/examples/mixtral/requirements.txt @@ -1,4 +1,4 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 transformers==4.38.2 accelerate==0.25.0 diff --git a/examples/mmlu.py b/examples/mmlu.py index 8815c0498..fdb94c0a1 100644 --- a/examples/mmlu.py +++ b/examples/mmlu.py @@ -58,7 +58,10 @@ from utils import load_tokenizer, read_model_name import tensorrt_llm -from tensorrt_llm.runtime import ModelRunner +from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner + +if PYTHON_BINDINGS: + from tensorrt_llm.runtime import ModelRunnerCpp os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -281,7 +284,8 @@ def __call__(self, prompt): top_k=top_k) output_ids = outputs[0, input_lengths[0]:] - elif isinstance(self.model, ModelRunner): + elif isinstance(self.model, ModelRunnerCpp) or isinstance( + self.model, ModelRunner): outputs = self.model.generate( batch_input_ids, max_new_tokens=output_len, @@ -389,9 +393,11 @@ def main(): if args.test_trt_llm: assert not args.test_hf, "Cannot test both TRT-LLM and HF" - model = ModelRunner.from_dir(args.engine_dir, - rank=runtime_rank, - debug_mode=args.debug_mode) + runner_cls = ModelRunner if (args.debug_mode + or not PYTHON_BINDINGS) else ModelRunnerCpp + model = runner_cls.from_dir(args.engine_dir, + rank=runtime_rank, + debug_mode=args.debug_mode) else: assert args.test_hf, "Must test either TRT-LLM or HF" if model_name == 'ChatGLMForCausalLM' and model_version == 'glm': diff --git a/examples/mpt/requirements.txt b/examples/mpt/requirements.txt index eb20b2d7a..bc9c169fe 100644 --- a/examples/mpt/requirements.txt +++ b/examples/mpt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index c3f033208..05f47db88 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -14,7 +14,7 @@ We first describe how to run each model on a single GPU. We then provide general - [Fuyu](#fuyu) - [Kosmos-2](#kosmos-2) - [LLaVA and VILA](#llava-and-vila) -- [Neva](#neva) +- [NeVA](#neva) - [Video NeVA](#video-neva) - [Nougat](#nougat) - [Enabling tensor parallelism for multi-GPU](#enabling-tensor-parallelism-for-multi-gpu) @@ -33,7 +33,6 @@ We first describe how to run each model on a single GPU. We then provide general --output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \ --tp_size 1 \ --pp_size 1 \ - --weight_data_type float32 \ --dtype bfloat16 \ --max_multimodal_len 256 # 8 (max_batch_size) * 32 (num_visual_features) ``` @@ -41,8 +40,8 @@ We first describe how to run each model on a single GPU. We then provide general 2. Build TRT-LLM engine from TRT-LLM checkpoint ```bash - trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/tp1/pp1/encoder \ - --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1/encoder \ + trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/encoder \ + --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/encoder \ --paged_kv_cache disable \ --moe_plugin disable \ --enable_xqa disable \ @@ -59,8 +58,8 @@ We first describe how to run each model on a single GPU. We then provide general --max_multimodal_len 256 # 8 (max_batch_size) * 32 (num_visual_features) # Same command for decoder but don't set --max_multimodal_len - trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/tp1/pp1/decoder \ - --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1/decoder \ + trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/decoder \ + --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/decoder \ --paged_kv_cache disable \ --moe_plugin disable \ --enable_xqa disable \ @@ -79,7 +78,7 @@ We first describe how to run each model on a single GPU. We then provide general **NOTE**: `max_multimodal_len = max_batch_size * num_visual_features`, so if you change max_batch_size, max multimodal length **MUST** be changed accordingly. - The built T5 engines are located in `./tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1`. + The built T5 engines are located in `./tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16`. 3. Build TensorRT engines for visual components @@ -99,7 +98,7 @@ We first describe how to run each model on a single GPU. We then provide general --input_text "Question: which city is this? Answer:" \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1 + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 ``` ## BLIP2-OPT @@ -242,15 +241,14 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --output_dir tmp/trt_models/${MODEL_NAME}/float16 \ --tp_size 1 \ --pp_size 1 \ - --weight_data_type float32 \ --dtype float16 ``` 2. Build TRT-LLM engine from TRT-LLM checkpoint ```bash - trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/float16/tp1/pp1/decoder \ - --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/float16/tp1/decoder \ + trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/float16/decoder \ + --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/float16/decoder \ --paged_kv_cache disable \ --moe_plugin disable \ --enable_xqa disable \ @@ -267,7 +265,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --max_input_len 1 ``` - The built deplot engines are located in `./tmp/trt_engines/${MODEL_NAME}/1-gpu/float16/tp1`. + The built deplot engines are located in `./tmp/trt_engines/${MODEL_NAME}/1-gpu/float16`. 3. Build TensorRT engines for visual components @@ -287,7 +285,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --input_text "" \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/float16/tp1 + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/float16 ``` ## Fuyu @@ -327,7 +325,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1 + --llm_engine_dir trt_engines/${MODEL_NAME}/1-gpu/bfloat16 ``` ## Kosmos-2 @@ -366,7 +364,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1 + --llm_engine_dir trt_engines/${MODEL_NAME}/1-gpu/bfloat16 ``` ## LLaVA and VILA @@ -659,12 +657,11 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in --output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \ --tp_size 1 \ --pp_size 1 \ - --weight_data_type float32 \ --dtype bfloat16 \ --nougat - trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/tp1/pp1/decoder \ - --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1/decoder \ + trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/decoder \ + --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/decoder \ --paged_kv_cache disable \ --moe_plugin disable \ --enable_xqa disable \ @@ -688,7 +685,7 @@ Currently, CogVLM only support bfloat16 precision and doesn't support `remove_in python run.py \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ --visual_engine_dir visual_engines/${MODEL_NAME} \ - --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1 \ + --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ ``` Note: Nougat models usually do not need a text prompt. diff --git a/examples/nemotron/requirements.txt b/examples/nemotron/requirements.txt index eb20b2d7a..bc9c169fe 100644 --- a/examples/nemotron/requirements.txt +++ b/examples/nemotron/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/openai_triton/plugin_autogen/build_engine.py b/examples/openai_triton/plugin_autogen/build_engine.py index fcd052723..23b829f0b 100644 --- a/examples/openai_triton/plugin_autogen/build_engine.py +++ b/examples/openai_triton/plugin_autogen/build_engine.py @@ -2,6 +2,7 @@ import math # include plugins # yapf: disable +import os import sys import time from pathlib import Path @@ -17,7 +18,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.network import net_guard -sys.path.append('./tmp') +sys.path.append(os.environ.get('PLUGIN_GEN_WORKSPACE', './tmp')) from functional import fused_attention_kernel # isort:skip # yapf: enable @@ -113,8 +114,8 @@ def build_engine(builder: Builder, builder_config: BuilderConfig, print('dot:') print(network.to_dot()) - layer = network.get_layer_by_name( - "FmhaLayer/PLUGIN_V2_fused_attention_kernelPlugin_2").as_layer() + layer = network.get_layer_by_name(next( + network.get_layers()).name).as_layer() print('layer', layer.plugin.plugin_type) print('layer', layer.plugin.plugin_version) print('layer', layer.plugin.plugin_namespace) diff --git a/examples/opt/requirements.txt b/examples/opt/requirements.txt index eb20b2d7a..bc9c169fe 100644 --- a/examples/opt/requirements.txt +++ b/examples/opt/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/phi/README.md b/examples/phi/README.md index 59a672ea0..84573ab5b 100644 --- a/examples/phi/README.md +++ b/examples/phi/README.md @@ -1,6 +1,8 @@ # Phi -This document explains how to build the [phi-2](https://huggingface.co/microsoft/phi-2), [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) and [Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) models using TensorRT-LLM and run on a single GPU. +This document explains how to build the [phi-2](https://huggingface.co/microsoft/phi-2), [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct), +[Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct), [Phi-3-small-8k-instruct](https://huggingface.co/microsoft/Phi-3-small-8k-instruct), and [Phi-3-small-128k-instruct](https://huggingface.co/microsoft/Phi-3-small-128k-instruct) +models using TensorRT-LLM and run on a single GPU. - [Phi](#phi) - [Overview](#overview) @@ -13,9 +15,10 @@ This document explains how to build the [phi-2](https://huggingface.co/microsoft ## Overview -The TensorRT-LLM Phi implementation can be found in [`tensorrt_llm/models/phi/model.py`](../../tensorrt_llm/models/phi/model.py) and [`tensorrt_llm/models/phi3/model.py`](../../tensorrt_llm/models/phi3/model.py). The TensorRT-LLM Phi example code is located in [`examples/phi`](./). There is one file: +The TensorRT-LLM Phi implementation can be found in [`tensorrt_llm/models/phi/model.py`](../../tensorrt_llm/models/phi/model.py) and [`tensorrt_llm/models/phi3/model.py`](../../tensorrt_llm/models/phi3/model.py). The TensorRT-LLM Phi example code is located in [`examples/phi`](./). There are two files: * [`convert_checkpoint.py`](./convert_checkpoint.py) to convert a checkpoint from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) format to the TensorRT-LLM format +* [`postprocess_quant_checkpoint.py`](./postprocess_quant_checkpoint.py) to post-process FP8 or INT8 SmoothQuant quantized checkpoints for Phi-3-small variants. In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation: @@ -25,14 +28,17 @@ In addition, there are two shared files in the parent folder [`examples`](../) f ## Support Matrix * FP16 * BF16 + * FP8 * Tensor Parallel ## Support Matrix -| Model Name | FP16 | BF16 | TP | -| :--------------: | :---: | :---: | :---: | -| phi-2 | Y | Y | Y | -| Phi-3-mini-4k-instruct | Y | Y | | -| Phi-3-mini-128k-instruct | Y | Y | | +| Model Name | FP16 | BF16 | FP8 | TP | +| :--------------: | :---: | :---: | :---: | :---: | +| phi-2 | Y | Y | | Y | +| Phi-3-mini-4k-instruct | Y | Y | | | +| Phi-3-mini-128k-instruct | Y | Y | | | +| Phi-3-small-8k-instruct | Y | Y | Y | Y | +| Phi-3-small-128k-instruct | Y | Y | Y | Y | * Model Name: the name of the model, the same as the name on HuggingFace * TP: Tensor Parallel @@ -48,7 +54,7 @@ pip install -r requirements.txt ``` ```bash -export MODEL_TYPE="phi-2" # or Phi-3-mini-4k-instruct, Phi-3-mini-128k-instruct +export MODEL_TYPE="phi-2" # or Phi-3-mini-4k-instruct, Phi-3-mini-128k-instruct, Phi-3-small-8k-instruct, Phi-3-small-128k-instruct python ./convert_checkpoint.py --model_type ${MODEL_TYPE} \ --model_dir "microsoft/${MODEL_TYPE}" \ --output_dir ./phi-checkpoint \ @@ -119,3 +125,37 @@ python3 ../summarize.py --engine_dir ./phi-engine-tp2 \ --check_accuracy \ --tensorrt_llm_rouge1_threshold 20 ``` + + +### 5. Quantization options for Phi-3-small + +Phi-3-small variants support post-training quantization to FP8 and INT8 SmoothQuant formats. + +FP8 checkpoints can be built as follows: + +```bash +DTYPE=bfloat16 +python3 ../quantization/quantize.py \ + --model_dir phi3-model \ + --output_dir ./phi3-checkpoint \ + --dtype $DTYPE \ + --qformat fp8 --kv_cache_dtype fp8 + +python3 postprocess_quant_checkpoint.py --checkpoint_dir ./phi3-checkpoint +``` + +INT8 checkpoints can be built as follows: + +```bash +DTYPE=bfloat16 +python3 ../quantization/quantize.py \ + --model_dir phi3-model \ + --output_dir ./phi3-checkpoint \ + --dtype $DTYPE \ + --qformat int8_sq --kv_cache_dtype int8 + +python3 postprocess_quant_checkpoint.py --checkpoint_dir ./phi3-checkpoint +``` + +The commands to [build TensorRT engines](#2-build-tensorrt-engines) from quantized checkpoints +and to run [summarization test](#3-summarization-using-the-phi-model) are same as those for unquantized checkpoints. diff --git a/examples/phi/convert_checkpoint.py b/examples/phi/convert_checkpoint.py index be66fb6e1..7e3ebc6ff 100644 --- a/examples/phi/convert_checkpoint.py +++ b/examples/phi/convert_checkpoint.py @@ -17,26 +17,61 @@ import time import tensorrt_llm -from tensorrt_llm.models import Phi3ForCausalLM, PhiForCausalLM +from tensorrt_llm.models import (Phi3ForCausalLM, Phi3SmallForCausalLM, + PhiForCausalLM) def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', type=str, default=None) + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') parser.add_argument('--dtype', type=str, default='float16', choices=['float32', 'bfloat16', 'float16']) + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) parser.add_argument('--output_dir', type=str, default='tllm_checkpoint', help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument('--model_type', + type=str, + default='phi-2', + choices=[ + 'phi-2', 'Phi-3-mini-4k-instruct', + 'Phi-3-mini-128k-instruct', + 'Phi-3-small-8k-instruct', + 'Phi-3-small-128k-instruct' + ], + help='Model to be converted.') parser.add_argument( - '--model_type', - type=str, - default='phi-2', - choices=['phi-2', 'Phi-3-mini-4k-instruct', 'Phi-3-mini-128k-instruct'], - help='Model to be converted.') + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') args = parser.parse_args() return args @@ -45,15 +80,30 @@ def parse_arguments(): if __name__ == '__main__': print(tensorrt_llm.__version__) args = parse_arguments() + assert args.pp_size == 1, "Pipeline parallelism is not supported." tik = time.time() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - modelForCausalLM = PhiForCausalLM if args.model_type == "phi-2" else Phi3ForCausalLM + modelForCausalLM = None + if args.model_type == 'phi-2': + modelForCausalLM = PhiForCausalLM + elif args.model_type in [ + 'Phi-3-mini-4k-instruct', 'Phi-3-mini-128k-instruct' + ]: + modelForCausalLM = Phi3ForCausalLM + elif args.model_type in [ + 'Phi-3-small-8k-instruct', 'Phi-3-small-128k-instruct' + ]: + modelForCausalLM = Phi3SmallForCausalLM + else: + assert False, "Invalid model type" + modelForCausalLM.convert_hf_checkpoint(args.model_dir, dtype=args.dtype, - output_dir=args.output_dir) + output_dir=args.output_dir, + args=args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) diff --git a/examples/phi/postprocess_quant_checkpoint.py b/examples/phi/postprocess_quant_checkpoint.py new file mode 100644 index 000000000..8c9cc28bb --- /dev/null +++ b/examples/phi/postprocess_quant_checkpoint.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import time + +import safetensors +from safetensors.torch import save_file + +import tensorrt_llm +from tensorrt_llm.models.phi3.phi3small.convert import shuffle_qkv_weights + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--checkpoint_dir', type=str, default=None) + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + print(tensorrt_llm.__version__) + args = parse_arguments() + tensorrt_llm.logger.set_level('info') + + tik = time.time() + with open(f"{args.checkpoint_dir}/config.json", "r") as f: + config = json.load(f) + + weights = {} + with safetensors.safe_open(f"{args.checkpoint_dir}/rank0.safetensors", + framework="pt") as f: + for k in f.keys(): + weights[k] = f.get_tensor(k) + + # Transform QKV weights from custom Phi3Small format to TRT-LLM format + num_total_heads = config[ + 'num_attention_heads'] + 2 * config['num_key_value_heads'] + for key, value in weights.items(): + if "qkv." in key: + if 'scaling_factor' in key and value.shape[0] % num_total_heads != 0: + continue + weights[key] = shuffle_qkv_weights(value, config) + + save_file(weights, f'{args.checkpoint_dir}/rank0.safetensors') + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') diff --git a/examples/phi/requirements.txt b/examples/phi/requirements.txt index e57d3578d..58febb836 100644 --- a/examples/phi/requirements.txt +++ b/examples/phi/requirements.txt @@ -1,6 +1,7 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 einops~=0.7.0 +tiktoken==0.6.0 diff --git a/examples/quantization/requirements.txt b/examples/quantization/requirements.txt index 3a3b76342..e6c752bf0 100644 --- a/examples/quantization/requirements.txt +++ b/examples/quantization/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets>=2.14.4 nemo-toolkit[all]<=1.20.0,>=1.18.0 rouge_score~=0.1.2 diff --git a/examples/qwen/README.md b/examples/qwen/README.md index ee8bb6957..bc2150788 100644 --- a/examples/qwen/README.md +++ b/examples/qwen/README.md @@ -10,6 +10,7 @@ This document shows how to build and run a [Qwen](https://huggingface.co/Qwen) m - [Build TensorRT engine(s)](#build-tensorrt-engines) - [INT8 KV cache](#int8-kv-cache) - [SmoothQuant](#smoothquant) + - [FP8 PTQ](#fp8-post-training-quantization) - [INT4-GPTQ](#int4-gptq) - [INT4-AWQ](#int4-awq) - [Run](#run) @@ -28,19 +29,20 @@ In addition, there are two shared files in the parent folder [`examples`](../) f * [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset. ## Support Matrix -| Model Name | FP16/BF16 | WO | AWQ | GPTQ | SQ | TP | PP | Arch | -| :-------------: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :-----: | -| Qwen-1_8B(-Chat) | Y | Y | Y* | Y | Y | Y | Y | Ampere+ | -| Qwen-7B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Ampere+ | -| Qwen-14B(-Chat) | Y | Y | Y* | Y | Y | Y | Y | Ampere+ | -| Qwen-72B(-Chat) | Y | Y | - | Y | Y | Y | Y | Ampere+ | -| Qwen1.5-0.5B(-Chat)| Y | Y | Y | Y | Y | Y | Y | Ampere+ | -| Qwen1.5-1.8B(-Chat)| Y | Y | Y | Y | Y | Y | Y | Ampere+ | -| Qwen1.5-4B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Ampere+ | -| Qwen1.5-7B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Ampere+ | -| Qwen1.5-14B(-Chat) | Y | Y | Y* | Y | Y | Y | Y | Ampere+ | -| Qwen1.5-32B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Ampere+ | -| Qwen1.5-72B(-Chat) | Y | Y | - | Y | Y | Y | Y | Ampere+ | +| Model Name | FP16/BF16 | FP8 | WO | AWQ | GPTQ | SQ | TP | PP | Arch | +| :-------------: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :-----: | +| Qwen-1_8B(-Chat) | Y | Y | Y | Y* | Y | Y | Y | Y | Ampere+ | +| Qwen-7B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen-14B(-Chat) | Y | Y | Y | Y* | Y | Y | Y | Y | Ampere+ | +| Qwen-72B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen1.5-0.5B(-Chat)| Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen1.5-1.8B(-Chat)| Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen1.5-4B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen1.5-7B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen1.5-14B(-Chat) | Y | Y | Y | Y* | Y | Y | Y | Y | Ampere+ | +| Qwen1.5-32B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen1.5-72B(-Chat) | Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | +| Qwen1.5-110B(-Chat)| Y | Y | Y | Y | Y | Y | Y | Y | Ampere+ | *Please note that these models supports AWQ only with single GPU. @@ -187,7 +189,7 @@ INT8 KV cache could be enabled to reduce memory footprint. It will bring more pe For INT8 KV cache, [`convert_checkpoint.py`](./convert_checkpoint.py) features a `--int8_kv_cache` option. Setting `--int8_kv_cache` will calibrate the model, -and then export the scaling factors needed for INT8 KV cache inference. Remember to set `--strongly_typed` when building the engine if you are not using INT8 weight only quantization at the same time. +and then export the scaling factors needed for INT8 KV cache inference. Example: @@ -199,7 +201,6 @@ python convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ \ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_sq \ --output_dir ./engine_outputs \ - --strongly_typed \ --gemm_plugin float16 ``` @@ -241,6 +242,29 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_sq \ --gemm_plugin float16 ``` +#### FP8 Post-Training Quantization + +The examples below uses the NVIDIA Modelopt (AlgorithMic Model Optimization) toolkit for the model quantization process. + +First make sure Modelopt toolkit is installed (see [examples/quantization/README.md](/examples/quantization/README.md#preparation)) + + +```bash +# Quantize model into FP8 and export trtllm checkpoint +python ../quantization/quantize.py --model_dir ./tmp/Qwen/7B/ \ + --dtype float16 \ + --qformat fp8 \ + --kv_cache_dtype fp8 \ + --output_dir ./tllm_checkpoint_1gpu_fp8 \ + --calib_size 512 + +# Build trtllm engines from the trtllm checkpoint +# Enable fp8 context fmha to get further acceleration by setting `--use_fp8_context_fmha enable` +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp8 \ + --output_dir ./engine_outputs \ + --gemm_plugin float16 \ +``` + #### INT4-GPTQ You may find the official GPTQ quantized INT4 weights of Qwen-7B-Chat here: [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4). And you need to first install auto-gptq: ```bash diff --git a/examples/qwen/convert_checkpoint.py b/examples/qwen/convert_checkpoint.py index 4899dccce..e9478d2dc 100644 --- a/examples/qwen/convert_checkpoint.py +++ b/examples/qwen/convert_checkpoint.py @@ -251,7 +251,7 @@ def from_cli_args(args): 'tp_size': args.tp_size, 'pp_size': args.pp_size }, - 'quantization': args_to_quantization(args).asdict() + 'quantization': args_to_quantization(args).to_dict() } config.update(args_to_build_options(args)) return config diff --git a/examples/qwen/requirements.txt b/examples/qwen/requirements.txt index 2a764e13c..0f5df6622 100644 --- a/examples/qwen/requirements.txt +++ b/examples/qwen/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 @@ -9,7 +9,7 @@ tiktoken einops # optional dependencies -gradio==3.40.1 +gradio==4.19.2 mdtex2html sse_starlette aiohttp_sse_client diff --git a/examples/qwenvl/requirements.txt b/examples/qwenvl/requirements.txt index 12ddc9461..a5134248e 100644 --- a/examples/qwenvl/requirements.txt +++ b/examples/qwenvl/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 @@ -8,3 +8,5 @@ sentencepiece~=0.1.99 tiktoken einops auto-gptq +matplotlib +torchvision==0.17.1 diff --git a/examples/recurrentgemma/convert_checkpoint.py b/examples/recurrentgemma/convert_checkpoint.py index e91f33c74..a80eb1ac5 100644 --- a/examples/recurrentgemma/convert_checkpoint.py +++ b/examples/recurrentgemma/convert_checkpoint.py @@ -479,9 +479,11 @@ def main(): intermediate_size=ckpt_config["intermediate_size"], norm_epsilon=1e-6, position_embedding_type="rope_gpt_neox", - world_size=args.world_size, - tp_size=args.world_size, - pp_size=1, + mapping={ + 'world_size': args.world_size, + 'tp_size': args.world_size, + 'pp_size': 1 + }, gpus_per_node=8, quantization=quant_config, conv_kernel=4, diff --git a/examples/recurrentgemma/requirements.txt b/examples/recurrentgemma/requirements.txt index 8833ae3b7..08a958192 100644 --- a/examples/recurrentgemma/requirements.txt +++ b/examples/recurrentgemma/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.14.5 evaluate rouge_score diff --git a/examples/run.py b/examples/run.py index ad5232428..2857552f6 100644 --- a/examples/run.py +++ b/examples/run.py @@ -16,6 +16,7 @@ import argparse import ast import csv +import os from pathlib import Path import numpy as np @@ -210,6 +211,7 @@ def parse_arguments(args=None): action='store_true', help='Enables chunked context (only available with cpp session).', ) + parser.add_argument('--no_repeat_ngram_size', type=int, default=None) return parser.parse_args(args=args) @@ -368,7 +370,19 @@ def main(args): runtime_rank = tensorrt_llm.mpi_rank() logger.set_level(args.log_level) - model_name, model_version = read_model_name(args.engine_dir) + # different handling if encoder-decoder models + is_enc_dec = { + name + for name in os.listdir(args.engine_dir) + if os.path.isdir(os.path.join(args.engine_dir, name)) + } == {'encoder', 'decoder'} + if is_enc_dec: + logger.warning( + "This path is an encoder-decoder model. Using different handling.") + assert not args.use_py_session, "Encoder-decoder models don't have a unified python runtime, please use its own examples/enc_dec/run.py instead." + + model_name, model_version = read_model_name( + args.engine_dir) if not is_enc_dec else ("", "") if args.tokenizer_dir is None: logger.warning( "tokenizer_dir is not specified. Try to infer from model_name, but this may be incorrect." @@ -408,7 +422,17 @@ def main(args): num_prepend_vtokens=args.num_prepend_vtokens, model_name=model_name, model_version=model_version) - input_lengths = [x.size(0) for x in batch_input_ids] + + if is_enc_dec: + encoder_input_ids = batch_input_ids + decoder_input_ids = [ + torch.tensor([pad_id], dtype=torch.int32) for _ in batch_input_ids + ] # by default decoder_start_token_id for T5 + + input_lengths = [x.size(0) for x in decoder_input_ids + ] if is_enc_dec else [x.size(0) for x in batch_input_ids] + encoder_input_lengths = [x.size(0) + for x in encoder_input_ids] if is_enc_dec else None if not PYTHON_BINDINGS and not args.use_py_session: logger.warning( @@ -429,6 +453,8 @@ def main(args): lora_ckpt_source=args.lora_ckpt_source, gpu_weights_percent=args.gpu_weights_percent, ) + if not args.use_py_session: + runner_kwargs.update(is_enc_dec=is_enc_dec) if args.medusa_choices is not None: args.medusa_choices = ast.literal_eval(args.medusa_choices) assert args.temperature == 1.0, "Medusa should use temperature == 1.0" @@ -437,7 +463,8 @@ def main(args): if not args.use_py_session: runner_kwargs.update( max_batch_size=len(batch_input_ids), - max_input_len=max(input_lengths), + max_input_len=max( + encoder_input_lengths if is_enc_dec else input_lengths), max_output_len=args.max_output_len, max_beam_width=args.num_beams, max_attention_window_size=args.max_attention_window_size, @@ -452,7 +479,9 @@ def main(args): with torch.no_grad(): outputs = runner.generate( - batch_input_ids, + batch_input_ids=decoder_input_ids + if is_enc_dec else batch_input_ids, + encoder_input_ids=encoder_input_ids if is_enc_dec else None, max_new_tokens=args.max_output_len, max_attention_window_size=args.max_attention_window_size, sink_token_length=args.sink_token_length, @@ -476,6 +505,7 @@ def main(args): prompt_tasks=args.prompt_tasks, streaming=args.streaming, output_sequence_lengths=True, + no_repeat_ngram_size=args.no_repeat_ngram_size, return_dict=True, medusa_choices=args.medusa_choices) torch.cuda.synchronize() diff --git a/examples/sample_weight_stripping/README.md b/examples/sample_weight_stripping/README.md index f33e0af04..80ed27da5 100644 --- a/examples/sample_weight_stripping/README.md +++ b/examples/sample_weight_stripping/README.md @@ -211,7 +211,6 @@ python ../quantization/quantize.py --model_dir /llm-models/llama-models-v2/llama trtllm-build --checkpoint_dir ./llama2-70b-hf-fp8-tp2 \ --output_dir engines/llama2-70b-hf-fp8-tp2 \ --gemm_plugin float16 \ - --strongly_typed \ --workers 2 ``` diff --git a/examples/skywork/requirements.txt b/examples/skywork/requirements.txt index 8e67acd43..ec0d2d3df 100644 --- a/examples/skywork/requirements.txt +++ b/examples/skywork/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets~=2.16.1 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/smaug/requirements.txt b/examples/smaug/requirements.txt index dc02169b6..ab2e413f3 100644 --- a/examples/smaug/requirements.txt +++ b/examples/smaug/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/summarize.py b/examples/summarize.py index 9c6faec2c..883e5a8b5 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -15,6 +15,7 @@ import argparse import ast +import os from pathlib import Path import evaluate @@ -139,8 +140,10 @@ def main(args): f.write(f'Tokenizer path: {args.tokenizer_dir}\n') # TODO: Add random_seed flag in gptj - metric_tensorrt_llm = [evaluate.load("rouge") for _ in range(num_beams)] - metric_hf = [evaluate.load("rouge") for _ in range(num_beams)] + rouge_dir = args.rouge_dir if args.rouge_dir and os.path.exists( + args.rouge_dir) else "rouge" + metric_tensorrt_llm = [evaluate.load(rouge_dir) for _ in range(num_beams)] + metric_hf = [evaluate.load(rouge_dir) for _ in range(num_beams)] for i in range(num_beams): metric_tensorrt_llm[i].seed = 0 metric_hf[i].seed = 0 @@ -149,7 +152,8 @@ def main(args): def _prepare_inputs(batch_input_texts, eval_task='summarize', - add_special_tokens=True): + add_special_tokens=True, + min_input_length=0): batch_size = len(batch_input_texts) append_str = ' TL;DR: ' if eval_task == 'summarize' else '' batch_input_ids = [] @@ -193,17 +197,23 @@ def _prepare_inputs(batch_input_texts, truncation=True, max_length=test_token_num).squeeze(0) - batch_input_ids.append(input_ids) + if input_ids.numel() > min_input_length: + batch_input_ids.append(input_ids) return batch_input_ids def eval_trt_llm(datapoint, eval_task='summarize', eval_ppl=False, - add_special_tokens=True): + add_special_tokens=True, + min_input_length=0): batch_size = len(datapoint[dataset_input_key]) batch_input_ids = _prepare_inputs(datapoint[dataset_input_key], eval_task=eval_task, - add_special_tokens=add_special_tokens) + add_special_tokens=add_special_tokens, + min_input_length=min_input_length) + batch_size = len(batch_input_ids) + if batch_size == 0: + return [], [], [], {} input_lengths = [x.size(0) for x in batch_input_ids] with torch.no_grad(): @@ -280,7 +290,8 @@ def eval_trt_llm(datapoint, def eval_hf(datapoint, eval_task='summarize', eval_ppl=False, - add_special_tokens=True): + add_special_tokens=True, + min_input_length=0): batch_size = len(datapoint[dataset_input_key]) if batch_size > 1: logger.warning( @@ -288,7 +299,11 @@ def eval_hf(datapoint, ) batch_input_ids = _prepare_inputs(datapoint[dataset_input_key], eval_task=eval_task, - add_special_tokens=add_special_tokens) + add_special_tokens=add_special_tokens, + min_input_length=min_input_length) + batch_size = len(batch_input_ids) + if batch_size == 0: + return [], [], [], [[] for _ in range(batch_size)] input_lengths = [x.size(0) for x in batch_input_ids] # Left padding for HF max_length = max(input_lengths) @@ -413,7 +428,8 @@ def eval_hf(datapoint, output, *_ = eval_trt_llm(datapoint, eval_task=args.eval_task, eval_ppl=args.eval_ppl, - add_special_tokens=args.add_special_tokens) + add_special_tokens=args.add_special_tokens, + min_input_length=args.min_input_length) if runtime_rank == 0 and args.eval_task != "eval_context_ppl": logger.info( "---------------------------------------------------------") @@ -440,7 +456,11 @@ def eval_hf(datapoint, datapoint, eval_task=args.eval_task, eval_ppl=args.eval_ppl, - add_special_tokens=args.add_special_tokens) + add_special_tokens=args.add_special_tokens, + min_input_length=args.min_input_length) + if output_tensorrt_llm == []: + data_point_idx += max_batch_size + continue profiler.stop('tensorrt_llm') if runtime_rank == 0: input_lengths = lengths_info['input_lengths'] @@ -520,7 +540,8 @@ def eval_hf(datapoint, output, *_ = eval_hf(datapoint, eval_task=args.eval_task, eval_ppl=args.eval_ppl, - add_special_tokens=args.add_special_tokens) + add_special_tokens=args.add_special_tokens, + min_input_length=args.min_input_length) if runtime_rank == 0 and args.eval_task != "eval_context_ppl": logger.info( "---------------------------------------------------------") @@ -547,9 +568,12 @@ def eval_hf(datapoint, datapoint, eval_task=args.eval_task, eval_ppl=args.eval_ppl, - add_special_tokens=args.add_special_tokens) + add_special_tokens=args.add_special_tokens, + min_input_length=args.min_input_length) profiler.stop('hf') - + if output_hf == []: + data_point_idx += max_batch_size + continue if runtime_rank == 0: seq_lengths = [len(tokens) for tokens in token_list] total_output_token_count_hf += sum(seq_lengths) @@ -611,8 +635,8 @@ def eval_hf(datapoint, f" Per-token perplexity: {np.mean(ppls_trt_llm[beam_idx])}" ) if args.check_accuracy and beam_idx == 0: - assert np.mean(ppls_trt_llm[beam_idx] - ) < args.tensorrt_llm_ppl_threshold + avg_ppl = np.mean(ppls_trt_llm[beam_idx]) + assert avg_ppl < args.tensorrt_llm_ppl_threshold, f"[FAILED] average PPL ({avg_ppl}) is larger than threshold ({args.tensorrt_llm_ppl_threshold})" if test_hf: np.random.seed(0) # rouge score use sampling to compute the score logger.info( @@ -690,6 +714,11 @@ def eval_hf(datapoint, parser.add_argument('--max_ite', type=int, default=20) parser.add_argument('--output_len', type=int, default=100) parser.add_argument('--max_input_length', type=int, default=923) + parser.add_argument( + '--min_input_length', + type=int, + default=0, + help='skip the sentences which are shorter than min_input_length.') parser.add_argument( '--max_attention_window_size', type=int, @@ -780,6 +809,15 @@ def eval_hf(datapoint, default=None, nargs="+", help="The list of LoRA task uids; use -1 to disable the LoRA module") + + parser.add_argument( + '--rouge_dir', + default=None, + type=str, + help= + "evaluate.load('rouge') will attempt to pull rouge package from HF. Use cached rouge can avoid network outage of host or HF." + ) + args = parser.parse_args() main(args) diff --git a/examples/utils.py b/examples/utils.py index f16d2faed..27f043496 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -12,12 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json from pathlib import Path from typing import Optional -from transformers import AutoTokenizer, T5Tokenizer +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer from tensorrt_llm.builder import get_engine_version @@ -30,6 +29,7 @@ 'GPTJForCausalLM': 'EleutherAI/gpt-j-6b', 'GPTNeoXForCausalLM': 'EleutherAI/gpt-neox-20b', 'InternLMForCausalLM': 'internlm/internlm-chat-7b', + 'InternLM2ForCausalLM': 'internlm/internlm2-chat-7b', 'LlamaForCausalLM': 'meta-llama/Llama-2-7b-hf', 'MPTForCausalLM': 'mosaicml/mpt-7b', 'PhiForCausalLM': 'microsoft/phi-2', @@ -37,9 +37,17 @@ 'QWenForCausalLM': 'Qwen/Qwen-7B', } +INTERNLM_META_INSTRUCTION = """You are an AI assistant whose name is InternLM (书生·浦语). +- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless. +- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文. +""" + DEFAULT_PROMPT_TEMPLATES = { 'InternLMForCausalLM': "<|User|>:{input_text}\n<|Bot|>:", + 'InternLM2ForCausalLM': + "<|im_start|>system\n" + INTERNLM_META_INSTRUCTION + + "<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n", 'QWenForCausalLM': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n", } @@ -97,13 +105,18 @@ def load_tokenizer(tokenizer_dir: Optional[str] = None, padding_side='left', truncation_side='left', legacy=False) + elif model_name == 'Grok1ModelForCausalLM': + tokenizer = LlamaTokenizer(vocab_file=vocab_file, + padding_side='left', + truncation_side='left', + legacy=False, + use_fast=False) else: # For gpt-next, directly load from tokenizer.model tokenizer = T5Tokenizer(vocab_file=vocab_file, padding_side='left', truncation_side='left', legacy=False) - if model_name == 'QWenForCausalLM' and model_version == 'qwen': with open(Path(tokenizer_dir) / "generation_config.json") as f: gen_config = json.load(f) diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index cd128d416..a9a1dca3c 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://pypi.nvidia.com -tensorrt_llm==0.11.0.dev2024052800 +tensorrt_llm==0.11.0.dev2024060400 tiktoken datasets kaldialign diff --git a/requirements-dev-windows.txt b/requirements-dev-windows.txt index 84ec402dd..819e0e3eb 100644 --- a/requirements-dev-windows.txt +++ b/requirements-dev-windows.txt @@ -14,3 +14,6 @@ pytest-xdist rouge_score cloudpickle typing-extensions==4.8.0 +jsonlines==4.0.0 +jieba==0.42.1 +rouge==1.0.1 diff --git a/requirements.txt b/requirements.txt index 259a73da8..905fa1b44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ tensorrt==10.0.1 # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html#rel-24-04 uses 2.3.0a0. torch>=2.3.0a,<=2.3.0 nvidia-modelopt~=0.11,<0.12 -transformers>=4.38.2 +transformers==4.40.2 wheel optimum evaluate diff --git a/scripts/build_cpp_examples.py b/scripts/build_cpp_examples.py new file mode 100644 index 000000000..c41c135f1 --- /dev/null +++ b/scripts/build_cpp_examples.py @@ -0,0 +1,84 @@ +import argparse +import contextlib +import logging +import os +import platform +import shutil +import subprocess +from os import PathLike +from pathlib import Path + + +@contextlib.contextmanager +def working_directory(path: PathLike): + """Changes working directory and returns to previous on exit.""" + prev_cwd = Path.cwd() + os.chdir(path) + try: + yield + finally: + os.chdir(prev_cwd) + + +def build_cpp_examples(build_dir: PathLike, trt_dir: PathLike, + loglevel: int) -> None: + logging.basicConfig(level=loglevel, + format='%(asctime)s - %(levelname)s - %(message)s') + # Convert input paths to pathlib.Path objects + build_dir = Path(build_dir) + trt_dir = Path(trt_dir) + trt_include_dir = trt_dir / "include" + trt_lib_dir = trt_dir / "lib" + + assert trt_include_dir.is_dir() + assert trt_lib_dir.is_dir() + + def cmake_parse(path: PathLike) -> str: + return str(path).replace("\\", "/") + + # Remove the build directory if it exists + if build_dir.exists(): + logging.info(f"Removed directory: {build_dir}") + shutil.rmtree(build_dir) + + # Create the build directory + build_dir.mkdir(parents=True, exist_ok=True) + + # Change to the build directory + with working_directory(build_dir): + # Run CMake with the specified TensorRT directories + generator = ["-GNinja"] if platform.system() == "Windows" else [] + generate_command = [ + 'cmake', '-S', '..', '-B', '.', + f'-DTRT_LIB_DIR={cmake_parse(trt_lib_dir)}', + f'-DTRT_INCLUDE_DIR={cmake_parse(trt_include_dir)}' + ] + generator + logging.info(f"Executing {generate_command}") + subprocess.run(generate_command, check=True) + + # Build the project using make + build_command = ["cmake", "--build", ".", "--config", "Release"] + logging.info(f"Executing {build_command}") + subprocess.run(build_command, check=True) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Build C++ examples') + parser.add_argument('--build-dir', + default='examples/cpp/executor/build', + help='Build directory path') + parser.add_argument('--trt-dir', + default='/usr/local/tensorrt', + help='TensorRT directory path') + parser.add_argument('-v', + '--verbose', + help="verbose", + action="store_const", + dest="loglevel", + const=logging.DEBUG, + default=logging.INFO) + cli = parser.parse_args() + + args = vars(cli) + print(args) # Log on Jenkins instance. + build_cpp_examples(**args) diff --git a/scripts/build_cpp_examples.sh b/scripts/build_cpp_examples.sh deleted file mode 100755 index 899870fc4..000000000 --- a/scripts/build_cpp_examples.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash -set -e - -BUILD_DIR="examples/cpp/executor/build" -TRT_DIR="/usr/local/tensorrt" - -rm -rf $BUILD_DIR -mkdir -p $BUILD_DIR -pushd ${BUILD_DIR} - -cmake .. -DTRT_LIB_DIR=${TRT_DIR}/lib -DTRT_INCLUDE_DIR=${TRT_DIR}/include -make -j - -popd diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index 2c30eff24..a52e801be 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -39,7 +39,8 @@ def working_directory(path): os.chdir(prev_cwd) -def main(build_type: str = "Release", +def main(*, + build_type: str = "Release", build_dir: Path = None, dist_dir: Path = None, cuda_architectures: str = None, @@ -50,6 +51,7 @@ def main(build_type: str = "Release", nccl_root: str = None, clean: bool = False, use_ccache: bool = False, + fast_build: bool = False, cpp_only: bool = False, install: bool = False, skip_building_wheel: bool = False, @@ -149,6 +151,9 @@ def main(build_type: str = "Release", f"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache" ) + if fast_build: + cmake_def_args.append(f"-DFAST_BUILD=ON") + build_pyt = "OFF" if cpp_only else "ON" th_common_lib = "" if cpp_only else "th_common" build_pybind = "OFF" if cpp_only else "ON" @@ -304,6 +309,14 @@ def get_pybind_lib(): default=False, action="store_true", help="Use ccache compiler driver") + parser.add_argument( + "--fast_build", + "-f", + default=False, + action="store_true", + help= + "Skip compiling some kernels to accelerate compilation -- for development only" + ) parser.add_argument("--job_count", "-j", const=cpu_count(), diff --git a/tensorrt_llm/_common.py b/tensorrt_llm/_common.py index 879f1e289..9c630f96f 100644 --- a/tensorrt_llm/_common.py +++ b/tensorrt_llm/_common.py @@ -51,6 +51,8 @@ def _init(log_level: object = None) -> None: logger.info('Skipping TensorRT-LLM init.') return + logger.info('Starting TensorRT-LLM init.') + # load plugin lib _load_plugin_lib() diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index b072911a5..dc79ad229 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import copy import gc +import inspect import json import math import struct @@ -161,6 +162,13 @@ def str_dtype_to_torch(dtype): return ret +_torch_dtype_to_str_dict = {v: k for k, v in _str_to_torch_dtype_dict.items()} + + +def torch_dtype_to_str(dtype): + return _torch_dtype_to_str_dict[dtype] + + _str_to_trt_dtype_dict = dict(float16=trt.float16, float32=trt.float32, int64=trt.int64, @@ -437,6 +445,21 @@ def set_obj_attrs( setattr(obj, key, value) +def get_init_params(obj, cls=None): + """ + Get all parameters in object's __init__. + Use cls's __init__ as filter if cls provided. + """ + names = None + if cls is not None: + names = set(list(inspect.signature(cls.__init__).parameters)[1:]) + return { + name: getattr(obj, name) + for name in list(inspect.signature(obj.__class__.__init__).parameters) + [1:] if names is None or name in names + } + + def release_gc(): ''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately. This could be used when some states might be kept in memory even after the variables are deleted. diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 3cd8a04b2..6aa015718 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -89,7 +89,10 @@ def to_dict(self) -> Dict: class Builder(): - _ALLOWED_PRECISIONS = ['float32', 'float16', 'bfloat16'] + _ALLOWED_PRECISIONS = [ + 'float32', 'float16', 'bfloat16', trt.DataType.HALF, trt.DataType.FLOAT, + trt.DataType.BF16 + ] def __init__(self): super().__init__() @@ -246,6 +249,9 @@ def _add_optimization_profile(self, network: Network, assert isinstance(builder_config, BuilderConfig) assert isinstance(network, Network) input_tensors = network._inputs + if len(input_tensors) == 0: + logger.warning("There are no inputs in the network!") + return num_profiles = len(list(input_tensors.values())[0].profiles) for i in range(num_profiles): logger.debug(f'Adding optimization profile {i+1}/{num_profiles}') @@ -653,39 +659,6 @@ def get_engine_version(engine_dir: str) -> Union[None, str]: return config['version'] -def optimize_model_with_config(model: PretrainedModel, - build_config: BuildConfig): - use_auto_parallel = build_config.auto_parallel_config.enabled - - if build_config.plugin_config.moe_plugin is None: - model = optimize_model(model, use_ootb_moe=True) - - if model.config.architecture not in ["EncoderModel", "DecoderModel"]: - model = optimize_model( - model, - use_fused_mlp=(build_config.use_fused_mlp - and not use_auto_parallel), - use_prompt_tuning=(build_config.max_prompt_embedding_table_size > - 0)) - - if model.config.architecture in ["RecurrentGemmaForCausalLM"]: - model = optimize_model(model, use_fused_rg_lru=True) - - if build_config.plugin_config.lora_plugin is not None: - model.use_lora(build_config.lora_config) - model = optimize_model( - model, - use_lora=True, - max_lora_rank=build_config.lora_config.max_lora_rank, - ) - - if model.config.quantization.quant_algo == QuantAlgo.FP8 and build_config.plugin_config.use_fp8_context_fmha: - model = optimize_model(model, use_fp8_context_fmha=True) - - model = optimize_model(model, use_unfused_qkv_gemm=use_auto_parallel) - return model - - def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: '''Build engine from given model and optimization options specified in the build_config WARNING: this function may change the given \p model object state in some optimization passes @@ -746,7 +719,38 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: raise RuntimeError( "Paged Context FMHA doesn't work with int8 kv cache currently.") - model = optimize_model_with_config(model, build_config) + use_auto_parallel = build_config.auto_parallel_config.enabled + gemm_swiglu_plugin = build_config.plugin_config.gemm_swiglu_plugin + if gemm_swiglu_plugin: + if not build_config.use_fused_mlp: + raise RuntimeError( + "GemmSwiGLU plugin requires --use_fused_mlp flag") + if gemm_swiglu_plugin not in ["fp8"]: + raise RuntimeError( + f"GemmSwiGLU plugin currently has limited support: fp8 only, " + f"got: {gemm_swiglu_plugin}") + + if build_config.plugin_config.lora_plugin is not None: + model.use_lora(build_config.lora_config) + + is_enc_dec = model.config.architecture in ["EncoderModel", "DecoderModel"] + model = optimize_model( + model, + use_ootb_moe=build_config.plugin_config.moe_plugin is None, + use_fused_mlp=(build_config.use_fused_mlp and not is_enc_dec + and not use_auto_parallel), + gemm_swiglu_plugin_dtype=gemm_swiglu_plugin, + use_fused_rg_lru=model.config.architecture + in ["RecurrentGemmaForCausalLM"], + use_unfused_qkv_gemm=use_auto_parallel, + use_prompt_tuning=(build_config.max_prompt_embedding_table_size > 0 + and not is_enc_dec), + use_lora=build_config.plugin_config.lora_plugin is not None, + max_lora_rank=build_config.lora_config.max_lora_rank, + use_fp8_context_fmha=( + model.config.quantization.quant_algo == QuantAlgo.FP8 + and build_config.plugin_config.use_fp8_context_fmha), + ) builder = Builder() builder_config = builder.create_builder_config( @@ -781,18 +785,10 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: network.plugin_config.weight_only_quant_matmul_plugin = model.config.dtype if use_smooth_quant and model.config.quantization.use_plugin_sq: network.plugin_config.set_smooth_quant_plugins() - # we will remove this later when XQA supports quantized fp8 output. - if network.plugin_config.use_fp8_context_fmha: - network.plugin_config.enable_xqa = False - logger.warning( - "The XQA kernel is disabled by default as it doesn't support fp8 attention plugin output currently." - ) nccl_plugin = model.config.dtype if model.config.mapping.world_size > 1 else None network.plugin_config.set_nccl_plugin( nccl_plugin, network.plugin_config.use_custom_all_reduce) - use_auto_parallel = build_config.auto_parallel_config.enabled - with net_guard(network): # Prepare network.set_named_parameters(model.named_parameters()) diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index 1d1fc120a..ee420d494 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -30,9 +30,9 @@ from ..builder import BuildConfig, Engine, build from ..logger import logger from ..lora_manager import LoraConfig, LoraManager -from ..models import PretrainedConfig +from ..models import MODEL_MAP, PretrainedConfig from ..models.modeling_utils import (WEIGHT_LOADER_MODELS, QuantConfig, - SpeculativeDecodingMode, load_model) + SpeculativeDecodingMode) from ..plugin import PluginConfig, add_plugin_argument from ..quantization import QuantAlgo @@ -121,15 +121,7 @@ def parse_arguments(): action='store_true', default=False, help='Gather generation logits') - parser.add_argument( - '--strongly_typed', - action='store_true', - default=False, - help= - 'This option is introduced with TensorRT 9.1.0.1+ and will reduce the engine building time. ' - 'It\'s not expected to see performance or accuracy regression after enable this flag. ' - 'Note that, we may remove this flag in the future, and enable the feature by default.' - ) + parser.add_argument('--builder_opt', type=int, default=None) parser.add_argument('--logits_dtype', type=str, @@ -266,17 +258,7 @@ def build_model(build_config: BuildConfig, model_config: Union[str, PretrainedConfig] = None, model_cls=None, **kwargs) -> Engine: - if ckpt_dir is not None: - model_config = PretrainedConfig.from_json_file( - os.path.join(ckpt_dir, 'config.json')) - else: - assert model_config is not None - if isinstance(model_config, PretrainedConfig): - model_config = model_config - else: - model_config = PretrainedConfig.from_json_file(model_config) - - preprocess_model_config(model_config, **kwargs) + model_config = copy.deepcopy(model_config) logits_dtype = kwargs.get('logits_dtype') if logits_dtype is not None: @@ -307,7 +289,14 @@ def build_model(build_config: BuildConfig, rank_config = copy.deepcopy(model_config) rank_config.set_rank(rank) - model = load_model(rank_config, ckpt_dir, model_cls) + + assert architecture in MODEL_MAP, \ + f"Unsupported model architecture: {architecture}" + model_cls = MODEL_MAP[architecture] + if ckpt_dir is None: + model = model_cls(rank_config) + else: + model = model_cls.from_checkpoint(ckpt_dir, config=rank_config) is_checkpoint_pruned = getattr(rank_config, 'is_pruned', False) if build_config.plugin_config.lora_plugin is not None: @@ -353,14 +342,14 @@ def parallel_build(ckpt_dir_or_model_config: str, log_level: str = 'info', model_cls=None, **kwargs): - ckpt_dir = ckpt_dir_or_model_config if ckpt_dir_or_model_config.lower().endswith('.json'): - model_config = PretrainedConfig.from_json_file(ckpt_dir_or_model_config) + config_path = ckpt_dir_or_model_config ckpt_dir = None else: - model_config = PretrainedConfig.from_json_file( - os.path.join(ckpt_dir_or_model_config, 'config.json')) + config_path = os.path.join(ckpt_dir_or_model_config, 'config.json') + ckpt_dir = ckpt_dir_or_model_config + model_config = PretrainedConfig.from_json_file(config_path) preprocess_model_config(model_config, **kwargs) if build_config.auto_parallel_config.enabled: @@ -450,6 +439,7 @@ def main(): cluster_config = dict(cluster_key=args.cluster_key) else: cluster_config = infer_cluster_config() + build_config = BuildConfig.from_dict( { 'max_input_len': args.max_input_len, @@ -462,7 +452,7 @@ def main(): args.max_prompt_embedding_table_size, 'gather_context_logits': args.gather_context_logits, 'gather_generation_logits': args.gather_generation_logits, - 'strongly_typed': args.strongly_typed, + 'strongly_typed': True, 'builder_opt': args.builder_opt, 'weight_sparsity': args.weight_sparsity, 'profiling_verbosity': args.profiling_verbosity, diff --git a/tensorrt_llm/commands/refit.py b/tensorrt_llm/commands/refit.py index f75e582d9..238372def 100644 --- a/tensorrt_llm/commands/refit.py +++ b/tensorrt_llm/commands/refit.py @@ -16,7 +16,6 @@ from tensorrt_llm._utils import np_dtype_to_trt from tensorrt_llm.builder import EngineConfig, optimize_model_with_config from tensorrt_llm.models import MODEL_MAP -from tensorrt_llm.models.modeling_utils import load_model from ..logger import logger @@ -44,7 +43,13 @@ def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str, tik = time.time() rank_config = copy.deepcopy(engine_config.pretrained_config) rank_config.set_rank(rank) - model = load_model(rank_config, checkpoint_dir) + + architecture = rank_config.architecture + assert architecture in MODEL_MAP, \ + f"Unsupported model architecture: {architecture}" + model_cls = MODEL_MAP[architecture] + model = model_cls.from_checkpoint(checkpoint_dir, config=rank_config) + tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Load checkpoint(s) time: {t}') diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index ab2fc40e1..46cc393c0 100644 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -30,7 +30,8 @@ from ._utils import (bf16_array, bool_array, dim_resolve_negative, dim_to_trt_axes, dims_array, fp16_array, fp32_array, int32_array, int64_array, np_dtype_to_trt, - str_dtype_to_trt, trt_dtype_to_np, trt_gte_10) + str_dtype_to_trt, trt_dtype_to_np, trt_dtype_to_str, + trt_gte_10) from .network import PluginInfo, set_np_weight, set_plugin_info from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper from .quantization import QuantMode @@ -675,6 +676,7 @@ class AttentionMaskType(IntEnum): causal = 1 bidirectional = 2 bidirectionalglm = 3 # TODO: merge this mask into bidirectional + blocksparse = 4 class LayerNormType(IntEnum): @@ -1034,6 +1036,80 @@ def matmul(input: Tensor, return output +def gemm_swiglu(input: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + scale_d0: float = 1.0, + scale_d1: float = 1.0, + scale_output: float = 1.0) -> Tensor: + ''' + Add a matrix multiplication, followed by SwiGLU (`x * SiLU(gate)`) operation. + + The second SwiGLU operation takes the preceding tensor, splits it into two halves + along the last dimension, applies SiLU to the second half and multiply the results. The + behaviour is undefined if the last dimension is not even. + + Parameters: + input : Tensor + The first tensor (often called A). + + weight : Tensor + The second tensor (often called B). + + bias : Optional[Tensor] + The per-channel bias. The plugin with fp8 dtype does not support bias yet. + + scale_d0 : float + The scale for dequantizing x, used for fp8 + + scale_d1 : float + The scale for dequantizing gate, used for fp8 + + scale_output : float + The scale for quantizing output, used for fp8 + + Returns: + The tensor produced by the inserted layer. + ''' + plg_creator = trt.get_plugin_registry().get_plugin_creator( + 'GemmSwiglu', '1', TRT_LLM_PLUGIN_NAMESPACE) + assert plg_creator is not None + + p_dtype = default_net().plugin_config.gemm_swiglu_plugin + if p_dtype == "fp8": + assert bias == None, "fp8 gemm_swiglu does not support bias yet" + + pf_type = trt.PluginField( + "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32), + trt.PluginFieldType.INT32) + pf_has_bias = trt.PluginField( + "has_bias", np.array(np.int8(0 if bias is None else 1), np.int8), + trt.PluginFieldType.INT8) + pf_scale_d0 = trt.PluginField("scale_d0", + np.array(scale_d0, dtype=np.float32), + trt.PluginFieldType.FLOAT32) + pf_scale_d1 = trt.PluginField("scale_d1", + np.array(scale_d1, dtype=np.float32), + trt.PluginFieldType.FLOAT32) + pf_scale_output = trt.PluginField("scale_output", + np.array(scale_output, dtype=np.float32), + trt.PluginFieldType.FLOAT32) + + pfc = trt.PluginFieldCollection( + [pf_type, pf_has_bias, pf_scale_d0, pf_scale_d1, pf_scale_output]) + gemm_swiglu_plug = plg_creator.create_plugin("gemm_swiglu", pfc) + + # TODO(anchengc) pass nullptr when no bias + if bias is None: + bias = constant( + np.zeros([weight.shape[0]], dtype=trt_dtype_to_np(input.dtype))) + plug_inputs = [input.trt_tensor, weight.trt_tensor, bias.trt_tensor] + + layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_swiglu_plug) + + return _create_tensor(layer.get_output(0), layer) + + def constant(ndarray: np.ndarray) -> Tensor: ''' Add a constant layer. @@ -1230,32 +1306,33 @@ def categorical_sample(probs: Tensor, rand_data: Tensor = None) -> Tensor: return samples -def conditional(condition: Tensor, true_input: Tensor, - false_input: Tensor) -> Tensor: +class Conditional: ''' Add an operation to conditionally execute two code paths/subgraphs. - Parameters: - condition : Tensor - The condition tensor. If the condition is true, the operation will - return the true_input tensor, otherwise the false_input tensor. + Usage: + 1. conditional = Conditional(condition) + 2. input_1_ = conditional.add_input(input_1) + ... + input_n_ = conditional.add_input(input_n) + 3. Construct the graph to get true_output_value and false_output_value using input_1_, ..., input_n_ + 4. output = conditional.add_output(true_output_value, false_output_value) + ''' - true_input : Tensor - The tensor to return if the condition is true. + def __init__(self, condition: Tensor): + self.layer = default_trtnet().add_if_conditional() + if condition.ndim() > 0: + condition = view(condition, []) + self.layer.set_condition(condition.trt_tensor) - false_input : Tensor - The tensor to return if the condition is false. - ''' - if condition.ndim() > 0: - condition = view(condition, []) - cond_trt_ = condition.trt_tensor - layer = default_trtnet().add_if_conditional() - layer.set_condition(cond_trt_) - true_subgraph = layer.add_input(true_input.trt_tensor) - false_subgraph = layer.add_input(false_input.trt_tensor) - output = layer.add_output(true_subgraph.get_output(0), - false_subgraph.get_output(0)) - return _create_tensor(output.get_output(0), output) + def add_input(self, input: Tensor) -> Tensor: + in_node = self.layer.add_input(input.trt_tensor) + return _create_tensor(in_node.get_output(0), in_node) + + def add_output(self, true_value: Tensor, false_value: Tensor) -> Tensor: + out_node = self.layer.add_output(true_value.trt_tensor, + false_value.trt_tensor) + return _create_tensor(out_node.get_output(0), out_node) # TODO: support step. @@ -2143,7 +2220,7 @@ def masked_select(input: Tensor, mask: Tensor) -> Tensor: return _create_tensor(gather_layer.get_output(0), gather_layer) -def cumsum(input: Tensor, dim: int) -> Tensor: +def cumsum(input: Tensor, dim: int, prefer_plugin: bool = True) -> Tensor: ''' Add an operation to calculate inclusive cumulative sum of elements of a tensor in a given dimension. @@ -2176,6 +2253,9 @@ def cumsum(input: Tensor, dim: int) -> Tensor: The dimension to calculate the inclusive cumulative sum. Negative value is supported. + prefer_plugin : bool + Whether to use the cumsumLastDim plugin if dim is last dim. + Returns: The tensor containing the inclusive cumulative sum of input. ''' @@ -2185,33 +2265,52 @@ def cumsum(input: Tensor, dim: int) -> Tensor: dim = dim_resolve_negative(dim, input.ndim())[0] - if (dim == input.ndim() - 1) and input.size(-1) > 0: - old_shape = shape(input) - if input.ndim() != 2: - input_2d = input.view([-1, input.size(-1)]) + if (dim == input.ndim() - 1): + if prefer_plugin: + last_dim = input.size(-1) + if last_dim == -1: # dynamic? + last_dim = shape(input, -1) + old_shape = shape(input) + if input.ndim() == 1: + input_2d = unsqueeze( + input, 0) # special handling of rank-1 dynamic tensor + elif input.ndim() != 2: + input_2d = input.view(concat([-1, last_dim]), + zero_is_placeholder=False) + else: + input_2d = input + cumsum_last_dim_plg_creator = trt.get_plugin_registry( + ).get_plugin_creator('CumsumLastDim', '1', TRT_LLM_PLUGIN_NAMESPACE) + assert cumsum_last_dim_plg_creator is not None + input_length = trt.PluginField( + "input_length", np.array(input_2d.size(-1), dtype=np.int32), + trt.PluginFieldType.INT32) + pf_type = trt.PluginField("type_id", + np.array([int(input_2d.dtype)], np.int32), + trt.PluginFieldType.INT32) + pfc = trt.PluginFieldCollection([input_length, pf_type]) + cumsum_last_dim_plug = cumsum_last_dim_plg_creator.create_plugin( + "cumsum_last_dim", pfc) + plug_inputs = [input_2d] + plug_inputs = [i.trt_tensor for i in plug_inputs] + layer = default_trtnet().add_plugin_v2(plug_inputs, + cumsum_last_dim_plug) + _add_plugin_info(layer, cumsum_last_dim_plg_creator, + "cumsum_last_dim", pfc) + output = _create_tensor(layer.get_output(0), layer) + output = output.view(old_shape, zero_is_placeholder=False) + return output else: - input_2d = input - cumsum_last_dim_plg_creator = trt.get_plugin_registry( - ).get_plugin_creator('CumsumLastDim', '1', TRT_LLM_PLUGIN_NAMESPACE) - assert cumsum_last_dim_plg_creator is not None - input_length = trt.PluginField( - "input_length", np.array(input_2d.size(-1), dtype=np.int32), - trt.PluginFieldType.INT32) - pf_type = trt.PluginField("type_id", - np.array([int(input_2d.dtype)], np.int32), - trt.PluginFieldType.INT32) - pfc = trt.PluginFieldCollection([input_length, pf_type]) - cumsum_last_dim_plug = cumsum_last_dim_plg_creator.create_plugin( - "cumsum_last_dim", pfc) - plug_inputs = [input_2d] - plug_inputs = [i.trt_tensor for i in plug_inputs] - layer = default_trtnet().add_plugin_v2(plug_inputs, - cumsum_last_dim_plug) - _add_plugin_info(layer, cumsum_last_dim_plg_creator, "cumsum_last_dim", - pfc) - output = _create_tensor(layer.get_output(0), layer) - output = output.view(old_shape) - return output + # credit to Apple + reduction_length = shape(input, -1) + reduction_range = arange(constant_to_tensor_(0, to_array=False), + reduction_length, + dtype='int32') + lower_triangle = cast( + unsqueeze(reduction_range, 0) <= unsqueeze(reduction_range, 1), + dtype=input.dtype) + output = sum(unsqueeze(input, -2) * lower_triangle, dim=-1) + return output else: slice_shape = [] for i in range(input.ndim()): @@ -3076,6 +3175,39 @@ def geglu(x: Tensor) -> Tensor: return a * gelu(b) +def quick_gelu(x: Tensor) -> Tensor: + return x * sigmoid(1.702 * x) + + +def gegelu(x: Tensor, limit: Optional[float] = None) -> Tensor: + # a, b = x[..., ::2], x[..., 1::2] + ndim = x.ndim() + a_starts = [0 for i in range(ndim)] + b_starts = [1 if i == (ndim - 1) else 0 for i in range(ndim)] + shapes = concat([ + shape(x, i) / 2 if i == (ndim - 1) else shape(x, i) for i in range(ndim) + ]) + strides = [2 if i == (ndim - 1) else 1 for i in range(ndim)] + + a = slice(x, a_starts, shapes, strides) + b = slice(x, b_starts, shapes, strides) + + if limit is not None: + a = clip(a, alpha=float(-1e20), beta=limit) + b = clip(b, alpha=-limit, beta=limit) + + # C = B + 1 + const1 = arange(constant(int32_array(1)), constant(int32_array(2)), + trt_dtype_to_str(b.dtype)) + for _ in range(ndim - 1): + const1 = expand_dims(const1, 0) + + b_shape = concat([shape(b, i) for i in range(ndim)]) + const1_arr = expand(const1, b_shape) + + return quick_gelu(a) * (b + const1_arr) + + def group_norm(input: Tensor, num_groups: int, weight: Optional[Tensor] = None, @@ -3964,6 +4096,8 @@ def create_sinusoidal_positions_long_rope( theta: float = 10000.0, scaling_short_factors: Tensor = 1.0, scaling_long_factors: Tensor = 1.0, + short_mscale=None, + long_mscale=None, dtype=np.float32): def _calc_mscale(scale): @@ -3971,16 +4105,19 @@ def _calc_mscale(scale): return 1.0 return math.sqrt(1 + math.log(scale) / math.log(num_orig_pos)) - mscale = _calc_mscale(num_pos / num_orig_pos) + if short_mscale is None: + short_mscale = _calc_mscale(num_pos / num_orig_pos) + long_mscale = short_mscale - def _compute_sinusoidal_positions(scale_factors, - for_attention_plugin=True): + def _compute_sinusoidal_positions(scale_factors, is_short, + for_attention_plugin): inv_freq = 1 / (scale_factors * (theta**(np.arange(0, dim, 2) / dim)).astype(dtype)) sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos, dtype=dtype), inv_freq, dtype=dtype) + if for_attention_plugin: sinusoid_inp = np.expand_dims(sinusoid_inp, axis=-1) concat = np.concatenate( @@ -3989,13 +4126,17 @@ def _compute_sinusoidal_positions(scale_factors, concat = np.concatenate( (np.sin(sinusoid_inp), np.cos(sinusoid_inp)), axis=1) concat = np.expand_dims(concat, axis=0) + + mscale = short_mscale if is_short else long_mscale return concat.astype(dtype) * mscale return _compute_sinusoidal_positions( - scaling_short_factors, False), _compute_sinusoidal_positions( - scaling_long_factors, False), _compute_sinusoidal_positions( - scaling_short_factors, True), _compute_sinusoidal_positions( - scaling_long_factors, True), mscale + scaling_short_factors, True, False), _compute_sinusoidal_positions( + scaling_long_factors, + False, False), _compute_sinusoidal_positions( + scaling_short_factors, True, + True), _compute_sinusoidal_positions( + scaling_long_factors, False, True), short_mscale @staticmethod def rotate_every_two(tensor: Tensor) -> Tensor: @@ -4206,9 +4347,11 @@ def gpt_attention( rotary_embedding_base: float = 10000.0, rotary_embedding_scale_type: RotaryScalingType = RotaryScalingType.none, rotary_embedding_scaling_factors: Optional[Tensor] = None, - rotary_embedding_m_scale: Optional[float] = None, + rotary_embedding_short_m_scale: Optional[float] = None, + rotary_embedding_long_m_scale: Optional[float] = None, rotary_embedding_scale: float = 1.0, rotary_embedding_max_positions: int = 1024, + rotary_embedding_original_max_positions: int = 1024, position_embedding_type: PositionEmbeddingType = PositionEmbeddingType. learned_absolute, rotary_cos_sin: Optional[Tensor] = None, @@ -4218,6 +4361,10 @@ def gpt_attention( kv_cache_quant_mode: QuantMode = QuantMode(0), max_context_length: Optional[int] = None, mask_type: AttentionMaskType = AttentionMaskType.causal, + block_sparse_block_size: int = 64, + block_sparse_homo_head_pattern: bool = False, + block_sparse_num_local_blocks: int = 16, + block_sparse_vertical_stride: int = 8, alibi_slopes: Optional[Tensor] = None, tp_size: int = 1, tp_rank: int = 0, @@ -4370,6 +4517,19 @@ def gpt_attention( * tensorrt_llm.layers.AttentionMaskType.causal for GPT, * tensorrt_llm.layers.AttentionMaskType.bidirectional for ChatGLM-6B, * tensorrt_llm.layers.AttentionMaskType.bidirectionalglm for GLM-10B, + * tensorrt_llm.layers.AttentionMaskType.blocksparse for Phi-3-small, + + block_sparse_block_size: int + Block size in block sparse attention + + block_sparse_homo_head_pattern: bool + Do all attention heads share same vertical stride pattern? + + block_sparse_num_local_blocks: int + Number of active blocks near diagonal + + block_sparse_vertical_stride: int + Stride of active blocks in vertical dimension alibi_slopes: Tensor The ALiBi slopes. The ALiBi bias is computed on-the-fly in the kernel @@ -4501,14 +4661,22 @@ def gpt_attention( "rotary_embedding_scale", np.array(rotary_embedding_scale, dtype=np.float32), trt.PluginFieldType.FLOAT32) - rotary_embedding_m_scale = trt.PluginField( - "rotary_embedding_m_scale", - np.array(rotary_embedding_m_scale, dtype=np.float32), + rotary_embedding_short_m_scale = trt.PluginField( + "rotary_embedding_short_m_scale", + np.array(rotary_embedding_short_m_scale, dtype=np.float32), + trt.PluginFieldType.FLOAT32) + rotary_embedding_long_m_scale = trt.PluginField( + "rotary_embedding_long_m_scale", + np.array(rotary_embedding_long_m_scale, dtype=np.float32), trt.PluginFieldType.FLOAT32) rotary_embedding_max_positions = trt.PluginField( "rotary_embedding_max_positions", np.array(rotary_embedding_max_positions, dtype=np.int32), trt.PluginFieldType.INT32) + rotary_embedding_original_max_positions = trt.PluginField( + "rotary_embedding_original_max_positions", + np.array(rotary_embedding_original_max_positions, dtype=np.int32), + trt.PluginFieldType.INT32) position_embedding_type = trt.PluginField( "position_embedding_type", np.array(int(position_embedding_type), dtype=np.int8), @@ -4532,6 +4700,22 @@ def gpt_attention( mask_type = trt.PluginField("mask_type", np.array([int(mask_type)], np.int32), trt.PluginFieldType.INT32) + block_sparse_block_size = trt.PluginField( + "block_sparse_block_size", np.array([block_sparse_block_size], + np.int32), + trt.PluginFieldType.INT32) + block_sparse_homo_head_pattern = trt.PluginField( + "block_sparse_homo_head_pattern", + np.array(np.int8(block_sparse_homo_head_pattern), np.int8), + trt.PluginFieldType.INT8) + block_sparse_num_local_blocks = trt.PluginField( + "block_sparse_num_local_blocks", + np.array([block_sparse_num_local_blocks], np.int32), + trt.PluginFieldType.INT32) + block_sparse_vertical_stride = trt.PluginField( + "block_sparse_vertical_stride", + np.array([block_sparse_vertical_stride], np.int32), + trt.PluginFieldType.INT32) multi_block_mode = trt.PluginField( "multi_block_mode", np.array(np.int8(default_net().plugin_config.multi_block_mode), @@ -4598,9 +4782,12 @@ def gpt_attention( unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, - rotary_embedding_m_scale, rotary_embedding_max_positions, tp_size, - tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode, + rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, + rotary_embedding_max_positions, rotary_embedding_original_max_positions, + tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type, multi_block_mode, enable_xqa, kv_cache_quant_mode_field, remove_input_padding, mask_type, + block_sparse_block_size, block_sparse_homo_head_pattern, + block_sparse_num_local_blocks, block_sparse_vertical_stride, paged_kv_cache, tokens_per_block, pf_type, max_context_length, qkv_bias_enabled, do_cross_attention_field, max_distance, pos_shift_enabled, dense_context_fmha, use_paged_context_fmha_field, @@ -5063,6 +5250,7 @@ def gather_last_token_logits(hidden_states: Tensor, last_token_ids: Tensor, 'gelu_new': gelu, 'gelu_fast': gelu, 'geglu': geglu, + 'gegelu': gegelu, 'identity': identity, 'silu': silu, 'softplus': softplus, diff --git a/tensorrt_llm/hlapi/llm.py b/tensorrt_llm/hlapi/llm.py index c8a1f4eea..7b88c4780 100644 --- a/tensorrt_llm/hlapi/llm.py +++ b/tensorrt_llm/hlapi/llm.py @@ -21,18 +21,18 @@ from ..executor import GenerationExecutor, GenerationResult from ..logger import logger from ..mapping import Mapping -from ..models.modeling_utils import (PretrainedConfig, QuantAlgo, QuantConfig, - load_model) +from ..models import MODEL_MAP +from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig from ..module import Module from .mpi_session import (MpiCommSession, MPINodeState, MpiPoolSession, MpiSession, external_mpi_comm_available) from .tokenizer import TokenizerBase, TransformersTokenizer from .utils import (GenerationOutput, GpuArch, OutputConfig, SamplingConfig, - file_with_glob_exists, file_with_suffix_exists, - get_device_count, print_colored, print_traceback_on_error, - suppress_runtime_log) + download_hf_model, file_with_glob_exists, + file_with_suffix_exists, get_device_count, init_log_level, + print_colored, print_traceback_on_error) -suppress_runtime_log( +init_log_level( ) # This should be called before importing the following cpp-runtime modules from ..bindings.executor import CapacitySchedulerPolicy @@ -98,7 +98,7 @@ class ModelConfig: # ``model_dir`` helps to locate a local model, the format of the model is determined by the model file itself. # Either HF model, TensorRT-LLM checkpoints or TensorRT-LLM engine format is supported. - model_dir: str + model_dir: Optional[str] = None # ``model`` could either the model directory or a in-memory model. # If ``model`` specifies the model kind like "llama-7B", etc. The model will be download automatically from third-party @@ -118,13 +118,12 @@ class ModelConfig: repr=False) def __post_init__(self): - if self.model: - raise NotImplementedError("model is not supported yet.") - - model_path = Path(self.model_dir) - if not model_path.exists(): + if not (self.model_dir or self.model): + raise ValueError("Either model_dir or model should be provided.") + if self.model_dir and self.model: raise ValueError( - f"model_dir of path {self.model_dir} does not exist.") + "Only one of model_dir or model should be provided, provided both." + ) self._engine_config: Optional[EngineConfig] = None @@ -139,15 +138,22 @@ def __post_init__(self): **infer_cluster_config(), ) - # Load parallel_config from the engine. - self.model_format = ModelLoader.get_model_format(self.model_dir) - if self.model_format is _ModelFormatKind.TLLM_ENGINE: - self._load_config_from_engine(Path(self.model_dir)) + if self.model_dir: + model_path = Path(self.model_dir) + if not model_path.exists(): + raise ValueError( + f"model_dir of path {self.model_dir} does not exist.") + + # Load parallel_config from the engine. + self.model_format = ModelLoader.get_model_format(self.model_dir) + if self.model_format is _ModelFormatKind.TLLM_ENGINE: + self._load_config_from_engine(Path(self.model_dir)) - # Load parallel_config from the checkpoint. - if ModelLoader.get_model_format( - self.model_dir) is _ModelFormatKind.TLLM_CKPT: - self._load_config_from_ckpt(Path(self.model_dir)) + # Load parallel_config from the checkpoint. + if self.model_format is _ModelFormatKind.TLLM_CKPT: + self._load_config_from_ckpt(Path(self.model_dir)) + else: + self.model_format = _ModelFormatKind.HF def _update_plugin_config(self, key: str, value: Any): if key == 'use_paged_context_fmha': @@ -245,6 +251,7 @@ def __init__(self, config: ModelConfig, *, tokenizer: Optional[TokenizerBase] = None, + dtype: str = 'auto', kv_cache_config: Optional[KvCacheConfig] = None, streaming_llm: Union[bool, StreamingLLMParam] = False, async_engine_tmp_dir: Optional[str] = None, @@ -255,6 +262,10 @@ def __init__(self, The model config for the model. tokenizer (TokenizerBase): User provided tokenizer, will override the default one if exists in the HF model or TRT-LLM engine. + dtype (str): + The data type for the model weights and activations (non-quantized). You can + (1) explicitly specify `float16`, `bfloat16` or `float32`; or + (2) implicitly specify `auto` (default), then `dtype` will be automatically inferred from the source model. However, if the source `dtype` is `float32`, will use `float16` instead. kv_cache_config (KvCacheConfig): The config for the paged KV cache. streaming_llm (bool, StreamingLLMParam): @@ -284,6 +295,7 @@ def __init__(self, self.config = config self._tokenizer = tokenizer + self.dtype = dtype self.async_engine_tmp_dir = async_engine_tmp_dir self.kv_cache_config = kv_cache_config # TODO[chunweiy]: add doc for enable_streaming_llm @@ -601,8 +613,7 @@ def get_default_sampling_config(self) -> Optional[SamplingConfig]: ) def _build_model(self): - model_format = ModelLoader.get_model_format(self.config.model_dir) - + model_format = self.config.model_format self._engine_dir = self.config.model_dir def get_engine_dir(): @@ -624,6 +635,7 @@ def get_engine_dir(): LLM._node_build_task, self.config, self._tokenizer, + self.dtype, self._workspace.name, build_config=self.config.build_config, convert_checkpoint_options=self._convert_checkpoint_options, @@ -637,6 +649,7 @@ def get_engine_dir(): with ModelLoader( self.config, tokenizer=self._tokenizer, + dtype=self.dtype, workspace=self._workspace.name, build_config=self.config.build_config, convert_checkpoint_options=self. @@ -683,6 +696,7 @@ def get_engine_dir(): def _node_build_task( config: ModelConfig, tokenizer: Optional[TokenizerBase] = None, + dtype: str = 'auto', workspace: Optional[str] = None, build_config: Optional[BuildConfig] = None, convert_checkpoint_options: Optional[dict] = None) -> bool: @@ -691,6 +705,7 @@ def _node_build_task( with ModelLoader(config, tokenizer=tokenizer, + dtype=dtype, workspace=workspace, build_config=build_config, convert_checkpoint_options=convert_checkpoint_options @@ -795,11 +810,13 @@ class ModelLoader: def __init__(self, config: ModelConfig, tokenizer: Optional[TokenizerBase], + dtype: str = 'auto', workspace: Optional[str] = None, build_config: Optional[BuildConfig] = None, convert_checkpoint_options: Optional[dict] = None): self.config = config self.tokenizer = tokenizer + self.dtype = dtype self.workspace = workspace assert build_config @@ -843,12 +860,14 @@ def __init__(self, if self.config.model_dir is None: ''' Download HF model if necessary ''' - # TODO[chunweiy]: Support HF model download - raise NotImplementedError() + if self.config.model is None: + raise ValueError( + "Either model_dir or model should be provided to ModelConfig." + ) + self._model_pipeline.append( + ("Downloading HF model", self._download_hf_model)) - if self._model_dir is None: - raise ValueError("The model_dir is not set yet.") - self._model_format = ModelLoader.get_model_format(self._model_dir) + self._model_format = self.config.model_format if self._model_format is _ModelFormatKind.HF: ''' HF -> TRT checkpoints -> engine ''' @@ -982,7 +1001,11 @@ def get_model_format(model_dir: str) -> _ModelFormatKind: def _download_hf_model(self): ''' Download HF model from third-party model hub like www.modelscope.cn or huggingface. ''' - raise NotImplementedError() + assert self.workspace is not None + assert isinstance(self.config.model, str) + self._model_dir = download_hf_model(self.config.model) + self.config.model_dir = self._model_dir + print_colored(f"Downloaded model to {self._model_dir}\n", 'grey') def _load_model_from_hf(self): ''' Load a TRT-LLM model from a HF model. ''' @@ -1005,28 +1028,32 @@ def _load_model_from_hf(self): f"Unsupported model architecture: {model_arch}, " f"only {', '.join(model2struct.keys())} are supported now.") + model_cls = model2struct[model_arch] + if self.config.quant_config.quant_mode.has_any_quant(): assert self.workspace is not None checkpoint_dir = f"{self.workspace}/quantized-checkpoint" if self.rank == 0: - model2struct[model_arch].quantize( + model_cls.quantize( self._model_dir, checkpoint_dir, - self.config.quant_config, + dtype=self.dtype, mapping=self.mapping, + quant_config=self.config.quant_config, ) if self.config.parallel_config.is_multi_gpu: mpi_barrier() - self.model = model2struct[model_arch].from_checkpoint( - checkpoint_dir, rank=self.mapping.rank) + self.model = model_cls.from_checkpoint(checkpoint_dir, + rank=self.mapping.rank) else: - self.model = model2struct[model_arch].from_hugging_face( + self.model = model_cls.from_hugging_face( self._model_dir, + dtype=self.dtype, mapping=self.mapping, - quantization=self.config.quant_config, + quant_config=self.config.quant_config, load_model_on_cpu= True, # TODO:TRTLLM-195 to enhance the weights loading memory usage and chose best location - override_fields=self.convert_checkpoint_options, + **self.convert_checkpoint_options, ) self.pretrained_config = self.model.config @@ -1035,11 +1062,16 @@ def _load_model_from_hf(self): def _load_model_from_ckpt(self): ''' Load a TRT-LLM model from checkpoint. ''' - model_config = PretrainedConfig.from_json_file( + self.pretrained_config = PretrainedConfig.from_json_file( os.path.join(self._model_dir, 'config.json')) - model_config.mapping = self.mapping - self.model = load_model(model_config, self._model_dir) - self.pretrained_config = model_config + self.pretrained_config.mapping = self.mapping + + architecture = self.pretrained_config.architecture + assert architecture in MODEL_MAP, \ + f"Unsupported model architecture: {architecture}" + model_cls = MODEL_MAP[architecture] + self.model = model_cls.from_checkpoint(self._model_dir, + config=self.pretrained_config) self._model_info = _ModelInfo.from_pretrained_config( self.pretrained_config) diff --git a/tensorrt_llm/hlapi/utils.py b/tensorrt_llm/hlapi/utils.py index 9ce349e8c..6bdf828d9 100644 --- a/tensorrt_llm/hlapi/utils.py +++ b/tensorrt_llm/hlapi/utils.py @@ -1,15 +1,21 @@ +import hashlib import os import signal import sys +import tempfile import traceback from dataclasses import dataclass, field from functools import wraps from pathlib import Path from typing import List, Optional, Union +import filelock +import huggingface_hub import torch +from huggingface_hub import snapshot_download from tensorrt_llm.bindings import executor as tllme +from tensorrt_llm.logger import set_level from ..bindings.executor import OutputConfig @@ -187,13 +193,11 @@ def is_directory_empty(directory: Path) -> bool: return not any(directory.iterdir()) -def suppress_runtime_log(): - ''' Suppress the runtime log if the environment variable is not set. ''' - +def init_log_level(): + ''' Set the log level if the environment variable is not set. ''' if "TLLM_LOG_LEVEL" not in os.environ: - os.environ["TLLM_LOG_LEVEL"] = "ERROR" - if "TLLM_LOG_FIRST_RANK_ONLY" not in os.environ: - os.environ["TLLM_LOG_FIRST_RANK_ONLY"] = "ON" + set_level("warning") + os.environ["TLLM_LOG_LEVEL"] = "WARNING" def sigint_handler(signal, frame): @@ -204,3 +208,26 @@ def sigint_handler(signal, frame): # Register the signal handler to handle SIGINT # This helps to deal with user's Ctrl+C signal.signal(signal.SIGINT, sigint_handler) +# Use the system temporary directory to share the cache +temp_dir = tempfile.gettempdir() + + +def get_file_lock(model_name: str, + cache_dir: Optional[str] = None) -> filelock.FileLock: + # Hash the model name to avoid invalid characters in the lock file path + hashed_model_name = hashlib.sha256(model_name.encode()).hexdigest() + + cache_dir = cache_dir or temp_dir + os.makedirs(cache_dir, exist_ok=True) + + lock_file_path = os.path.join(cache_dir, f"{hashed_model_name}.lock") + + return filelock.FileLock(lock_file_path) + + +def download_hf_model(model_name: str) -> Path: + with get_file_lock(model_name): + hf_folder = snapshot_download( + model_name, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE) + return Path(hf_folder) diff --git a/tensorrt_llm/layers/__init__.py b/tensorrt_llm/layers/__init__.py index e9a4c1098..3511baf59 100644 --- a/tensorrt_llm/layers/__init__.py +++ b/tensorrt_llm/layers/__init__.py @@ -14,8 +14,9 @@ # limitations under the License. from .activation import Mish from .attention import (Attention, AttentionMaskType, AttentionParams, - BertAttention, CogVLMAttention, KeyValueCacheParams, - PositionEmbeddingType, SpecDecodingParams) + BertAttention, BlockSparseAttnParams, CogVLMAttention, + KeyValueCacheParams, PositionEmbeddingType, + SpecDecodingParams) from .cast import Cast from .conv import Conv1d, Conv2d, ConvTranspose2d from .embedding import Embedding, PromptTuningEmbedding @@ -57,6 +58,7 @@ 'AttentionParams', 'SpecDecodingParams', 'KeyValueCacheParams', + 'BlockSparseAttnParams', 'Lora', 'LoraParams', 'LoraRuntimeParams', diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index 536212009..5dde13747 100644 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -21,12 +21,12 @@ from .._common import default_net, precision from .._utils import (fp32_array, int32_array, is_same_dtype, trt_dtype_to_np, trt_dtype_to_str, trt_gte_10) -from ..functional import (AttentionMaskType, PositionEmbeddingType, - RopeEmbeddingUtils, RotaryScalingType, Tensor, arange, - bert_attention, cast, clip, concat, conditional, - constant, embedding, expand, expand_dims, expand_mask, - generate_alibi_biases, generate_alibi_slopes, - gpt_attention, matmul) +from ..functional import (ACT2FN, AttentionMaskType, Conditional, + PositionEmbeddingType, RopeEmbeddingUtils, + RotaryScalingType, Tensor, arange, bert_attention, + cast, clip, concat, constant, embedding, expand, + expand_dims, expand_mask, generate_alibi_biases, + generate_alibi_slopes, gpt_attention, matmul) from ..functional import max as fmax from ..functional import (minimum, repeat_interleave, shape, slice, softmax, split, unsqueeze, where) @@ -236,52 +236,68 @@ def is_valid(self, gpt_attention_plugin): return True +class BlockSparseAttnParams: + + def __init__(self, + block_size: int = 64, + homo_head_pattern: bool = False, + num_local_blocks: int = 16, + vertical_stride: int = 8): + self.block_size = block_size + self.homo_head_pattern = homo_head_pattern + self.num_local_blocks = num_local_blocks + self.vertical_stride = vertical_stride + + class Attention(Module): - def __init__( - self, - *, - local_layer_idx, - hidden_size, - num_attention_heads, - num_kv_heads=None, - max_position_embeddings=1024, - num_layers=1, - apply_query_key_layer_scaling=False, - attention_head_size=None, - qk_layernorm=False, - inner_layernorm=False, - eps=1e-05, - attention_mask_type=AttentionMaskType.padding, - bias=True, - dtype=None, - position_embedding_type=PositionEmbeddingType.learned_absolute, - rotary_embedding_base=10000.0, - rotary_embedding_scaling=None, - rotary_embedding_percentage=1.0, - rope_scaling_short_factors=None, - rope_scaling_long_factors=None, - original_max_position_embeddings=1024, - tp_group=None, - tp_size=1, - tp_rank=0, - quant_mode: QuantMode = QuantMode(0), - q_scaling=1.0, - cross_attention=False, - relative_attention=False, - max_distance=0, - num_buckets=0, - dense_bias=None, - clip_qkv=None, - alibi_bias_max=8, - skip_cross_qkv=False, - ): + def __init__(self, + *, + local_layer_idx, + hidden_size, + num_attention_heads, + num_kv_heads=None, + max_position_embeddings=1024, + num_layers=1, + apply_query_key_layer_scaling=False, + attention_head_size=None, + qk_layernorm=False, + inner_layernorm=False, + eps=1e-05, + attention_mask_type=AttentionMaskType.padding, + bias=True, + dtype=None, + position_embedding_type=PositionEmbeddingType.learned_absolute, + rotary_embedding_base=10000.0, + rotary_embedding_scaling=None, + rotary_embedding_percentage=1.0, + rope_scaling_short_factors=None, + rope_scaling_long_factors=None, + rope_scaling_short_mscale=None, + rope_scaling_long_mscale=None, + original_max_position_embeddings=1024, + tp_group=None, + tp_size=1, + tp_rank=0, + quant_mode: QuantMode = QuantMode(0), + q_scaling=1.0, + cross_attention=False, + relative_attention=False, + max_distance=0, + num_buckets=0, + dense_bias=None, + clip_qkv=None, + alibi_bias_max=8, + skip_cross_qkv=False, + max_attn_value=0.0, + block_sparse_params=None): super().__init__() - self.layer_idx = local_layer_idx + self.local_layer_idx = local_layer_idx self.cross_attention = cross_attention self.attention_mask_type = attention_mask_type self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size + self.num_kv_heads = num_kv_heads assert num_attention_heads % tp_size == 0, \ "num_attention_heads must be divisible by tp_size" self.num_attention_heads = num_attention_heads // tp_size @@ -291,14 +307,15 @@ def __init__( self.hidden_size = hidden_size self.attention_hidden_size = self.attention_head_size * self.num_attention_heads self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings self.bias = bias self.tp_group = tp_group self.tp_size = tp_size self.tp_rank = tp_rank self.dtype = dtype + self.dense_bias = dense_bias if dense_bias is None: - dense_bias = bias - self.unfuse_qkv_gemm = False + self.dense_bias = bias self.num_layers = num_layers self.apply_query_key_layer_scaling = apply_query_key_layer_scaling @@ -343,12 +360,19 @@ def __init__( self.max_position_embeddings, original_max_position_embeddings, self.rotary_embedding_dim, self.rotary_embedding_base, rope_scaling_short_factors, - rope_scaling_long_factors) + rope_scaling_long_factors, rope_scaling_short_mscale, rope_scaling_long_mscale) + + if rope_scaling_short_mscale is not None: + assert rope_scaling_long_mscale is not None + short_mscale = rope_scaling_short_mscale + long_mscale = rope_scaling_long_mscale + else: + short_mscale = long_mscale = mscale + rope_scaling_short_factors = np.array( rope_scaling_short_factors).reshape(1, -1) rope_scaling_long_factors = np.array( rope_scaling_long_factors).reshape(1, -1) - self.original_max_position_embeddings = original_max_position_embeddings self.register_parameter( 'embed_positions_short_factors', @@ -365,7 +389,8 @@ def __init__( 'embed_positions_long_factors_for_attention_plugin', Parameter(embed_positions_long_factors_for_attention_plugin, dtype='float32')) - self.mscale = mscale + self.short_mscale = short_mscale + self.long_mscale = long_mscale self.register_parameter( 'rope_scaling_short_factors', Parameter(rope_scaling_short_factors, dtype='float32')) @@ -402,9 +427,13 @@ def __init__( Parameter(alibi_slopes, dtype='float32')) self.quant_mode = quant_mode + self.max_attn_value = max_attn_value self.register_parameter('kv_cache_scaling_factor', None) self.register_parameter('attention_output_orig_quant_scale', None) + self.block_sparse_params = block_sparse_params if block_sparse_params is not None else BlockSparseAttnParams( + ) + # The output feature size is therefore (h/tp + 2*kvh/tp) * d, where h is num_heads, # d is head_size, kvh is the num_kv_heads and tp is tensor_parallel_size. # In ColumnLinear op, the output dim is calculated by (h + 2*kvh) * d / tp, @@ -425,11 +454,14 @@ def __init__( self.dense = RowLinear(tp_size * self.num_attention_heads * self.attention_head_size, hidden_size, - bias=dense_bias, + bias=self.dense_bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size) + # see optimize_model's add_lora for LoRA initialization + self.qkv_lora = None + # per-layer relative attention table if relative_attention: self.rel_attn_table = Parameter(shape=(num_attention_heads // @@ -482,7 +514,7 @@ def forward(self, qkv_lora_params = lora_layer_params.get_runtime_params( 0, "cross_attn_qkv") - unfuse_qkv_gemm = self.unfuse_qkv_gemm + unfuse_qkv_gemm = self.qkv is None if unfuse_qkv_gemm: qkv_gemm = [self.q, self.k, self.v] qkv = [gemm(hidden_states) for gemm in qkv_gemm] @@ -610,13 +642,22 @@ def forward(self, if self.cross_attention and encoder_output: assert isinstance(encoder_output, Tensor) + encoder_output_tensor = None + cross_qkv_reuse_tensor = None + if self.skip_cross_qkv: + conditional = Conditional(cross_kv_cache_gen) + encoder_output_tensor = conditional.add_input(encoder_output) + cross_qkv_reuse_tensor = conditional.add_input(cross_qkv_reuse) + else: + encoder_output_tensor = encoder_output + ## True branch: context phase, compute cross qkv - cross_qkv_true = self.qkv(encoder_output, qkv_lora_params) + cross_qkv_true = self.qkv(encoder_output_tensor, qkv_lora_params) if default_net( ).plugin_config.lora_plugin and qkv_lora_params is None and lora_layer_params is not None: cross_q_lora, cross_k_lora, cross_v_lora = self.qkv_lora( - encoder_output, + encoder_output_tensor, qkv_lora_runtime_params, is_cross_attention=True) cross_qkv_lora = concat( @@ -628,13 +669,13 @@ def forward(self, ## False branch: generation phase, no compute but need to obey shape constraints # because TRT's IfConditional requires the output shape of two subgraphs to be identical # our 1st attempt was to stack encoder_output [B, S, H] or [N, H] --> cross qkv [B, S, 3*H] or [N, 3*H], but it still introduces unnecessary concat. A better solution is to create a dummy torch tensor `cross_qkv_resue` with the correct shape and reuse it in every generation step - cross_qkv_false = cross_qkv_reuse + cross_qkv_false = cross_qkv_reuse_tensor ## End False branch # IfConditional layer if self.skip_cross_qkv: - cross_qkv = conditional(cross_kv_cache_gen, cross_qkv_true, - cross_qkv_false) + cross_qkv = conditional.add_output(cross_qkv_true, + cross_qkv_false) else: cross_qkv = cross_qkv_true @@ -643,7 +684,8 @@ def forward(self, past_key_value = kv_cache_params.past_key_value[1] assert self.attention_mask_type in [ AttentionMaskType.causal, AttentionMaskType.bidirectional, - AttentionMaskType.bidirectionalglm + AttentionMaskType.bidirectionalglm, + AttentionMaskType.blocksparse ], 'Plugin only support masked MHA.' # KV cache scales. @@ -702,7 +744,12 @@ def forward(self, rotary_cos_sin = self.embed_positions_for_gpt_attention.value if self.position_embedding_type.is_rope( ) else None rope_scaling_factors = None - mscale = self.mscale if self.position_embedding_type == PositionEmbeddingType.long_rope else None + + if self.position_embedding_type == PositionEmbeddingType.long_rope: + short_mscale, long_mscale = self.short_mscale, self.long_mscale + else: + short_mscale, long_mscale = None, None + context, past_key_value = gpt_attention( qkv=qkv, past_key_value=past_key_value, @@ -715,7 +762,7 @@ def forward(self, context_lengths=attention_params.context_lengths, cache_indirection=kv_cache_params.cache_indirection, host_request_types=attention_params.host_request_types, - layer_idx=self.layer_idx, + layer_idx=self.local_layer_idx, num_heads=self.num_attention_heads, num_kv_heads=self.num_attention_kv_heads, hidden_size_per_head=self.attention_head_size, @@ -724,9 +771,12 @@ def forward(self, rotary_embedding_base=self.rotary_embedding_base, rotary_embedding_scale_type=self.rotary_embedding_scale_type, rotary_embedding_scaling_factors=rope_scaling_factors, - rotary_embedding_m_scale=mscale, + rotary_embedding_short_m_scale=short_mscale, + rotary_embedding_long_m_scale=long_mscale, rotary_embedding_scale=self.rotary_embedding_scale, rotary_embedding_max_positions=self.max_position_embeddings, + rotary_embedding_original_max_positions=self. + original_max_position_embeddings, position_embedding_type=self.position_embedding_type, rotary_cos_sin=rotary_cos_sin, kv_orig_quant_scale=kv_orig_quant_scale, @@ -736,6 +786,13 @@ def forward(self, kv_cache_quant_mode=self.quant_mode, max_context_length=attention_params.max_context_length, mask_type=self.attention_mask_type, + block_sparse_block_size=self.block_sparse_params.block_size, + block_sparse_homo_head_pattern=self.block_sparse_params. + homo_head_pattern, + block_sparse_num_local_blocks=self.block_sparse_params. + num_local_blocks, + block_sparse_vertical_stride=self.block_sparse_params. + vertical_stride, alibi_slopes=alibi_slopes, tp_size=self.tp_size, tp_rank=self.tp_rank, @@ -763,7 +820,7 @@ def forward(self, spec_decoding_position_offsets, spec_decoding_packed_mask=spec_decoding_params. spec_decoding_packed_mask, - ) + qk_tanh_scale=self.max_attn_value) else: # plain TensorRT mode @@ -1065,6 +1122,9 @@ def transpose_for_scores(x, if not norm_before_bmm1: attention_scores = attention_scores / (self.q_scaling * self.norm_factor) + if self.max_attn_value > 0: + attention_scores = self.max_attn_value * ACT2FN['tanh']( + attention_scores / self.max_attn_value) if self.attention_mask_type in [ AttentionMaskType.causal, @@ -1188,6 +1248,9 @@ def __init__(self, tp_group=tp_group, tp_size=tp_size) + # see optimize_model's add_lora for LoRA initialization + self.qkv_lora = None + # per-layer relative attention table if relative_attention: self.rel_attn_table = Parameter(shape=(num_attention_heads // @@ -1374,7 +1437,7 @@ def __init__( self.vis_dense = RowLinear(tp_size * self.num_attention_heads * self.attention_head_size, hidden_size, - bias=dense_bias, + bias=self.dense_bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size) @@ -1460,7 +1523,7 @@ def forward(self, context_lengths=attention_params.context_lengths, cache_indirection=kv_cache_params.cache_indirection, host_request_types=attention_params.host_request_types, - layer_idx=self.layer_idx, + layer_idx=self.local_layer_idx, num_heads=self.num_attention_heads, num_kv_heads=self.num_attention_kv_heads, hidden_size_per_head=self.attention_head_size, diff --git a/tensorrt_llm/layers/linear.py b/tensorrt_llm/layers/linear.py index c4640d6c4..894fa20e0 100644 --- a/tensorrt_llm/layers/linear.py +++ b/tensorrt_llm/layers/linear.py @@ -130,6 +130,7 @@ def __init__(self, self.dtype = dtype self.pad_lda = pad_lda + self.share_weight = share_weight if not share_weight: self.weight = Parameter(shape=(self.out_features, self.in_features), dtype=dtype) @@ -152,16 +153,16 @@ def __init__(self, else: self.register_parameter('bias', None) - # see add_lora in tensorrt_llm/models/modeling_utils.py for LoRA initialization + # see optimize_model's add_lora for LoRA initialization self.lora = None - def multiply_gather( - self, - x, - weight, - gemm_plugin: Optional[str] = None, - use_fp8: bool = False, - lora_runtime_params: Optional[LoraRuntimeParams] = None): + def multiply_gather(self, + x, + weight, + gemm_plugin: Optional[str] = None, + use_fp8: bool = False, + lora_runtime_params: Optional[LoraRuntimeParams] = None, + lora_hidden_state: Optional[Tensor] = None): hidden_state = x if gemm_plugin: x = _gemm_plugin(x, @@ -175,7 +176,8 @@ def multiply_gather( if default_net( ).plugin_config.lora_plugin and lora_runtime_params is not None: - x = x + self.lora(hidden_state, + x = x + self.lora(hidden_state if lora_hidden_state is None else + lora_hidden_state, lora_runtime_params=lora_runtime_params) if self.bias is not None: @@ -190,12 +192,14 @@ def multiply_gather( def forward(self, x, - lora_runtime_params: Optional[LoraRuntimeParams] = None): + lora_runtime_params: Optional[LoraRuntimeParams] = None, + lora_hidden_state: Optional[Tensor] = None): return self.multiply_gather( x, self.weight.value, gemm_plugin=default_net().plugin_config.gemm_plugin, - lora_runtime_params=lora_runtime_params) + lora_runtime_params=lora_runtime_params, + lora_hidden_state=lora_hidden_state) def weight_loader(self, mapping: Mapping, param: Parameter, loaded_weight: torch.Tensor): @@ -280,20 +284,20 @@ def __init__(self, else: self.register_parameter('bias', None) - # see add_lora in tensorrt_llm/models/modeling_utils.py for LoRA initialization + # see optimize_model's add_lora for LoRA initialization self.lora = None self.tp_group = tp_group self.tp_size = tp_size self.strict_dtype = self.dtype if strict_dtype else None - def multiply_reduce( - self, - x, - weight, - gemm_plugin: Optional[str] = None, - use_fp8: bool = False, - lora_runtime_params: Optional[LoraRuntimeParams] = None): + def multiply_reduce(self, + x, + weight, + gemm_plugin: Optional[str] = None, + use_fp8: bool = False, + lora_runtime_params: Optional[LoraRuntimeParams] = None, + lora_hidden_state: Optional[Tensor] = None): hidden_state = x if gemm_plugin: x = _gemm_plugin(x, @@ -307,7 +311,8 @@ def multiply_reduce( if default_net( ).plugin_config.lora_plugin and lora_runtime_params is not None: - x = x + self.lora(hidden_state, + x = x + self.lora(hidden_state if lora_hidden_state is None else + lora_hidden_state, lora_runtime_params=lora_runtime_params) if self.tp_size > 1 and self.tp_group is not None: @@ -321,12 +326,14 @@ def multiply_reduce( def forward(self, x, - lora_runtime_params: Optional[LoraRuntimeParams] = None): + lora_runtime_params: Optional[LoraRuntimeParams] = None, + lora_hidden_state: Optional[Tensor] = None): return self.multiply_reduce( x, self.weight.value, gemm_plugin=default_net().plugin_config.gemm_plugin, - lora_runtime_params=lora_runtime_params) + lora_runtime_params=lora_runtime_params, + lora_hidden_state=lora_hidden_state) def weight_loader(self, mapping: Mapping, param: Parameter, loaded_weight: torch.Tensor): diff --git a/tensorrt_llm/layers/mlp.py b/tensorrt_llm/layers/mlp.py index 7871aedfd..9deebc7df 100644 --- a/tensorrt_llm/layers/mlp.py +++ b/tensorrt_llm/layers/mlp.py @@ -12,10 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tensorrt as trt -from ..functional import ACT2FN, concat +from .._common import default_net +from ..functional import ACT2FN, cast, concat, gemm_swiglu from ..module import Module from ..quantization import QuantMode +from ..quantization.functional import quantize +from ..quantization.layers import FP8Linear, FP8RowLinear from .linear import ColumnLinear, RowLinear from .lora import LoraRuntimeParams from .normalization import LayerNorm @@ -40,9 +44,12 @@ def __init__( if hidden_act not in ACT2FN: raise ValueError( 'unsupported activation function: {}'.format(hidden_act)) - fc_output_size = 2 * ffn_hidden_size if hidden_act == 'swiglu' else ffn_hidden_size + fc_output_size = 2 * ffn_hidden_size if hidden_act in [ + 'swiglu', 'gegelu' + ] else ffn_hidden_size self.inner_layernorm = LayerNorm(ffn_hidden_size, dtype=dtype, eps=eps) if inner_layernorm else None + self.fc = ColumnLinear(hidden_size, fc_output_size, bias=bias, @@ -57,12 +64,17 @@ def __init__( tp_group=tp_group, tp_size=tp_size) + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size self.hidden_act = hidden_act self.dtype = dtype self.bias = bias + self.tp_group = tp_group + self.tp_size = tp_size self.quant_mode = quant_mode + self.eps = eps - def forward(self, hidden_states, lora_layer_params=None): + def forward(self, hidden_states, lora_layer_params=None, gegelu_limit=None): mlp_fc_lora_params = None if lora_layer_params is not None: mlp_fc_lora_params = lora_layer_params.get_runtime_params( @@ -74,7 +86,10 @@ def forward(self, hidden_states, lora_layer_params=None): 0, "mlp_4h_to_h") inter = self.fc(hidden_states, mlp_fc_lora_params) - inter = ACT2FN[self.hidden_act](inter) + if self.hidden_act == 'gegelu': + inter = ACT2FN[self.hidden_act](inter, gegelu_limit) + else: + inter = ACT2FN[self.hidden_act](inter) if self.inner_layernorm is not None: inter = self.inner_layernorm(inter) output = self.proj(inter, lora_runtime_params=mlp_proj_lora_params) @@ -160,6 +175,8 @@ def __init__( tp_group=None, tp_size=1, quant_mode=QuantMode(0), + inner_layernorm=False, + eps=1e-05, ): super().__init__() self.hidden_size = hidden_size @@ -180,6 +197,8 @@ def __init__( tp_size=self.tp_size, gather_output=False, ) + self.inner_layernorm = LayerNorm(ffn_hidden_size, dtype=dtype, + eps=eps) if inner_layernorm else None self.proj = RowLinear(ffn_hidden_size, hidden_size, bias=bias, @@ -187,7 +206,67 @@ def __init__( tp_group=tp_group, tp_size=tp_size) - def forward(self, hidden_states, lora_layer_params=None): + # see optimize_model's add_lora for LoRA initialization + self.lora = None + + def fc_gate_plugin(self, hidden_states, lora_layer_params=None): + # Combine the following pattern + # + # SiLU(FC(x)) + Gate(x) + # + # into: + # + # SwiGLU(FusedFC(x)) + p_dtype = default_net().plugin_config.gemm_swiglu_plugin + use_fp8 = p_dtype == 'fp8' + assert use_fp8, "gemm_swiglu_plugin only supports fp8 now" + + if lora_layer_params is not None: + mlp_fc_lora_params = lora_layer_params.get_runtime_params( + 0, "mlp_h_to_4h") + mlp_gate_lora_params = lora_layer_params.get_runtime_params( + 0, "mlp_gate") + + if mlp_fc_lora_params is not None or mlp_gate_lora_params is not None: + raise NotImplementedError( + f"LoRA not yet implemented for gemm_swiglu_plugin") + + if self.hidden_act != 'silu': + raise NotImplementedError( + f"Activation {self.hidden_act} not yet implemented for gemm_swiglu_plugin" + ) + + if self.bias: + raise NotImplementedError( + f"bias not yet implemented for gemm_swiglu_plugin fp8") + + assert isinstance( + self.fused_fc, + FP8Linear), "fp8 gemm_swiglu only supports fp8 weights" + assert isinstance( + self.proj, + FP8RowLinear), "fp8 gemm_swiglu only supports fp8 weights" + assert self.fused_fc.weight.shape == ( + self.hidden_size, self.ffn_hidden_size * 2 // + self.tp_size), "fp8 gemm_swiglu only supports (k, n) weights" + + scale_d0 = (self.fused_fc.weights_scaling_factor.raw_value.item() * + self.fused_fc.activation_scaling_factor.raw_value.item()) + scale_d1 = scale_d0 + scale_output = 1.0 / self.proj.activation_scaling_factor.raw_value.item( + ) + activation_scaling_factor = cast( + self.fused_fc.activation_scaling_factor.value, self.dtype) + if hidden_states.dtype != trt.fp8: + hidden_states = quantize(hidden_states, activation_scaling_factor, + 'fp8') + + inter = gemm_swiglu(hidden_states, self.fused_fc.weight.value, None, + scale_d0, scale_d1, scale_output) + + return inter + + def fc_gate(self, hidden_states, lora_layer_params=None): # Combine the following pattern # # SiLU(FC(x)) + Gate(x) @@ -221,8 +300,8 @@ def forward(self, hidden_states, lora_layer_params=None): host_context_lengths, max_context_length=mlp_fc_lora_params.max_context_length) - mlp_fc_lora, mlp_gate_lora = self.mlp_in_lora( - hidden_states, mlp_in_lora_params) + mlp_fc_lora, mlp_gate_lora = self.lora(hidden_states, + mlp_in_lora_params) mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora], dim=mlp_fc_lora.rank() - 1) inter = inter + mlp_in_result @@ -235,6 +314,17 @@ def forward(self, hidden_states, lora_layer_params=None): raise NotImplementedError( f"Activation {self.hidden_act} not yet implemented for FusedGatedMLP" ) + return inter + + def forward(self, hidden_states, lora_layer_params=None): + if default_net().plugin_config.gemm_swiglu_plugin: + assert self.dtype == 'float16', f"Currently limited support, got {self.dtype}" + inter = self.fc_gate_plugin(hidden_states, lora_layer_params) + else: + inter = self.fc_gate(hidden_states, lora_layer_params) + + if self.inner_layernorm is not None: + inter = self.inner_layernorm(inter) mlp_proj_lora_params = None if lora_layer_params is not None: diff --git a/tensorrt_llm/layers/moe.py b/tensorrt_llm/layers/moe.py index 8ce5b143a..5a0fdc475 100644 --- a/tensorrt_llm/layers/moe.py +++ b/tensorrt_llm/layers/moe.py @@ -12,15 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from dataclasses import dataclass +from dataclasses import asdict, dataclass from enum import IntEnum from typing import List, Type, Union import numpy as np import tensorrt as trt -from tensorrt_llm._utils import str_dtype_to_trt +from tensorrt_llm._utils import get_init_params, str_dtype_to_trt from tensorrt_llm.layers.lora import LoraParams from .._common import default_net, default_trtnet @@ -74,6 +73,13 @@ def validate(self) -> "MoeConfig": def has_moe(self) -> bool: return self.num_experts > 1 + @classmethod + def from_dict(cls, config: dict): + return cls(**config) + + def to_dict(self): + return asdict(self) + def _moe_plugin(moe_config, hidden_states, @@ -449,13 +455,7 @@ def to(self, config=None) -> "MixtureOfExperts": from ..quantization.quantize import quantize - # initialize subclass with all parameters in __init__ of base class - new_moe = moe_cls( - **{ - name: getattr(self, name) - for name in list( - inspect.signature(MixtureOfExperts.__init__).parameters)[1:] - }) + new_moe = moe_cls(**get_init_params(self)) if config is not None: quantize(new_moe, config.quantization) new_moe.load_weights(self) diff --git a/tensorrt_llm/layers/normalization.py b/tensorrt_llm/layers/normalization.py index 9aa4d95c9..50f39b779 100644 --- a/tensorrt_llm/layers/normalization.py +++ b/tensorrt_llm/layers/normalization.py @@ -37,6 +37,7 @@ def __init__(self, self.register_parameter('bias', None) self.eps = eps + self.dtype = dtype def forward(self, x): weight = 1. if self.weight is None else self.weight.value @@ -62,6 +63,7 @@ def __init__(self, self.register_parameter('weight', None) self.eps = eps + self.dtype = dtype def forward(self, x): weight = None if self.weight is None else self.weight.value diff --git a/tensorrt_llm/layers/recurrent.py b/tensorrt_llm/layers/recurrent.py index 95154bff6..0145d43c6 100644 --- a/tensorrt_llm/layers/recurrent.py +++ b/tensorrt_llm/layers/recurrent.py @@ -120,15 +120,13 @@ def __init__(self, num_heads=1, dtype=None, tp_group=None, - tp_size=1, - tp_rank=0): + tp_size=1): super().__init__() self.lru_width = lru_width self.dtype = dtype self.num_heads = num_heads self.tp_group = tp_group self.tp_size = tp_size - self.tp_rank = tp_rank self.recurrent_param = Parameter(shape=(self.lru_width // self.tp_size, ), @@ -184,8 +182,7 @@ def __init__(self, num_heads=1, dtype=None, tp_group=None, - tp_size=1, - tp_rank=0): + tp_size=1): super().__init__() self.lru_width = lru_width self.tp_size = tp_size @@ -240,7 +237,6 @@ def __init__( dtype=None, tp_group=None, tp_size=1, - tp_rank=0, ): super().__init__() self.width = width @@ -273,8 +269,7 @@ def __init__( num_heads=num_heads, dtype=dtype, tp_group=tp_group, - tp_size=tp_size, - tp_rank=tp_rank) + tp_size=tp_size) self.linear_out = RowLinear(self.lru_width, self.width, diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index e8bff6f4d..328a9e6bd 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -565,6 +565,7 @@ def load_from_model_dir(uid, model_dir, hf_config): all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component) rank = int(hf_config["r"]) + rs_lora = bool(hf_config.get("use_rslora", False)) self._lora_uid_to_low_ranks[uid] = {} self._lora_weights_pointers_list[uid] = {} @@ -629,7 +630,10 @@ def load_from_model_dir(uid, model_dir, hf_config): t_in = t_in.cuda().contiguous() t_out = t_out.cuda().contiguous() - scale = float(hf_config["lora_alpha"]) / rank + if rs_lora: + scale = float(hf_config["lora_alpha"]) / np.sqrt(rank) + else: + scale = float(hf_config["lora_alpha"]) / rank t_out = t_out * scale t_in = t_in.to(str_dtype_to_torch(dtype)) t_out = t_out.to(str_dtype_to_torch(dtype)) diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index 5799ccf1d..200b0d58f 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -111,3 +111,16 @@ def ep_experts(self, num_experts: int) -> List[int]: experts_range = range(self.tp_rank * experts_per_rank, (self.tp_rank + 1) * experts_per_rank) return list(experts_range) + + @classmethod + def from_dict(cls, mapping: dict): + return cls(**mapping) + + def to_dict(self): + return { + 'world_size': self.world_size, + 'rank': self.rank, + 'gpus_per_node': self.gpus_per_node, + 'tp_size': self.tp_size, + 'pp_size': self.pp_size + } diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index 83518d784..b09e1d505 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -17,23 +17,30 @@ BertForSequenceClassification, BertModel) from .bloom.model import BloomForCausalLM, BloomModel from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel +from .cogvlm.config import CogVLMConfig from .cogvlm.model import CogVLMForCausalLM +from .dbrx.config import DbrxConfig from .dbrx.model import DbrxForCausalLM from .dit.model import DiT from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder from .falcon.model import FalconForCausalLM, FalconModel from .gemma.model import GemmaForCausalLM +from .gpt.config import GPTConfig from .gpt.model import GPTForCausalLM, GPTModel from .gptj.model import GPTJForCausalLM, GPTJModel from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel +from .grok.model import GrokForCausalLM +from .llama.config import LLaMAConfig from .llama.model import LLaMAForCausalLM, LLaMAModel from .mamba.model import MambaForCausalLM +from .medusa.config import MedusaConfig from .medusa.model import MedusaForCausalLm from .modeling_utils import (PretrainedConfig, PretrainedModel, SpeculativeDecodingMode) from .mpt.model import MPTForCausalLM, MPTModel from .opt.model import OPTForCausalLM, OPTModel from .phi3.model import Phi3ForCausalLM, Phi3Model +from .phi3.phi3small.model import Phi3SmallForCausalLM, Phi3SmallModel from .phi.model import PhiForCausalLM, PhiModel from .qwen.model import QWenForCausalLM from .recurrentgemma.model import RecurrentGemmaForCausalLM @@ -47,12 +54,15 @@ 'DiT', 'FalconForCausalLM', 'FalconModel', + 'GPTConfig', 'GPTModel', 'GPTForCausalLM', 'OPTForCausalLM', 'OPTModel', + 'LLaMAConfig', 'LLaMAForCausalLM', 'LLaMAModel', + 'MedusaConfig', 'MedusaForCausalLm', 'GPTJModel', 'GPTJForCausalLM', @@ -60,8 +70,10 @@ 'GPTNeoXForCausalLM', 'PhiModel', 'Phi3Model', + 'Phi3SmallModel', 'PhiForCausalLM', 'Phi3ForCausalLM', + 'Phi3SmallForCausalLM', 'ChatGLMForCausalLM', 'ChatGLMModel', 'BaichuanForCausalLM', @@ -76,8 +88,10 @@ 'MPTModel', 'SkyworkForCausalLM', 'GemmaForCausalLM', + 'DbrxConfig', 'DbrxForCausalLM', 'RecurrentGemmaForCausalLM', + 'CogVLMConfig', 'CogVLMForCausalLM', 'SpeculativeDecodingMode', ] @@ -89,6 +103,7 @@ 'FalconForCausalLM': FalconForCausalLM, 'PhiForCausalLM': PhiForCausalLM, 'Phi3ForCausalLM': Phi3ForCausalLM, + 'Phi3SmallForCausalLM': Phi3SmallForCausalLM, 'MambaForCausalLM': MambaForCausalLM, 'GPTNeoXForCausalLM': GPTNeoXForCausalLM, 'GPTJForCausalLM': GPTJForCausalLM, @@ -98,7 +113,9 @@ 'MistralForCausalLM': LLaMAForCausalLM, 'MixtralForCausalLM': LLaMAForCausalLM, 'ArcticForCausalLM': LLaMAForCausalLM, + 'Grok1ModelForCausalLM': GrokForCausalLM, 'InternLMForCausalLM': LLaMAForCausalLM, + 'InternLM2ForCausalLM': LLaMAForCausalLM, 'MedusaForCausalLM': MedusaForCausalLm, 'BaichuanForCausalLM': BaichuanForCausalLM, 'SkyworkForCausalLM': LLaMAForCausalLM, diff --git a/tensorrt_llm/models/cogvlm/config.py b/tensorrt_llm/models/cogvlm/config.py new file mode 100644 index 000000000..dbeac68e3 --- /dev/null +++ b/tensorrt_llm/models/cogvlm/config.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ..modeling_utils import PretrainedConfig + + +class CogVLMConfig(PretrainedConfig): + + def __init__(self, + *, + mlp_bias: bool = False, + attn_bias: bool = False, + rotary_base: float = 10000.0, + rotary_scaling: Optional[dict] = None, + vision_start: int = 1, + vision_length: int = 1225, + **kwargs): + self.mlp_bias = mlp_bias + self.attn_bias = attn_bias + self.rotary_base = rotary_base + self.rotary_scaling = rotary_scaling + self.vision_start = vision_start + self.vision_length = vision_length + super().__init__(**kwargs) + + def to_dict(self): + output = super().to_dict() + # Serialize the fields added in CogVLMConfig + output['mlp_bias'] = self.mlp_bias + output['attn_bias'] = self.attn_bias + output['rotary_base'] = self.rotary_base + output['rotary_scaling'] = self.rotary_scaling + output['vision_start'] = self.vision_start + output['vision_length'] = self.vision_length + return output diff --git a/tensorrt_llm/models/cogvlm/convert.py b/tensorrt_llm/models/cogvlm/convert.py index 010892795..75d1e17ae 100644 --- a/tensorrt_llm/models/cogvlm/convert.py +++ b/tensorrt_llm/models/cogvlm/convert.py @@ -26,8 +26,7 @@ def convert_hf_cogvlm(hf_model, int8_kv_cache=False, act_range=[], qkv_para=[], - smoother=[], - moe_config=None): + smoother=[]): weights = {} tik = time.time() diff --git a/tensorrt_llm/models/cogvlm/model.py b/tensorrt_llm/models/cogvlm/model.py index ec2cc6cb9..3b2122785 100644 --- a/tensorrt_llm/models/cogvlm/model.py +++ b/tensorrt_llm/models/cogvlm/model.py @@ -17,9 +17,8 @@ from ..._utils import pad_vocab_size from ...functional import (Tensor, concat, maximum, minimum, recv, send, shape, slice) -from ...layers import (MOE, AttentionMaskType, CogVLMAttention, ColumnLinear, - Embedding, GatedMLP, MoeConfig, PromptTuningEmbedding, - RmsNorm) +from ...layers import (AttentionMaskType, CogVLMAttention, ColumnLinear, + Embedding, GatedMLP, PromptTuningEmbedding, RmsNorm) from ...mapping import Mapping from ...module import Module from ...plugin import init_all_reduce_helper @@ -27,12 +26,13 @@ from ...quantization import QuantMode from ...top_model_mixin import TopModelMixin from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, - PretrainedConfig, QuantConfig) + QuantConfig) +from .config import CogVLMConfig class CogvlmDecoderLayer(Module): - def __init__(self, config: PretrainedConfig, layer_idx: int): + def __init__(self, config: CogVLMConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.config = config @@ -63,42 +63,25 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size - ClsMLP = GatedMLP - mlp_kwargs = {} - if config.moe_num_experts > 1: - ClsMLP = MOE - mlp_kwargs = { - "moe_config": - MoeConfig( - config.moe_num_experts, - config.moe_top_k, - config.moe_tp_mode, - config.moe_normalization_mode, - ), - "tp_rank": - config.mapping.tp_rank, - } self.vision_start = config.vision_start self.vision_length = config.vision_length self.hidden_size = config.hidden_size - self.mlp = ClsMLP(hidden_size=config.hidden_size, - ffn_hidden_size=mlp_hidden_size, - hidden_act=config.hidden_act, - dtype=config.dtype, - bias=config.mlp_bias, - tp_group=config.mapping.tp_group, - tp_size=config.mapping.tp_size, - quant_mode=config.quant_mode, - **mlp_kwargs) - self.vis_mlp = ClsMLP(hidden_size=config.hidden_size, - ffn_hidden_size=mlp_hidden_size, - hidden_act=config.hidden_act, - dtype=config.dtype, - bias=config.mlp_bias, - tp_group=config.mapping.tp_group, - tp_size=config.mapping.tp_size, - quant_mode=config.quant_mode, - **mlp_kwargs) + self.mlp = GatedMLP(hidden_size=config.hidden_size, + ffn_hidden_size=mlp_hidden_size, + hidden_act=config.hidden_act, + dtype=config.dtype, + bias=config.mlp_bias, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + quant_mode=config.quant_mode) + self.vis_mlp = GatedMLP(hidden_size=config.hidden_size, + ffn_hidden_size=mlp_hidden_size, + hidden_act=config.hidden_act, + dtype=config.dtype, + bias=config.mlp_bias, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + quant_mode=config.quant_mode) self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype) @@ -156,7 +139,7 @@ def forward(self, class CogvlmModel(Module): - def __init__(self, config: PretrainedConfig) -> None: + def __init__(self, config: CogVLMConfig) -> None: super().__init__() init_all_reduce_helper() @@ -231,9 +214,9 @@ def forward(self, class CogVLMForCausalLM(DecoderModelForCausalLM, TopModelMixin): + config_class = CogVLMConfig - def __init__(self, config: PretrainedConfig): - self.check_config(config) + def __init__(self, config: CogVLMConfig): transformer = CogvlmModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) @@ -251,19 +234,6 @@ def __init__(self, config: PretrainedConfig): self.mapping = config.mapping super().__init__(config, transformer, lm_head) - def check_config(self, config): - config.set_if_not_exist('mlp_bias', False) - config.set_if_not_exist('attn_bias', False) - config.set_if_not_exist('rotary_base', 10000.0) - config.set_if_not_exist('rotary_scaling', None) - config.set_if_not_exist('moe_num_experts', 0) - config.set_if_not_exist('moe_top_k', 0) - config.set_if_not_exist('moe_tp_mode', - MoeConfig.ParallelismMode.TENSOR_PARALLEL) - config.set_if_not_exist( - 'moe_normalization_mode', - MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE) - @classmethod def from_hugging_face(cls, hf_model_dir, diff --git a/tensorrt_llm/models/convert_utils.py b/tensorrt_llm/models/convert_utils.py index 200200ac3..45cca2fc3 100644 --- a/tensorrt_llm/models/convert_utils.py +++ b/tensorrt_llm/models/convert_utils.py @@ -202,6 +202,10 @@ def iterate_shard_files(model_dir: Union[Path, str], yield shard_file +def has_safetensors(model_dir: str): + return len(list(Path(model_dir).glob('*.safetensors'))) > 0 + + DEFAULT_HF_DATASET_META = { 'ccdv/cnn_dailymail': ('3.0.0', 'train', 'article'), 'cnn_dailymail': ('3.0.0', 'train', 'article'), diff --git a/tensorrt_llm/models/dbrx/config.py b/tensorrt_llm/models/dbrx/config.py new file mode 100644 index 000000000..643d6c3ff --- /dev/null +++ b/tensorrt_llm/models/dbrx/config.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +from ...layers import MoeConfig +from ..modeling_utils import PretrainedConfig + + +class DbrxConfig(PretrainedConfig): + + def __init__(self, + *, + bias: bool = False, + clip_qkv: Optional[float] = None, + rotary_base: float = 500000.0, + rotary_scaling: Optional[dict] = None, + moe: Optional[Union[MoeConfig, dict]] = None, + **kwargs): + self.bias = bias + self.clip_qkv = clip_qkv + self.rotary_base = rotary_base + self.rotary_scaling = rotary_scaling + if moe is None: + # Legacy MOE config fields + moe = MoeConfig( + num_experts=kwargs.pop('moe_num_experts', 0), + top_k=kwargs.pop('moe_top_k', 0), + tp_mode=kwargs.pop('moe_tp_mode', + MoeConfig.ParallelismMode.TENSOR_PARALLEL), + normalization_mode=kwargs.pop( + 'moe_normalization_mode', + MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE)) + elif isinstance(moe, dict): + moe = MoeConfig.from_dict(moe) + assert isinstance(moe, MoeConfig) + self.moe = moe.validate() + + super().__init__(**kwargs) + + def to_dict(self): + output = super().to_dict() + # Serialize the fields added in DbrxConfig + output['bias'] = self.bias + output['clip_qkv'] = self.clip_qkv + output['rotary_base'] = self.rotary_base + output['rotary_scaling'] = self.rotary_scaling + output['moe'] = self.moe.to_dict() + return output diff --git a/tensorrt_llm/models/dbrx/model.py b/tensorrt_llm/models/dbrx/model.py index d678e43d6..623ec52f6 100644 --- a/tensorrt_llm/models/dbrx/model.py +++ b/tensorrt_llm/models/dbrx/model.py @@ -16,15 +16,15 @@ from ..._utils import pad_vocab_size from ...functional import Tensor, recv, send from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, - Embedding, GatedMLP, LayerNorm, MoeConfig) + Embedding, GatedMLP, LayerNorm) from ...module import Module -from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, - PretrainedConfig) +from ..modeling_utils import DecoderLayerList, DecoderModelForCausalLM +from .config import DbrxConfig class DbrxDecoderLayer(Module): - def __init__(self, config: PretrainedConfig, layer_idx: int): + def __init__(self, config: DbrxConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.config = config @@ -54,18 +54,11 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): ClsMLP = GatedMLP mlp_kwargs = {} - if config.moe_config['num_experts'] > 1: + if config.moe.has_moe(): ClsMLP = MOE mlp_kwargs = { - "moe_config": - MoeConfig( - config.moe_config['num_experts'], - config.moe_config['top_k'], - config.moe_config['tp_mode'], - config.moe_config['normalization_mode'], - ), - "tp_rank": - config.mapping.tp_rank, + "moe_config": config.moe, + "tp_rank": config.mapping.tp_rank, } self.mlp = ClsMLP(hidden_size=config.hidden_size, @@ -119,7 +112,7 @@ def forward(self, class DbrxModel(Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, config: DbrxConfig): super().__init__() self.config = config @@ -171,9 +164,9 @@ def forward(self, class DbrxForCausalLM(DecoderModelForCausalLM): + config_class = DbrxConfig - def __init__(self, config: PretrainedConfig): - self.check_config(config) + def __init__(self, config: DbrxConfig): transformer = DbrxModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) @@ -190,16 +183,3 @@ def __init__(self, config: PretrainedConfig): self.quant_mode = config.quant_mode self.mapping = config.mapping super().__init__(config, transformer, lm_head) - - def check_config(self, config): - config.set_if_not_exist('bias', False) - config.set_if_not_exist('clip_qkv', None) - config.set_if_not_exist('rotary_base', 500000.0) - config.set_if_not_exist('rotary_scaling', None) - config.set_if_not_exist('moe_num_experts', 0) - config.set_if_not_exist('moe_top_k', 0) - config.set_if_not_exist('moe_tp_mode', - MoeConfig.ParallelismMode.TENSOR_PARALLEL) - config.set_if_not_exist( - 'moe_normalization_mode', - MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE) diff --git a/tensorrt_llm/models/enc_dec/model.py b/tensorrt_llm/models/enc_dec/model.py index 33d0a8942..3bb477abb 100644 --- a/tensorrt_llm/models/enc_dec/model.py +++ b/tensorrt_llm/models/enc_dec/model.py @@ -22,8 +22,8 @@ from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, MLPType, PositionEmbeddingType, Tensor, - assertion, gather_last_token_logits, gelu, - maximum, minimum, recv, send, shape, + assertion, cast, gather_last_token_logits, + gelu, maximum, minimum, recv, send, shape, transpose) from tensorrt_llm.layers import (MLP, Attention, AttentionMaskType, AttentionParams, BertAttention, ColumnLinear, @@ -1211,7 +1211,7 @@ def prepare_inputs(self, max_output_len_range = [0, (max_output_len + 1) // 2, max_output_len] encoder_num_tokens_range = [ - 1, + 0, # 0 for generation phase, >0 for context phase (max_encoder_input_len * max_batch_size + 1) // 2, max_encoder_input_len * max_batch_size, ] @@ -1225,7 +1225,9 @@ def prepare_inputs(self, # No enable_two_optimization_profiles support yet encoder_input_len_range = [ - 1, (max_encoder_input_len + 1) // 2, max_encoder_input_len + 0, # 0 for generation phase, >0 for context phase + (max_encoder_input_len + 1) // 2, + max_encoder_input_len ] past_key_value = [] sequence_length = None @@ -1797,7 +1799,7 @@ def forward(self, x: Tensor, input_lengths=None): x = self.conv2(x) x = gelu(x) x = transpose(x, 2, 1) - x = x + self.positional_embedding.value + x = x + cast(self.positional_embedding.value, x.dtype) hidden_states = x for encoder_layer in self.encoder_layers: diff --git a/tensorrt_llm/models/gemma/model.py b/tensorrt_llm/models/gemma/model.py index 20dc24eae..4c4753547 100644 --- a/tensorrt_llm/models/gemma/model.py +++ b/tensorrt_llm/models/gemma/model.py @@ -251,7 +251,7 @@ def from_hugging_face(cls, 'rotary_base': getattr(cfg, 'rotary_base', 10000.0), 'rotary_scaling': getattr(cfg, 'rotary_scaling', None), 'norm_epsilon': cfg.rms_norm_eps, - 'quantization': quantization.asdict(), + 'quantization': quantization.to_dict(), 'mapping': { 'world_size': mapping.world_size, 'tp_size': mapping.world_size, diff --git a/tensorrt_llm/models/gpt/config.py b/tensorrt_llm/models/gpt/config.py new file mode 100644 index 000000000..5dc25b592 --- /dev/null +++ b/tensorrt_llm/models/gpt/config.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +from ...layers import MoeConfig +from ..modeling_utils import PretrainedConfig + + +class GPTConfig(PretrainedConfig): + + def __init__(self, + *, + bias: bool = True, + q_scaling: float = 1.0, + embedding_scale: Optional[float] = None, + apply_query_key_layer_scaling: bool = False, + rotary_pct: float = 1.0, + rotary_base: float = 10000.0, + rotary_scaling: Optional[dict] = None, + moe: Optional[Union[MoeConfig, dict]] = None, + **kwargs): + self.bias = bias + self.q_scaling = q_scaling + self.embedding_scale = embedding_scale + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.rotary_pct = rotary_pct + self.rotary_base = rotary_base + self.rotary_scaling = rotary_scaling + if moe is None: + # Legacy MOE config fields + moe = MoeConfig( + num_experts=kwargs.pop('moe_num_experts', 0), + top_k=kwargs.pop('moe_top_k', 0), + tp_mode=kwargs.pop('moe_tp_mode', + MoeConfig.ParallelismMode.TENSOR_PARALLEL), + normalization_mode=kwargs.pop( + 'moe_normalization_mode', + MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE)) + elif isinstance(moe, dict): + moe = MoeConfig.from_dict(moe) + assert isinstance(moe, MoeConfig) + self.moe = moe.validate() + + super().__init__(**kwargs) + + def to_dict(self): + output = super().to_dict() + # Serialize the fields added in GPTConfig + output['bias'] = self.bias + output['q_scaling'] = self.q_scaling + output['embedding_scale'] = self.embedding_scale + output[ + 'apply_query_key_layer_scaling'] = self.apply_query_key_layer_scaling + output['rotary_pct'] = self.rotary_pct + output['rotary_base'] = self.rotary_base + output['rotary_scaling'] = self.rotary_scaling + output['moe'] = self.moe.to_dict() + return output diff --git a/tensorrt_llm/models/gpt/model.py b/tensorrt_llm/models/gpt/model.py index a8a46f739..2b52123e0 100644 --- a/tensorrt_llm/models/gpt/model.py +++ b/tensorrt_llm/models/gpt/model.py @@ -22,8 +22,8 @@ from ...lora_manager import LoraConfig, use_lora from ...module import Module from ...quantization import QuantMode -from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, - PretrainedConfig) +from ..modeling_utils import DecoderLayerList, DecoderModelForCausalLM +from .config import GPTConfig def MLPFactory(hidden_size, @@ -67,7 +67,7 @@ def MLPFactory(hidden_size, class GPTDecoderLayer(Module): - def __init__(self, config: PretrainedConfig, layer_idx: int): + def __init__(self, config: GPTConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.config = config @@ -111,20 +111,13 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size self.norm_before_bmm1 = config.norm_before_bmm1 if hasattr( config, "norm_before_bmm1") else False - moe_config = MoeConfig() - if config.moe_num_experts > 1: - moe_config = MoeConfig( - config.moe_num_experts, - config.moe_top_k, - config.moe_tp_mode, - config.moe_normalization_mode, - ) + self.mlp = MLPFactory(hidden_size=config.hidden_size, ffn_hidden_size=mlp_hidden_size, hidden_act=config.hidden_act, dtype=config.dtype, bias=config.bias, - moe_config=moe_config, + moe_config=config.moe, tp_group=tp_group, tp_size=tp_size, tp_rank=tp_rank, @@ -155,6 +148,7 @@ def forward(self, hidden_states, attention_mask=attention_mask, use_cache=use_cache, + spec_decoding_params=spec_decoding_params, kv_cache_params=kv_cache_params, attention_params=attention_params, lora_layer_params=lora_layer_params, @@ -179,7 +173,7 @@ def forward(self, class GPTModel(Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, config: GPTConfig): super().__init__() self.mapping = config.mapping self.position_embedding_type = config.position_embedding_type @@ -250,9 +244,9 @@ def forward(self, class GPTForCausalLM(DecoderModelForCausalLM): + config_class = GPTConfig - def __init__(self, config: PretrainedConfig): - self.check_config(config) + def __init__(self, config: GPTConfig): transformer = GPTModel(config) if config.mapping.is_last_pp_rank(): @@ -269,21 +263,5 @@ def __init__(self, config: PretrainedConfig): lm_head = None super().__init__(config, transformer, lm_head) - def check_config(self, config: PretrainedConfig): - config.set_if_not_exist('bias', True) - config.set_if_not_exist('q_scaling', 1) - config.set_if_not_exist('embedding_scale', None) - config.set_if_not_exist('apply_query_key_layer_scaling', False) - config.set_if_not_exist('rotary_pct', 1.0) - config.set_if_not_exist('rotary_base', 10000.0) - config.set_if_not_exist('rotary_scaling', None) - config.set_if_not_exist('moe_num_experts', 0) - config.set_if_not_exist('moe_top_k', 0) - config.set_if_not_exist('moe_tp_mode', - MoeConfig.ParallelismMode.TENSOR_PARALLEL) - config.set_if_not_exist( - 'moe_normalization_mode', - MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE) - def use_lora(self, lora_config: LoraConfig): use_lora(self, lora_config) diff --git a/tensorrt_llm/models/grok/__init__.py b/tensorrt_llm/models/grok/__init__.py new file mode 100644 index 000000000..71bf6d298 --- /dev/null +++ b/tensorrt_llm/models/grok/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorrt_llm/models/grok/convert.py b/tensorrt_llm/models/grok/convert.py new file mode 100644 index 000000000..1fc3e2455 --- /dev/null +++ b/tensorrt_llm/models/grok/convert.py @@ -0,0 +1,494 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from pathlib import Path +from typing import Optional + +import jax +import torch +from jax import dlpack as jax_dlpack +from torch.utils import dlpack as torch_dlpack + +from ..._utils import pad_vocab_size, release_gc +from ...layers import MoeConfig +from ...logger import logger +from ...quantization import QuantAlgo +from ..modeling_utils import PretrainedConfig, QuantConfig, optimize_model + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return torch.chunk(v, tp_size)[idx].contiguous() + else: + return torch.chunk(v, tp_size, dim=dim)[idx] + + +def split_matrix_tp(v, tensor_parallel, rank, dim): + return split(v, tensor_parallel, rank, dim=dim) + + +def get_weight(config, prefix, dtype, postfix='.weight'): + if config[prefix + postfix].dtype != dtype: + config[prefix + postfix].data = config[prefix + postfix].to(dtype) + return config[prefix + postfix].detach().cpu() + + +def get_jax_weight(config, prefix, dtype, postfix='.weight', key_name='scale'): + return torch.as_tensor((config[prefix + postfix][key_name])._value, + dtype=dtype).T + + +def get_jax_weight_scale(params, key, rank): + jax_obj = params[key]['w'] + jax_scales = jax.device_put(jax_obj.scales, device=jax.devices('cpu')[0]) + # jax_scales = jax.device_put(jax_obj.scales, device=jax.devices('gpu')[rank]) + torch_scales = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(jax_scales)) + return torch.as_tensor(jax_obj.weight._value, + dtype=torch.int8), torch_scales + + +def get_tllm_linear_weight( + weight, + torch_weight_scales, + prefix, + plugin_weight_only_quant_type=torch.int8, + postfix='weight', +): + results = {} + processed_weight = torch.ops.trtllm.preprocess_weights_for_mixed_gemm( + weight.contiguous(), plugin_weight_only_quant_type, torch.bfloat16) + results[prefix + postfix] = processed_weight + + results[prefix + 'per_channel_scale'] = torch_weight_scales.contiguous() + + return results + + +def convert_grok(hf_model, + config, + mapping, + vocab_size=32000, + dtype='float32', + use_parallel_embedding=False, + sharding_dim=0, + use_weight_only=False, + share_embedding_table=False, + use_gemm_woq_plugin=False, + plugin_weight_only_quant_type=torch.int8, + moe_config=None): + + weights = {} + tik = time.time() + tensor_parallel = mapping.tp_size + model_params = hf_model + dtype = getattr(torch, dtype) + + num_attention_heads = config['num_attention_heads'] + hidden_size = config['hidden_size'] + hidden_size // num_attention_heads + + layers_range = mapping.pp_layers(config['num_hidden_layers']) + + # layers_range = mapping.pp_layers(2) + + def convert_layer(l): + prefix = f'transformer/decoder_layer_{l}/' + print(prefix) + tllm_prex = f'transformer.layers.{l - layers_range[0]}.' + + q_weight, q_scale = get_jax_weight_scale( + model_params, prefix + 'multi_head_attention/query', + mapping.tp_rank) + k_weight, k_scale = get_jax_weight_scale( + model_params, prefix + 'multi_head_attention/key', mapping.tp_rank) + v_weight, v_scale = get_jax_weight_scale( + model_params, prefix + 'multi_head_attention/value', + mapping.tp_rank) + + wq = split(q_weight, mapping.tp_size, mapping.tp_rank, dim=1) + wk = split(k_weight, mapping.tp_size, mapping.tp_rank, dim=1) + wv = split(v_weight, mapping.tp_size, mapping.tp_rank, dim=1) + qs = split(q_scale, mapping.tp_size, mapping.tp_rank, dim=1) + ks = split(k_scale, mapping.tp_size, mapping.tp_rank, dim=1) + vs = split(v_scale, mapping.tp_size, mapping.tp_rank, dim=1) + split_v = torch.concat((wq, wk, wv), dim=1) + scale_v = torch.concat((qs, ks, vs), dim=1) + + weights.update( + get_tllm_linear_weight(split_v, scale_v.squeeze(), + tllm_prex + 'attention.qkv.', + plugin_weight_only_quant_type)) + + attn_dense_weight, attn_dense_scales = get_jax_weight_scale( + model_params, prefix + 'multi_head_attention/linear', + mapping.tp_rank) + + split_v = split_matrix_tp(attn_dense_weight, + tensor_parallel, + mapping.tp_rank, + dim=0) + split_scales = split_matrix_tp(attn_dense_scales, + tensor_parallel, + mapping.tp_rank, + dim=0) + + weights.update( + get_tllm_linear_weight(split_v, split_scales.squeeze(), + tllm_prex + 'attention.dense.', + plugin_weight_only_quant_type)) + + if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL: + mapping.ep_experts(moe_config.num_experts) + + w3, s3 = get_jax_weight_scale( + model_params, f'transformer/decoder_layer_{l}/moe/linear_v', + mapping.tp_rank) + + w2, s2 = get_jax_weight_scale( + model_params, f'transformer/decoder_layer_{l}/moe/linear_1', + mapping.tp_rank) + + w1, s1 = get_jax_weight_scale( + model_params, f'transformer/decoder_layer_{l}/moe/linear', + mapping.tp_rank) + + if moe_config.tp_mode == moe_config.ParallelismMode.TENSOR_PARALLEL: + + w3_split = split(w3, mapping.tp_size, mapping.tp_rank, dim=2) + w2_split = split(w2, mapping.tp_size, mapping.tp_rank, dim=1) + w1_split = split(w1, mapping.tp_size, mapping.tp_rank, dim=2) + + s3_split = split(s3, mapping.tp_size, mapping.tp_rank, dim=2) + s2_split = split(s2, mapping.tp_size, mapping.tp_rank, dim=1) + s1_split = split(s1, mapping.tp_size, mapping.tp_rank, dim=2) + + weights.update( + get_tllm_linear_weight(w2_split, + s2_split.reshape(moe_config.num_experts, -1), + tllm_prex + 'mlp.proj.', + plugin_weight_only_quant_type)) + + weights.update( + get_tllm_linear_weight( + torch.concat([w3_split, w1_split], dim=-1), + torch.concat([s3_split, s1_split], + dim=-1).reshape(moe_config.num_experts, -1), + tllm_prex + 'mlp.fc.', + plugin_weight_only_quant_type, + )) + + moe_experts_gate_weights = get_jax_weight(model_params, + prefix + 'router', + torch.float32, + postfix='', + key_name='w').contiguous() + + weights[tllm_prex + 'mlp.router.weight'] = moe_experts_gate_weights + + # Layer norms do not use tensor parallelism + input_ln_weight = get_jax_weight(model_params, + prefix + 'rms_norm', + dtype, + postfix='') + weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight + + post_attn_weight = get_jax_weight(model_params, + prefix + 'rms_norm_1', + dtype, + postfix='') + weights[tllm_prex + 'post_attn_layernorm.weight'] = post_attn_weight + + post_ln_weight = get_jax_weight(model_params, + prefix + 'rms_norm_2', + dtype, + postfix='') + weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight + + post_mlp_weight = get_jax_weight(model_params, + prefix + 'rms_norm_3', + dtype, + postfix='') + weights[tllm_prex + 'post_mlp_layernorm.weight'] = post_mlp_weight + + for l in layers_range: + convert_layer(l) + release_gc() + + v = get_jax_weight(model_params, + 'language_model/in_out_embed', + dtype, + postfix='', + key_name='embeddings').T + tie_word_embeddings = config['tie_word_embeddings'] + if tie_word_embeddings: + # lm_head.weight has the same weights as embedding + if mapping.is_last_pp_rank(): + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + + v = torch.nn.functional.pad(v, (0, pad_width, 0, 0), 'constant', + 0) + weights['lm_head.weight'] = split(v, mapping.tp_size, + mapping.tp_rank) + + if use_parallel_embedding: + v = split_matrix_tp(v, + mapping.tp_size, + mapping.tp_rank, + dim=sharding_dim) + + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = v + + ln_f_w = get_jax_weight(model_params, + 'language_model/rms_norm', + dtype, + postfix='') + weights['transformer.ln_f.weight'] = ln_f_w + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Weights loaded. Total time: {t}') + return weights + + +def create_config_from_xai(dtype, + mapping, + quantization: QuantConfig = None, + override_fields: dict = {}): + config = {} + hf_config = { + "architectures": ["Grok1ModelForCausalLM"], + "vocab_size": 131072, + "hidden_size": 6144, + "intermediate_size": 32768, + "num_hidden_layers": 64, + "num_attention_heads": 48, + "num_key_value_heads": 8, + "attn_output_multiplier": 0.08838834764831845, + "embedding_multiplier_scale": 78.38367176906169, + "output_multiplier_scale": 0.5773502691896257, + "max_attn_value": 30.0, + "max_position_embeddings": 8192, + "rms_norm_eps": 1e-5, + "use_cache": True, + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": True, + "num_experts_per_tok": 2, + "num_experts": 8, + "output_router_logits": False, + "router_aux_loss_coef": 0.001, + "torch_dtype": "bfloat16", + "transformers_version": "4.35.0" + } + # same for from_meta and from_cli_args + n_head = hf_config['num_attention_heads'] + inter_size = hf_config['intermediate_size'] + n_layer = hf_config['num_hidden_layers'] + # n_layer = 2 + n_embd = hf_config['hidden_size'] + n_kv_head = hf_config['num_key_value_heads'] + rms_norm_eps = hf_config['rms_norm_eps'] + vocab_size = hf_config['vocab_size'] + n_positions = hf_config['max_position_embeddings'] + hidden_act = 'geglu' + config['rotary_scaling'] = None + rotary_base = 10000.0 + + config[ + 'moe_normalization_mode'] = MoeConfig.ExpertScaleNormalizationMode.NONE + + moe_num_experts = hf_config['num_experts'] + + moe_top_k = hf_config['num_experts_per_tok'] + moe_tp_mode = MoeConfig.ParallelismMode.TENSOR_PARALLEL + + attn_output_multiplier = hf_config['attn_output_multiplier'] + embedding_multiplier_scale = hf_config['embedding_multiplier_scale'] + + output_multiplier_scale = hf_config['output_multiplier_scale'] + max_attn_value = hf_config['max_attn_value'] + + architecture = hf_config['architectures'][0] + + attn_bias = False + + config.update({ + 'architecture': + architecture, + 'dtype': + dtype, + 'logits_dtype': + 'float32', + 'num_hidden_layers': + n_layer, + 'num_attention_heads': + n_head, + 'hidden_size': + n_embd, + 'intermediate_size': + inter_size, + 'num_key_value_heads': + n_kv_head, + 'vocab_size': + vocab_size, + 'position_embedding_type': + 'rope_gpt_neox', + 'max_position_embeddings': + n_positions, + 'hidden_act': + hidden_act, + 'rotary_base': + rotary_base, + 'norm_epsilon': + rms_norm_eps, + 'moe_num_experts': + moe_num_experts, + 'moe_top_k': + moe_top_k, + 'moe_tp_mode': + moe_tp_mode, + 'moe_normalization_mode': + MoeConfig.ExpertScaleNormalizationMode.NONE, + #TODO: should have directly map from the Mapping object to the TRT-LLM checkpoint fields + 'mapping': { + 'world_size': mapping.tp_size * mapping.pp_size, + 'tp_size': mapping.tp_size, + 'pp_size': mapping.pp_size + }, + 'attn_bias': + attn_bias, + "attn_output_multiplier": + attn_output_multiplier, + "embedding_multiplier_scale": + embedding_multiplier_scale, + "output_multiplier_scale": + output_multiplier_scale, + "max_attn_value": + max_attn_value, + "tie_word_embeddings": + True, + }) + + config['quantization'] = quantization.to_dict() + config.update(override_fields) + + moe_config = MoeConfig(config['moe_num_experts'], config['moe_top_k'], + config['moe_tp_mode'], + config['moe_normalization_mode']).validate() + use_weight_only = config['quantization']['quant_algo'] in [ + QuantAlgo.W8A16, QuantAlgo.W4A16, QuantAlgo.FP8 + ] + if use_weight_only and moe_config.has_moe(): + config['quantization']['exclude_modules'].append('router') + + return config + + +def from_hugging_face(cls, + model_dir, + dtype, + *, + mapping, + quantization: QuantConfig = None, + override_fields={}, + skip_loading_weights=False, + preloaded_model=None): + ''' Create a LLaMAForCausalLM object from give parameters + ''' + assert model_dir is not None + if isinstance(model_dir, Path): # some code relies on this as string + model_dir = str(model_dir) + + if override_fields.get('share_embedding_table', False): + logger.warning( + "Llama model does not support share_embedding_table; setting share_embedding_table=False" + ) + override_fields['share_embedding_table'] = False + + config = create_config_from_xai(dtype, + mapping, + quantization, + override_fields=override_fields) + + pretrained_config = PretrainedConfig.from_dict(config) + pretrained_config.set_rank(mapping.rank) # TODO:remove this hack + + grok = cls.from_config(pretrained_config) + grok = optimize_model( + grok, + use_parallel_embedding=pretrained_config.use_parallel_embedding, + share_embedding_table=pretrained_config.share_embedding_table, + ) + + if skip_loading_weights: + return grok + + weights = load_weights_from_xai(config=config, + mapping=mapping, + model=preloaded_model) + + grok.load(weights) + return grok + + +def quantize(dtype, + model_dir, + output_dir, + mapping, + quantization: QuantConfig, + *, + override_fields, + dataset_cache_dir: Optional[str] = None): + ''' + Quantize the save the model as TRT-LLM checkpoint to output_dir + ''' + pass #The official grok-1 model is published under int8 wo format, we don't need to quantize again. + + +def load_weights_from_xai(*, config, mapping, model): + assert model is not None + plugin_weight_only_quant_type = None # the value does not matter when use_weight_only is False + quant_algo = config['quantization']['quant_algo'] + assert quant_algo == QuantAlgo.W8A16 + plugin_weight_only_quant_type = torch.int8 + + moe_config = MoeConfig(config['moe_num_experts'], config['moe_top_k'], + config['moe_tp_mode'], + config['moe_normalization_mode']).validate() + + use_weight_only = quant_algo in [QuantAlgo.W8A16] + + weights = convert_grok( + model, + config, + mapping, + vocab_size=config['vocab_size'], + dtype=config['dtype'], + use_weight_only=use_weight_only, + use_gemm_woq_plugin=not config.get('disable_weight_only_quant_plugin', + False), + plugin_weight_only_quant_type=plugin_weight_only_quant_type, + use_parallel_embedding=config.get('use_parallel_embedding', False), + sharding_dim=config.get('embedding_sharding_dim', 0), + share_embedding_table=config.get('share_embedding_table', False), + moe_config=moe_config) + return weights diff --git a/tensorrt_llm/models/grok/model.py b/tensorrt_llm/models/grok/model.py new file mode 100644 index 000000000..d33bbca62 --- /dev/null +++ b/tensorrt_llm/models/grok/model.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from ..._utils import pad_vocab_size +from ...functional import Tensor, recv, send +from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, + Embedding, MoeConfig, PositionEmbeddingType, RmsNorm) +from ...lora_manager import LoraConfig, use_lora +from ...mapping import Mapping +from ...module import Module +from ...plugin import init_all_reduce_helper +from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, + PretrainedConfig, QuantConfig) + + +class GrokDecoderLayer(Module): + + def __init__(self, config: PretrainedConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.config = config + + self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + layers_range = config.mapping.pp_layers(config.num_hidden_layers) + local_layer_idx = layer_idx - layers_range[0] + self.attention = Attention( + local_layer_idx=local_layer_idx, + hidden_size=config.hidden_size, + attention_head_size=config.head_size, + num_attention_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + dtype=config.dtype, + attention_mask_type=AttentionMaskType.causal, + bias=config.attn_bias, + position_embedding_type=PositionEmbeddingType.rope_gpt_neox, + rotary_embedding_base=config.rotary_base, + rotary_embedding_scaling=config.rotary_scaling, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + tp_rank=config.mapping.tp_rank, + quant_mode=config.quant_mode, + max_attn_value=config.max_attn_value) + + mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size + self.post_attn_layernorm = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + self.post_mlp_layernorm = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + mlp_kwargs = {} + assert config.moe_num_experts > 1, "Grok model is a MoE model." + ClsMLP = MOE + mlp_kwargs = { + "moe_config": + MoeConfig( + config.moe_num_experts, + config.moe_top_k, + config.moe_tp_mode, + config.moe_normalization_mode, + ), + "tp_rank": + config.mapping.tp_rank, + } + self.mlp = ClsMLP(hidden_size=config.hidden_size, + ffn_hidden_size=mlp_hidden_size, + hidden_act=config.hidden_act, + dtype=config.dtype, + bias=config.mlp_bias, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + quant_mode=config.quant_mode, + **mlp_kwargs) + self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + def forward(self, + hidden_states, + attention_mask=None, + use_cache=False, + spec_decoding_params=None, + kv_cache_params=None, + attention_params=None, + lora_layer_params=None): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + attention_output = self.attention( + hidden_states, + attention_mask=attention_mask, + use_cache=use_cache, + spec_decoding_params=spec_decoding_params, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + lora_layer_params=lora_layer_params) + + if use_cache: + attention_output, presents = attention_output + + attention_output = self.post_attn_layernorm(attention_output) + hidden_states = residual + attention_output + + residual_attn = hidden_states + + # regular llama/mixtral layers + hidden_states = self.post_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, + lora_layer_params=lora_layer_params) + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual_attn + hidden_states + + if use_cache: + return (hidden_states, presents) + return hidden_states + + +class GrokModel(Module): + + def __init__(self, config: PretrainedConfig) -> None: + super().__init__() + init_all_reduce_helper() + + self.mapping = config.mapping + if self.mapping.is_first_pp_rank(): + self.vocab_embedding = Embedding(config.vocab_size, + config.hidden_size, + dtype=config.dtype) + + self.layers = DecoderLayerList(GrokDecoderLayer, config) + + self.embedding_multiplier_scale = config.embedding_multiplier_scale + + if self.mapping.is_last_pp_rank(): + self.ln_f = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + def forward(self, + input_ids, + position_ids=None, + use_cache=False, + attention_mask=None, + spec_decoding_params=None, + kv_cache_params=None, + attention_params=None, + hidden_states=None, + prompt_embedding_table: Optional[Tensor] = None, + prompt_tasks: Optional[Tensor] = None, + prompt_vocab_size: Optional[Tensor] = None, + lora_params=None): + + ptuning_args = [ + prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else [] + + if self.mapping.is_first_pp_rank(): + hidden_states = self.vocab_embedding(input_ids, *ptuning_args) + hidden_states *= 78.38367176906169 + else: + hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) + + hidden_states = self.layers.forward( + hidden_states, + use_cache=use_cache, + attention_mask=attention_mask, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + lora_params=lora_params, + spec_decoding_params=spec_decoding_params) + + if use_cache: + hidden_states, presents = hidden_states + + if self.mapping.is_last_pp_rank(): + hidden_states = self.ln_f(hidden_states) + else: + hidden_states = send(hidden_states, self.mapping.next_pp_rank()) + + if use_cache: + return (hidden_states, tuple(presents)) + return hidden_states + + +class GrokForCausalLM(DecoderModelForCausalLM): + + def __init__(self, config: PretrainedConfig): + self.check_config(config) + transformer = GrokModel(config) + vocab_size_padded = pad_vocab_size(config.vocab_size, + config.mapping.tp_size) + if config.mapping.is_last_pp_rank(): + lm_head = ColumnLinear(config.hidden_size, + vocab_size_padded, + bias=False, + dtype=config.dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + gather_output=True) + else: + lm_head = None + self.quant_mode = config.quant_mode + self.mapping = config.mapping + super().__init__(config, transformer, lm_head) + + def check_config(self, config): + config.set_if_not_exist('mlp_bias', False) + config.set_if_not_exist('attn_bias', False) + config.set_if_not_exist('rotary_base', 10000.0) + config.set_if_not_exist('rotary_scaling', None) + config.set_if_not_exist('moe_num_experts', 0) + config.set_if_not_exist('moe_top_k', 0) + config.set_if_not_exist('moe_tp_mode', + MoeConfig.ParallelismMode.TENSOR_PARALLEL) + config.set_if_not_exist('moe_normalization_mode', + MoeConfig.ExpertScaleNormalizationMode.NONE) + + @classmethod + def from_hugging_face(cls, + hf_model_dir, + dtype='float16', + mapping: Optional[Mapping] = None, + **kwargs): + from . import convert + if mapping is None: + mapping = Mapping() + grok = convert.from_hugging_face( + cls, + hf_model_dir, + dtype, + mapping=mapping, + quantization=kwargs.get('quantization', QuantConfig()), + override_fields=kwargs.get('override_fields', {}), + skip_loading_weights=kwargs.get('skip_loading_weights', False), + preloaded_model=kwargs.get('preloaded_model', None)) + return grok + + def default_plugin_config(self, **kwargs): + plugin_config = super().default_plugin_config(**kwargs) + if self.quant_mode.is_int4_weight_only_per_group(): + plugin_config.set_weight_only_groupwise_quant_matmul_plugin() + return plugin_config + + @classmethod + def quantize( + cls, + hf_model_dir, + output_dir, + quant_config: QuantConfig, + *, + dtype='float16', + mapping: Optional[Mapping] = None, + calib_batches=512, + calib_batch_size=1, + random_seed=1234, + tokenizer_max_seq_length=2048, + **kwargs, + ): + pass + + def use_lora(self, lora_config: LoraConfig): + use_lora(self, lora_config) diff --git a/tensorrt_llm/models/grok/weight.py b/tensorrt_llm/models/grok/weight.py new file mode 100644 index 000000000..7446952ad --- /dev/null +++ b/tensorrt_llm/models/grok/weight.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +import numpy as np +import torch + + +def gen_suffix(rank, use_smooth_quant, quant_per_channel): + suffix = f"{rank}.bin" + if use_smooth_quant: + sq_prefix = "int8." + if quant_per_channel: + sq_prefix += "col." + suffix = sq_prefix + suffix + return suffix + + +def extract_layer_idx(name): + ss = name.split('.') + for s in ss: + if s.isdigit(): + return s + return None + + +def split(v: Union[np.ndarray, torch.Tensor], + tp_size: int, + tp_rank: int, + dim=0): + if tp_size == 1: + return v + assert len(v.shape) > 1 or dim == 0 + if isinstance(v, np.ndarray): + return np.ascontiguousarray( + np.split(v, tp_size, axis=dim)[tp_rank].copy()) + else: + assert v.shape[dim] % tp_size == 0, \ + 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' + split_size = v.shape[dim] // tp_size + return v.split(split_size, dim=dim)[tp_rank].clone().detach() + + +def dup_kv_weight(v, num_head, tp_size): + assert tp_size % num_head == 0 + reps = tp_size // num_head + head_size = v.shape[0] // num_head + v = v.reshape(num_head, head_size, + -1)[:, None, :, :].expand(num_head, reps, head_size, + v.shape[1]) + return v.reshape(num_head * reps * head_size, -1).clone().detach() diff --git a/tensorrt_llm/models/llama/config.py b/tensorrt_llm/models/llama/config.py new file mode 100644 index 000000000..9233dc19e --- /dev/null +++ b/tensorrt_llm/models/llama/config.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import sys +from pathlib import Path +from typing import Optional, Union + +import torch + +from ..._utils import torch_dtype_to_str +from ...layers import MoeConfig +from ...logger import logger +from ...mapping import Mapping +from ..modeling_utils import PretrainedConfig, QuantConfig + + +class LLaMAConfig(PretrainedConfig): + + def __init__(self, + *, + mlp_bias: bool = False, + attn_bias: bool = False, + rotary_base: float = 10000.0, + rotary_scaling: Optional[dict] = None, + residual_mlp: bool = False, + disable_weight_only_quant_plugin: bool = False, + moe: Optional[Union[MoeConfig, dict]] = None, + **kwargs): + self.mlp_bias = mlp_bias + self.attn_bias = attn_bias + self.rotary_base = rotary_base + self.rotary_scaling = rotary_scaling + self.residual_mlp = residual_mlp + self.disable_weight_only_quant_plugin = disable_weight_only_quant_plugin + if moe is None: + # Legacy MOE config fields + moe = MoeConfig( + num_experts=kwargs.pop('moe_num_experts', 0), + top_k=kwargs.pop('moe_top_k', 0), + tp_mode=kwargs.pop('moe_tp_mode', + MoeConfig.ParallelismMode.TENSOR_PARALLEL), + normalization_mode=kwargs.pop( + 'moe_normalization_mode', + MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE)) + elif isinstance(moe, dict): + moe = MoeConfig.from_dict(moe) + assert isinstance(moe, MoeConfig) + self.moe = moe.validate() + + super().__init__(**kwargs) + + def to_dict(self): + output = super().to_dict() + # Serialize the fields added in LLaMAConfig + output['mlp_bias'] = self.mlp_bias + output['attn_bias'] = self.attn_bias + output['rotary_base'] = self.rotary_base + output['rotary_scaling'] = self.rotary_scaling + output['residual_mlp'] = self.residual_mlp + output[ + 'disable_weight_only_quant_plugin'] = self.disable_weight_only_quant_plugin + output['moe'] = self.moe.to_dict() + return output + + @classmethod + def from_hugging_face( + cls, + hf_config_or_dir: Union[str, 'transformers.PretrainedConfig'], + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + import transformers + + if isinstance(hf_config_or_dir, transformers.PretrainedConfig): + hf_config = hf_config_or_dir + else: + hf_config_dir = str(hf_config_or_dir) + if "vila" in hf_config_dir: + sys.path.append(hf_config_dir + "/../VILA") + from llava.model import LlavaConfig, LlavaLlamaForCausalLM + transformers.AutoConfig.register("llava_llama", LlavaConfig) + transformers.AutoModelForCausalLM.register( + LlavaConfig, LlavaLlamaForCausalLM) + + hf_config = transformers.AutoConfig.from_pretrained( + hf_config_dir, trust_remote_code=True) + if hf_config.model_type == "llava": + # LLaVA = Vision model + Llama LLM + # We load a llava config and use its' text config as llama config + hf_config = LlavaConfig.from_pretrained( + hf_config_dir).text_config + if hf_config.model_type == "llava_llama": + hf_config.llm_cfg["architecture"] = hf_config.llm_cfg[ + "architectures"] + hf_config.llm_cfg["dtype"] = hf_config.llm_cfg["torch_dtype"] + hf_config = PretrainedConfig.from_dict(hf_config.llm_cfg) + + num_key_value_heads = getattr(hf_config, "num_key_value_heads", + hf_config.num_attention_heads) + hidden_act = hf_config.hidden_act + attn_bias = getattr(hf_config, 'bias', False) or getattr( + hf_config, 'attention_bias', False) + rotary_scaling = getattr(hf_config, "rope_scaling", None) + rotary_base = getattr(hf_config, "rope_theta", 10000.0) + residual_mlp = getattr(hf_config, "parallel_attn_mlp_res", False) + disable_weight_only_quant_plugin = kwargs.pop( + 'disable_weight_only_quant_plugin', False) + + if hf_config.model_type == "mixtral" or hf_config.model_type == "arctic": + # HF LLaMA-type models are implicitly using gated activation. + # With our MoE implementation, we must make it explicit + hidden_act = "swiglu" + moe_normalization_mode = MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE + else: + moe_normalization_mode = None + moe_num_experts = getattr(hf_config, "num_local_experts", 0) + moe_top_k = getattr(hf_config, "num_experts_per_tok", 0) + moe_tp_mode = kwargs.pop('moe_tp_mode', + MoeConfig.ParallelismMode.TENSOR_PARALLEL) + moe_config = MoeConfig(num_experts=moe_num_experts, + top_k=moe_top_k, + tp_mode=moe_tp_mode, + normalization_mode=moe_normalization_mode) + moe_config.validate() + + if dtype == 'auto': + dtype = getattr(hf_config, 'torch_dtype', None) + if dtype is None: + dtype = 'float16' + if isinstance(dtype, torch.dtype): + dtype = torch_dtype_to_str(dtype) + if dtype == 'float32': + dtype = 'float16' + if dtype == 'bfloat16' and torch.cuda.get_device_properties( + 0).major < 8: + logger.warning( + "Pre SM 80 GPUs do not support bfloat16, fallback to float16") + dtype = 'float16' + + return cls( + architecture='LlamaForCausalLM', + dtype=dtype, + num_hidden_layers=hf_config.num_hidden_layers, + num_attention_heads=hf_config.num_attention_heads, + hidden_size=hf_config.hidden_size, + intermediate_size=hf_config.intermediate_size, + num_key_value_heads=num_key_value_heads, + vocab_size=hf_config.vocab_size, + position_embedding_type='rope_gpt_neox', + max_position_embeddings=hf_config.max_position_embeddings, + hidden_act=hidden_act, + norm_epsilon=hf_config.rms_norm_eps, + attn_bias=attn_bias, + rotary_base=rotary_base, + rotary_scaling=rotary_scaling, + residual_mlp=residual_mlp, + disable_weight_only_quant_plugin=disable_weight_only_quant_plugin, + moe=moe_config, + mapping=mapping, + quantization=quant_config, + **kwargs) + + @classmethod + def from_meta_ckpt(cls, + meta_ckpt_dir: str, + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + + with open(Path(meta_ckpt_dir, "params.json")) as fp: + meta_config: dict = json.load(fp) + + n_embd = meta_config["dim"] + n_head = meta_config["n_heads"] + n_kv_head = meta_config.get("n_kv_heads", n_head) + vocab_size = meta_config.get("vocab_size", 32000) + + # Reset vocab_size to 32000 for LLama v2 checkpoint. + if vocab_size == -1: + vocab_size = 32000 + + if "hidden_dim" in meta_config: + inter_size = meta_config["hidden_dim"] + else: + multiple_of = meta_config.get("multiple_of", 1) + n_embd_ = int(4 * n_embd * 2 / 3) + ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1) + inter_size = multiple_of * ( + (int(n_embd_ * ffn_dim_multiplier) + multiple_of - 1) // + multiple_of) + + if dtype == 'auto': + dtype = 'bfloat16' + if dtype == 'bfloat16' and torch.cuda.get_device_properties( + 0).major < 8: + logger.warning( + "Pre SM 80 GPUs do not support bfloat16, fallback to float16") + dtype = 'float16' + + # meta checkpoint don't have vocab_size|hidden_act|rotary_base specified, use same default value as HF + return cls(architecture="LlamaForCausalLM", + dtype=dtype, + num_hidden_layers=meta_config["n_layers"], + num_attention_heads=n_head, + hidden_size=n_embd, + intermediate_size=inter_size, + num_key_value_heads=n_kv_head, + vocab_size=vocab_size, + position_embedding_type='rope_gpt_neox', + max_position_embeddings=2048, + hidden_act='silu', + rotary_base=meta_config.get('rope_theta', 10000), + norm_epsilon=meta_config["norm_eps"], + mapping=mapping, + quantization=quant_config, + **kwargs) diff --git a/tensorrt_llm/models/llama/convert.py b/tensorrt_llm/models/llama/convert.py index ba096999f..8ef2cd3d9 100644 --- a/tensorrt_llm/models/llama/convert.py +++ b/tensorrt_llm/models/llama/convert.py @@ -1,12 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import functools import json import os +import sys import time from collections import defaultdict from pathlib import Path -from typing import Optional +from typing import List, Optional +import numpy as np import safetensors import torch import torch.nn as nn @@ -15,24 +31,13 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.pytorch_utils import Conv1D -from ..._utils import pad_vocab_size, release_gc -from ...layers import MoeConfig +from ..._utils import pad_vocab_size, release_gc, str_dtype_to_torch from ...logger import logger -from ...mapping import Mapping from ...quantization import QuantAlgo -from ..convert_utils import load_calib_dataset -from ..modeling_utils import PretrainedConfig, QuantConfig, optimize_model -from .weight import load_from_hf_checkpoint, load_from_hf_safetensors - -try: - from transformers import LlavaConfig, LlavaForConditionalGeneration -except ImportError: - pass - -try: - pass -except ImportError: - pass +from ..convert_utils import (iterate_shard_files, load_calib_dataset, + load_state_dict, retrieved_layer_index_from_name) +from ..modeling_utils import PretrainedConfig +from .config import LLaMAConfig def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): @@ -608,9 +613,10 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): results[prefix + 'per_channel_scale'] = torch.Tensor(cur_per_channel_value).to( torch.float32).reshape(col_shape).contiguous().cuda() - results[prefix + 'act_scale'] = torch.Tensor( - vals['scale_y_quant_orig']).to(torch.float32).contiguous().cuda() - results[last_prefix] = torch.Tensor(vals['scale_x_orig_quant']).to( + results[prefix + 'act_scale'] = torch.Tensor([[ + vals['scale_y_quant_orig'] + ]]).to(torch.float32).contiguous().cuda() + results[last_prefix] = torch.Tensor([vals['scale_x_orig_quant']]).to( torch.float32).contiguous().cuda() if smoother_value is not None: @@ -626,39 +632,67 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): return results -def convert_hf_llama(hf_model, - mapping, - vocab_size=32000, - dtype='float32', - use_parallel_embedding=False, - sharding_dim=0, - use_weight_only=False, - share_embedding_table=False, - residual_mlp=False, - use_gemm_woq_plugin=False, - plugin_weight_only_quant_type=torch.int8, - use_smooth_quant=False, - per_channel=False, - per_token=False, - int8_kv_cache=False, - act_range=[], - qkv_para=[], - smoother=[], - moe_config=None): +def load_hf_llama(model_dir: str, load_model_on_cpu: bool = False): + if "vila" in model_dir: + sys.path.append(model_dir + "/../VILA") + from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa + from transformers import AutoModel + model = AutoModel.from_pretrained( + model_dir, + device_map='auto', + trust_remote_code=True, + ) + return model.llm + + hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + model_cls = AutoModelForCausalLM + if hf_config.model_type == "llava": + from transformers import LlavaForConditionalGeneration + model_cls = LlavaForConditionalGeneration + model = model_cls.from_pretrained( + model_dir, + device_map='auto' if not load_model_on_cpu else 'cpu', + torch_dtype='auto', + trust_remote_code=True, + ) + if hf_config.model_type == "llava": + model = model.language_model + return model + + +def load_weights_from_hf_model(hf_model, + config: LLaMAConfig, + act_range: Optional[dict] = None, + qkv_para: Optional[dict] = None, + smoother: Optional[dict] = None): + quant_algo = config.quantization.quant_algo + use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16] + if quant_algo == QuantAlgo.W8A16: + plugin_weight_only_quant_type = torch.int8 + elif quant_algo == QuantAlgo.W4A16: + plugin_weight_only_quant_type = torch.quint4x2 + else: + plugin_weight_only_quant_type = None + use_gemm_woq_plugin = (not config.disable_weight_only_quant_plugin) + + use_smooth_quant = config.quantization.use_plugin_sq + per_channel = use_smooth_quant and 'PER_CHANNEL' in quant_algo + per_token = use_smooth_quant and 'PER_TOKEN' in quant_algo + int8_kv_cache = config.quantization.kv_cache_quant_algo == QuantAlgo.INT8 + if use_smooth_quant or int8_kv_cache: + assert act_range is not None + assert qkv_para is not None + assert smoother is not None weights = {} tik = time.time() - tensor_parallel = mapping.tp_size model_params = dict(hf_model.named_parameters()) - dtype = getattr(torch, dtype) - num_attention_heads = hf_model.config.num_attention_heads - hidden_size = hf_model.config.hidden_size - head_size = hidden_size // num_attention_heads - intermediate_size = hf_model.config.intermediate_size - num_key_value_heads = getattr(hf_model.config, 'num_key_value_heads', - num_attention_heads) - mha_mode = (num_key_value_heads == num_attention_heads) - layers_range = mapping.pp_layers(hf_model.config.num_hidden_layers) + dtype = getattr(torch, config.dtype) + + mapping = config.mapping + moe_config = config.moe + mha_mode = (config.num_key_value_heads == config.num_attention_heads) + layers_range = config.mapping.pp_layers(config.num_hidden_layers) def convert_layer(l): prefix = f'model.layers.{l}.' @@ -668,14 +702,16 @@ def convert_layer(l): v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype) if not mha_mode: - if num_key_value_heads < tensor_parallel: + if config.num_key_value_heads < mapping.tp_size: # duplicate the KV heads up to tensor_parallel - k_weight = dup_kv_weight(k_weight, num_key_value_heads, - tensor_parallel) - v_weight = dup_kv_weight(v_weight, num_key_value_heads, - tensor_parallel) - assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0 - assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0 + k_weight = dup_kv_weight(k_weight, config.num_key_value_heads, + mapping.tp_size) + v_weight = dup_kv_weight(v_weight, config.num_key_value_heads, + mapping.tp_size) + assert (k_weight.shape[0] % + (mapping.tp_size * config.head_size)) == 0 + assert (v_weight.shape[0] % + (mapping.tp_size * config.head_size)) == 0 wq = split(q_weight, mapping.tp_size, mapping.tp_rank) wk = split(k_weight, mapping.tp_size, mapping.tp_rank) @@ -686,8 +722,9 @@ def convert_layer(l): else: qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) - split_v = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size, - tensor_parallel, mapping.tp_rank) + split_v = split_qkv_tp(qkv_weight, config.num_attention_heads, + config.hidden_size, mapping.tp_size, + mapping.tp_rank) if prefix + 'self_attn.q_proj.bias' in model_params: # only used in Internlm 7B models @@ -695,9 +732,10 @@ def convert_layer(l): k_bias = get_bias(model_params, prefix + 'self_attn.k_proj', dtype) v_bias = get_bias(model_params, prefix + 'self_attn.v_proj', dtype) qkv_bias = torch.cat((q_bias, k_bias, v_bias)) - split_bias_v = split_qkv_bias_tp(qkv_bias, num_attention_heads, - hidden_size, tensor_parallel, - mapping.tp_rank) + split_bias_v = split_qkv_bias_tp(qkv_bias, + config.num_attention_heads, + config.hidden_size, + mapping.tp_size, mapping.tp_rank) else: split_bias_v = None @@ -711,7 +749,8 @@ def convert_layer(l): qkv_weight = qkv_weight.reshape(local_dim, local_dim + 2 * kv_hidden_size) else: - qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size) + qkv_weight = qkv_weight.reshape(config.hidden_size, 3, + config.hidden_size) int8_weights = generate_int8(qkv_weight, act_range.get(prefix + @@ -722,8 +761,8 @@ def convert_layer(l): weights.update( get_tllm_linear_sq_weight(int8_weights, tllm_prex + 'attention.qkv.', - [1, qkv_out_dim // tensor_parallel], - tensor_parallel, + [1, qkv_out_dim // mapping.tp_size], + mapping.tp_size, is_qkv=True, bias=split_bias_v, per_token=per_token, @@ -764,7 +803,7 @@ def convert_layer(l): attn_dense_weight = get_weight(model_params, prefix + 'self_attn.o_proj', dtype) split_v = split_matrix_tp(attn_dense_weight, - tensor_parallel, + mapping.tp_size, mapping.tp_rank, dim=1) @@ -780,8 +819,8 @@ def convert_layer(l): weights.update( get_tllm_linear_sq_weight( int8_weights, - tllm_prex + 'attention.dense.', [1, hidden_size], - tensor_parallel, + tllm_prex + 'attention.dense.', [1, config.hidden_size], + mapping.tp_size, is_qkv=False, bias=attn_dense_bias, per_token=per_token, @@ -789,7 +828,7 @@ def convert_layer(l): last_prefix=tllm_prex + 'attention.quantization_scaling_factor', smoother_value=smoother[(prefix + 'self_attn.o_proj')], - smoother_shape=[1, hidden_size // tensor_parallel], + smoother_shape=[1, config.hidden_size // mapping.tp_size], rank=mapping.tp_rank, cat_dim=0)) else: @@ -799,8 +838,7 @@ def convert_layer(l): plugin_weight_only_quant_type, dtype, use_gemm_woq_plugin)) - if moe_config and moe_config.has_moe(): - + if moe_config.has_moe(): rank_experts = list(range(moe_config.num_experts)) if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL: rank_experts = mapping.ep_experts(moe_config.num_experts) @@ -845,7 +883,7 @@ def convert_layer(l): plugin_weight_only_quant_type, dtype, use_gemm_woq_plugin)) - if residual_mlp: + if config.residual_mlp: residual_mlp_gate_weights = get_weight( model_params, prefix + 'residual_mlp.w3', dtype) if use_smooth_quant: @@ -857,8 +895,8 @@ def convert_layer(l): get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'residual_mlp.gate.', - [1, hidden_size // tensor_parallel], - tensor_parallel, + [1, config.hidden_size // mapping.tp_size], + mapping.tp_size, is_qkv=False, per_token=per_token, per_channel=per_channel, @@ -870,7 +908,7 @@ def convert_layer(l): cat_dim=-1)) else: split_v = split_matrix_tp(residual_mlp_gate_weights, - tensor_parallel, + mapping.tp_size, mapping.tp_rank, dim=0) weights.update( @@ -893,8 +931,8 @@ def convert_layer(l): get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'residual_mlp.fc.', - [1, hidden_size // tensor_parallel], - tensor_parallel, + [1, config.hidden_size // mapping.tp_size], + mapping.tp_size, is_qkv=False, per_token=per_token, per_channel=per_channel, @@ -906,7 +944,7 @@ def convert_layer(l): cat_dim=-1)) else: split_v = split_matrix_tp(residual_mlp_fc_weight, - tensor_parallel, + mapping.tp_size, mapping.tp_rank, dim=0) weights.update( @@ -927,20 +965,23 @@ def convert_layer(l): weights.update( get_tllm_linear_sq_weight( int8_weights, - tllm_prex + 'residual_mlp.proj.', [1, hidden_size], - tensor_parallel, + tllm_prex + 'residual_mlp.proj.', + [1, config.hidden_size], + mapping.tp_size, is_qkv=False, per_token=per_token, per_channel=per_channel, last_prefix=tllm_prex + 'residual_mlp.quantization_scaling_factor', smoother_value=smoother[prefix + 'residual_mlp.w2'], - smoother_shape=[1, hidden_size // tensor_parallel], + smoother_shape=[ + 1, config.hidden_size // mapping.tp_size + ], rank=mapping.tp_rank, cat_dim=0)) else: split_v = split_matrix_tp(residual_mlp_proj_weight, - tensor_parallel, + mapping.tp_size, mapping.tp_rank, dim=1) weights.update( @@ -965,7 +1006,7 @@ def convert_layer(l): mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj', dtype) split_v = split_matrix_tp(mlp_gate_weight, - tensor_parallel, + mapping.tp_size, mapping.tp_rank, dim=0) if use_smooth_quant: @@ -977,8 +1018,8 @@ def convert_layer(l): get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'mlp.gate.', - [1, intermediate_size // tensor_parallel], - tensor_parallel, + [1, config.intermediate_size // mapping.tp_size], + mapping.tp_size, is_qkv=False, per_token=per_token, per_channel=per_channel, @@ -997,7 +1038,7 @@ def convert_layer(l): mlp_fc_weight = get_weight(model_params, prefix + 'mlp.gate_proj', dtype) split_v = split_matrix_tp(mlp_fc_weight, - tensor_parallel, + mapping.tp_size, mapping.tp_rank, dim=0) @@ -1009,8 +1050,8 @@ def convert_layer(l): get_tllm_linear_sq_weight( int8_weights, tllm_prex + 'mlp.fc.', - [1, intermediate_size // tensor_parallel], - tensor_parallel, + [1, config.intermediate_size // mapping.tp_size], + mapping.tp_size, is_qkv=False, per_token=per_token, per_channel=per_channel, @@ -1029,7 +1070,7 @@ def convert_layer(l): mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj', dtype) split_v = split_matrix_tp(mlp_proj_weight, - tensor_parallel, + mapping.tp_size, mapping.tp_rank, dim=1) @@ -1040,8 +1081,8 @@ def convert_layer(l): weights.update( get_tllm_linear_sq_weight( int8_weights, - tllm_prex + 'mlp.proj.', [1, hidden_size], - tensor_parallel, + tllm_prex + 'mlp.proj.', [1, config.hidden_size], + mapping.tp_size, is_qkv=False, per_token=per_token, per_channel=per_channel, @@ -1049,7 +1090,7 @@ def convert_layer(l): 'mlp.quantization_scaling_factor', smoother_value=smoother[prefix + 'mlp.down_proj'], smoother_shape=[ - 1, intermediate_size // tensor_parallel + 1, config.intermediate_size // mapping.tp_size ], rank=mapping.tp_rank, cat_dim=0)) @@ -1069,7 +1110,7 @@ def convert_layer(l): prefix + 'post_attention_layernorm', dtype) weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight - if residual_mlp: + if config.residual_mlp: residual_ln_weight = get_weight(model_params, prefix + 'residual_layernorm', dtype) @@ -1091,46 +1132,41 @@ def convert_layer(l): if hf_model.config.tie_word_embeddings: # lm_head.weight has the same weights as embedding if mapping.is_last_pp_rank(): - if vocab_size % mapping.tp_size != 0: + if config.vocab_size % mapping.tp_size != 0: # padding - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - vocab_size + vocab_size_padded = pad_vocab_size(config.vocab_size, + mapping.tp_size) + pad_width = vocab_size_padded - config.vocab_size v = torch.nn.functional.pad(v, (0, 0, 0, pad_width), 'constant', 0) weights['lm_head.weight'] = split(v, mapping.tp_size, mapping.tp_rank) - if use_parallel_embedding: + if config.use_parallel_embedding: v = split_matrix_tp(v, mapping.tp_size, mapping.tp_rank, - dim=sharding_dim) + dim=config.embedding_sharding_dim) if mapping.is_first_pp_rank(): weights['transformer.vocab_embedding.weight'] = v - # if not use_parallel_embedding: - # weights['transformer.vocab_embedding.weight'] = embed_w - # else: - # assert hf_model.config.vocab_size % tensor_parallel == 0 - # weights['transformer.vocab_embedding.weight'] = split_matrix_tp( - # embed_w, tensor_parallel, rank - lm_head_weights = get_weight(model_params, 'lm_head', dtype) if mapping.is_last_pp_rank(): - if vocab_size % mapping.tp_size != 0: + if config.vocab_size % mapping.tp_size != 0: # padding - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - vocab_size + vocab_size_padded = pad_vocab_size(config.vocab_size, + mapping.tp_size) + pad_width = vocab_size_padded - config.vocab_size lm_head_weights = torch.nn.functional.pad(lm_head_weights, (0, 0, 0, pad_width), 'constant', value=0) weights['lm_head.weight'] = split_matrix_tp(lm_head_weights, - tensor_parallel, + mapping.tp_size, mapping.tp_rank, dim=0) ln_f_w = get_weight(model_params, 'model.norm', dtype) @@ -1143,9 +1179,8 @@ def convert_layer(l): def smooth_quant(model, - model_dir, - calib_dataset, - dataset_cache_dir, + tokenizer, + dataset, smoothquant: Optional[float] = None): assert model is not None act_range = {} @@ -1153,14 +1188,6 @@ def smooth_quant(model, # smoother for inputs of self_attn.o_proj and mlp.down_proj llama_smoother = {} - os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( - "TOKENIZERS_PARALLELISM", "false") - tokenizer = AutoTokenizer.from_pretrained(model_dir, - trust_remote_code=True, - use_fast=False, - padding_side='left') - dataset = load_calib_dataset(calib_dataset, cache_dir=dataset_cache_dir) - act_range = capture_activation_range(model, tokenizer, dataset) if smoothquant is not None: smooth_llama_model(model, act_range, smoothquant, llama_qkv_para, @@ -1168,288 +1195,859 @@ def smooth_quant(model, return act_range, llama_qkv_para, llama_smoother -def create_config_from_hugging_face(hf_model, - dtype, - mapping, - quantization: QuantConfig = None, - override_fields: dict = {}): - config = {} - hf_config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True) - if hf_config.model_type == "llava": - # LLaVA = Vision model + Llama LLM - # We load a llava config and use its' text config as llama config - hf_config = LlavaConfig.from_pretrained(hf_model).text_config - if hf_config.model_type == "llava_llama": - hf_config.llm_cfg["architecture"] = hf_config.llm_cfg["architectures"] - hf_config.llm_cfg["dtype"] = hf_config.llm_cfg["torch_dtype"] - hf_config = PretrainedConfig.from_dict(hf_config.llm_cfg) - # TODO: directly assign the hf_config fields to the config dict w/o creating these local vars - # same for from_meta and from_cli_args - n_head = hf_config.num_attention_heads - inter_size = hf_config.intermediate_size - n_layer = hf_config.num_hidden_layers - n_embd = hf_config.hidden_size - n_kv_head = getattr(hf_config, "num_key_value_heads", n_head) - rms_norm_eps = hf_config.rms_norm_eps - vocab_size = hf_config.vocab_size - n_positions = hf_config.max_position_embeddings - hidden_act = hf_config.hidden_act - config['rotary_scaling'] = getattr(hf_config, "rope_scaling", None) - rotary_base = getattr(hf_config, "rope_theta", 10000.0) - config['residual_mlp'] = getattr(hf_config, "parallel_attn_mlp_res", False) - if hf_config.model_type == "mixtral" or hf_config.model_type == "arctic": - # HF LLaMA-type models are implicitly using gated activation. - # With our MoE implementation, we must make it explicit - hidden_act = "swiglu" - config[ - 'moe_normalization_mode'] = MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE - else: - config['moe_normalization_mode'] = None - moe_num_experts = getattr(hf_config, "num_local_experts", 0) - moe_top_k = getattr(hf_config, "num_experts_per_tok", 0) - moe_tp_mode = MoeConfig.ParallelismMode.TENSOR_PARALLEL - architecture = hf_config.architectures[0] - # VILA model, force to use llama config - if hf_config.model_type == "llava_llama": - architecture = "LlamaForCausalLM" - attn_bias = getattr(hf_config, 'bias', False) or getattr( - hf_config, 'attention_bias', False) - - config.update({ - 'architecture': architecture, - 'dtype': dtype, - 'logits_dtype': 'float32', - 'num_hidden_layers': n_layer, - 'num_attention_heads': n_head, - 'hidden_size': n_embd, - 'intermediate_size': inter_size, - 'num_key_value_heads': n_kv_head, - 'vocab_size': vocab_size, - 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': n_positions, - 'hidden_act': hidden_act, - 'rotary_base': rotary_base, - 'norm_epsilon': rms_norm_eps, - 'moe_num_experts': moe_num_experts, - 'moe_top_k': moe_top_k, - 'moe_tp_mode': moe_tp_mode, - #TODO: should have directly map from the Mapping object to the TRT-LLM checkpoint fields - 'mapping': { - 'world_size': mapping.tp_size * mapping.pp_size, - 'tp_size': mapping.tp_size, - 'pp_size': mapping.pp_size - }, - 'attn_bias': attn_bias, - }) - config['quantization'] = quantization.asdict() - config.update(override_fields) - - moe_config = MoeConfig(config['moe_num_experts'], config['moe_top_k'], - config['moe_tp_mode'], - config['moe_normalization_mode']).validate() - use_weight_only = config['quantization']['quant_algo'] in [ - QuantAlgo.W8A16, QuantAlgo.W4A16, QuantAlgo.FP8 - ] - if use_weight_only and moe_config.has_moe(): - config['quantization']['exclude_modules'].append('router') - - return config - - -def from_hugging_face(cls, - model_dir, - dtype, - *, - mapping, - quantization: QuantConfig = None, - load_by_shard=False, - load_model_on_cpu=False, - override_fields={}, - skip_loading_weights=False, - preloaded_model=None): - ''' Create a LLaMAForCausalLM object from give parameters - ''' - assert model_dir is not None - if isinstance(model_dir, Path): # some code relies on this as string - model_dir = str(model_dir) - - if override_fields.get('share_embedding_table', False): - logger.warning( - "Llama model does not support share_embedding_table; setting share_embedding_table=False" - ) - override_fields['share_embedding_table'] = False - - config = create_config_from_hugging_face(model_dir, - dtype, - mapping, - quantization, - override_fields=override_fields) - - pretrained_config = PretrainedConfig.from_dict(config) - pretrained_config.set_rank(mapping.rank) # TODO:remove this hack - - llama = cls.from_config(pretrained_config) - llama = optimize_model( - llama, - use_parallel_embedding=pretrained_config.use_parallel_embedding, - share_embedding_table=pretrained_config.share_embedding_table, - ) - - if skip_loading_weights: - return llama - - model = preloaded_model - if model is None and not load_by_shard: # when load by shard, no need to create complete hf model - have_safetensors = any( - [f.endswith(".safetensors") for f in os.listdir(model_dir)]) - hf_config = AutoConfig.from_pretrained(model_dir, - trust_remote_code=True) - if hf_config.model_type == "llava": - hf_llava = LlavaForConditionalGeneration.from_pretrained( - model_dir, torch_dtype="auto") - model = hf_llava.language_model - else: - # TODO: Remove WAR after `load_from_hf_safetensors` supports weight-only quantization - if not have_safetensors or config['quantization'][ - 'quant_algo'] is not None: - model = AutoModelForCausalLM.from_pretrained( - model_dir, - device_map='auto' if not load_model_on_cpu else 'cpu', - torch_dtype='auto', - trust_remote_code=True, - ) - - if load_by_shard: - weights = load_from_hf_checkpoint(model_dir, mapping, pretrained_config) - elif model is not None: - weights = load_weights_from_hf(config=config, - mapping=mapping, - model=model) - else: - weights = load_from_hf_safetensors(model_dir=model_dir, - config=pretrained_config, - mapping=mapping) - - llama.load(weights) - return llama - - -def quantize(dtype, - model_dir, - output_dir, - mapping, - quantization: QuantConfig, - *, - calib_dataset='cnn_dailymail', - override_fields={}, - dataset_cache_dir: Optional[str] = None): +def quantize(hf_model_dir: str, + output_dir: str, + config: LLaMAConfig, + calib_dataset='cnn_dailymail'): ''' Quantize the save the model as TRT-LLM checkpoint to output_dir ''' #TODO: currently only smooth quant and kv cache quantization are supported, needs to support mode quant algorithm calling modelopt - config = create_config_from_hugging_face(model_dir, - dtype, - mapping, - quantization, - override_fields=override_fields) with open(os.path.join(output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) + json.dump(config.to_dict(), f, indent=4) + + mapping = config.mapping assert mapping.rank == -1, "You shall call quantize only once in one rank, assert rank==-1 for precaution" - act_range = {} - llama_qkv_para = {} - # smoother for inputs of self_attn.o_proj and mlp.down_proj - llama_smoother = {} - model = None - assert config['quantization']['quant_algo'] == quantization.quant_algo - int8_kv_cache = quantization.kv_cache_quant_algo == QuantAlgo.INT8 - use_smooth_quant = quantization.quant_algo is not None and quantization.quant_algo.startswith( - 'W8A8_SQ') + quant_config = config.quantization + + use_smooth_quant = quant_config.use_plugin_sq + int8_kv_cache = quant_config.kv_cache_quant_algo == QuantAlgo.INT8 assert use_smooth_quant or int8_kv_cache, "Call from_hugging_face when there is no quantization" if use_smooth_quant: - assert quantization.smoothquant_val is not None, "A smooth value must be specified when using smooth quant" + assert quant_config.smoothquant_val is not None, "A smooth value must be specified when using smooth quant" - assert model_dir is not None + assert hf_model_dir is not None ## only load and call smooth quant routine once for all ranks - hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(hf_model_dir, trust_remote_code=True) assert "llava" not in hf_config.model_type, "Smooth quant llava/vila is not supported yet" - model = AutoModelForCausalLM.from_pretrained( - model_dir, + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_dir, device_map='auto', torch_dtype='auto' if not use_smooth_quant else torch.float16, trust_remote_code=True) - act_range, llama_qkv_para, llama_smoother = smooth_quant( - model, model_dir, calib_dataset, dataset_cache_dir, - quantization.smoothquant_val) + + os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( + "TOKENIZERS_PARALLELISM", "false") + tokenizer = AutoTokenizer.from_pretrained(hf_model_dir, + trust_remote_code=True, + use_fast=False, + padding_side='left') + + dataset = load_calib_dataset(calib_dataset) + + act_range, qkv_para, smoother = smooth_quant(hf_model, tokenizer, dataset, + quant_config.smoothquant_val) for rank in range(mapping.world_size): # To avoid changing the mapping arg in-place, also the given mapping from caller is rank agnostic, since quantize is called from only one rank - ranked_mapping = Mapping(world_size=mapping.world_size, - rank=rank, - tp_size=mapping.tp_size, - pp_size=mapping.pp_size) - weights = load_weights_from_hf( + config = copy.deepcopy(config) + config.set_rank(rank) + weights = load_weights_from_hf_model( + hf_model, config=config, - mapping=ranked_mapping, - model=model, - # for smooth quant only act_range=act_range, - llama_qkv_para=llama_qkv_para, - llama_smoother=llama_smoother, + qkv_para=qkv_para, + smoother=smoother, ) safetensors.torch.save_file( weights, os.path.join(output_dir, f'rank{rank}.safetensors')) del weights -def load_weights_from_hf(*, - config, - mapping, - model, - act_range={}, - llama_qkv_para={}, - llama_smoother={}): - #TODO: simplify the parameters here - - assert model is not None - plugin_weight_only_quant_type = None # the value does not matter when use_weight_only is False - quant_algo = config['quantization']['quant_algo'] +class QkvWeightHelper: + """ A helper utility for loading QKV weights from sharded files. """ + + def __init__(self, config: PretrainedConfig): + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.tp_size = config.mapping.tp_size + self.tp_rank = config.mapping.tp_rank + self.is_mha = self.num_heads == self.num_kv_heads + self._qkv_weights = {} + + @staticmethod + def is_qkv_weight(name): + for k in ['q_proj', 'k_proj', 'v_proj']: + if 'self_attn' in name and k in name: + return True + return False + + def add_weight(self, i: int, name: str, weight: torch.Tensor): + if 'q_proj' in name: + tag = 'q' + elif 'k_proj' in name: + tag = 'k' + elif 'v_proj' in name: + tag = 'v' + else: + raise ValueError(f'Got an unexpected parameter of name {name}') + if i not in self._qkv_weights: + self._qkv_weights[i] = {} + self._qkv_weights[i][tag] = weight + + def is_qkv_prepared(self, layer_idx): + if layer_idx not in self._qkv_weights: + return False + weights = self._qkv_weights[layer_idx] + return 'q' in weights and 'k' in weights and 'v' in weights + + def split_qkv_weights(self, layer_idx): + if not self.is_qkv_prepared(layer_idx): + return None + weights = self._qkv_weights.pop(layer_idx) # to prevent memory leak. + q, k, v = (torch.tensor(weights[t]) for t in ['q', 'k', 'v']) + + if not self.is_mha: + head_size = self.hidden_size // self.num_heads + if self.num_kv_heads < self.tp_size: + # duplicate the KV heads up to tensor_parallel + k = dup_kv_weight(k, self.num_kv_heads, self.tp_size) + v = dup_kv_weight(v, self.num_kv_heads, self.tp_size) + assert k.shape[0] % (self.tp_size * head_size) == 0 + assert v.shape[0] % (self.tp_size * head_size) == 0 + wq = split(q, self.tp_size, self.tp_rank) + wk = split(k, self.tp_size, self.tp_rank) + wv = split(v, self.tp_size, self.tp_rank) + fused_qkv = torch.cat((wq, wk, wv), dim=0) + else: + qkv = torch.cat([q, k, v], dim=0) + qkv = qkv.reshape(3, q.shape[0], q.shape[1]) + fused_qkv = split(qkv, self.tp_size, self.tp_rank, dim=1) + fused_qkv = fused_qkv.reshape(3 * (q.shape[0] // self.tp_size), + q.shape[1]) + return fused_qkv + + +def load_weights_from_hf_by_shard(model_dir: str, config: LLaMAConfig): + '''Weights-only quantization is the only supported quantization recipe here.''' + logger.info('Loading weights from HF LLaMA...') + quant_algo = config.quantization.quant_algo + use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16] if quant_algo == QuantAlgo.W8A16: plugin_weight_only_quant_type = torch.int8 elif quant_algo == QuantAlgo.W4A16: plugin_weight_only_quant_type = torch.quint4x2 + else: + plugin_weight_only_quant_type = None + + weights = {} + tik = time.time() + dtype = getattr(torch, config.dtype) + + mapping = config.mapping + moe_config = config.moe + assert not moe_config.has_moe(), "MoE does not support sharded load" + + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(model_dir) + + quant_mode = config.quant_mode + if quant_mode.is_int8_weight_only(): + plugin_weight_only_quant_type = torch.int8 + elif quant_mode.is_int4_weight_only(): + plugin_weight_only_quant_type = torch.quint4x2 + else: + plugin_weight_only_quant_type = None + use_weight_only = quant_mode.is_weight_only() - moe_config = MoeConfig(config['moe_num_experts'], config['moe_top_k'], - config['moe_tp_mode'], - config['moe_normalization_mode']).validate() + layers_range = mapping.pp_layers(config.num_hidden_layers) - use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16] - use_smooth_quant = quant_algo is not None and quant_algo.startswith( - 'W8A8_SQ') - per_channel_sq = use_smooth_quant and 'PER_CHANNEL' in quant_algo - per_token_sq = use_smooth_quant and 'PER_TOKEN' in quant_algo - use_int8_kv_cache = config['quantization'][ - 'kv_cache_quant_algo'] == QuantAlgo.INT8 - weights = convert_hf_llama( - model, - mapping, - vocab_size=config['vocab_size'], - dtype=config['dtype'], - use_weight_only=use_weight_only, - use_gemm_woq_plugin=not config.get('disable_weight_only_quant_plugin', - False), - plugin_weight_only_quant_type=plugin_weight_only_quant_type, - use_parallel_embedding=config.get('use_parallel_embedding', False), - sharding_dim=config.get('embedding_sharding_dim', 0), - share_embedding_table=config.get('share_embedding_table', False), - residual_mlp=config['residual_mlp'], - use_smooth_quant=use_smooth_quant, - per_channel=per_channel_sq, - per_token=per_token_sq, - int8_kv_cache=use_int8_kv_cache, - act_range=act_range, - qkv_para=llama_qkv_para, - smoother=llama_smoother, - moe_config=moe_config) + qkv_weight_helper = QkvWeightHelper(config) + + for model_file in iterate_shard_files(model_dir, + rank=mapping.tp_rank, + progress_bar=False): + logger.debug(f'Loading file {str(model_file)}...') + model_params = load_state_dict(model_file, dtype=dtype) + for name, param in model_params.items(): + logger.debug(f'Converting weight {name}...') + layer_idx = retrieved_layer_index_from_name(name) + if layer_idx is None: + layer = None + else: + if layer_idx not in layers_range: + continue + tllm_prex = f'transformer.layers.{layer_idx}.' + + if 'model.embed_tokens.weight' in name: + if hf_config.tie_word_embeddings: + # lm_head.weight has the same weights as embedding + if mapping.is_last_pp_rank(): + + if config.vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size( + config.vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - config.vocab_size + param = torch.from_numpy( + np.pad(param.detach().cpu().numpy(), + ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + weights['lm_head.weight'] = split( + param, mapping.tp_size, mapping.tp_rank) + if config.use_parallel_embedding: + param = split(param, mapping.tp_size, mapping.tp_rank, + config.embedding_sharding_dim) + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = param + elif 'model.norm.weight' in name: + if mapping.is_last_pp_rank(): + weights['transformer.ln_f.weight'] = param + elif 'lm_head.weight' in name: + if mapping.is_last_pp_rank(): + if config.vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size( + config.vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - config.vocab_size + param = torch.from_numpy( + np.pad(param.detach().cpu().numpy(), + ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + weights['lm_head.weight'] = split(param, mapping.tp_size, + mapping.tp_rank) + elif 'input_layernorm.weight' in name: + weights[tllm_prex + 'input_layernorm.weight'] = param + elif 'post_attention_layernorm.weight' in name: + weights[tllm_prex + 'post_layernorm.weight'] = param + elif qkv_weight_helper.is_qkv_weight(name): + qkv_weight_helper.add_weight(layer_idx, name, param) + if not qkv_weight_helper.is_qkv_prepared(layer_idx): + continue + split_v = qkv_weight_helper.split_qkv_weights(layer_idx) + if use_weight_only: + param = split_v.transpose() + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + param, plugin_weight_only_quant_type) + weights[tllm_prex + + 'attention.qkv.weight'] = processed_torch_weights + weights[ + tllm_prex + + 'attention.qkv.per_channel_scale'] = torch_weight_scales + else: + weights[tllm_prex + 'attention.qkv.weight'] = split_v + elif 'self_attn.o_proj.weight' in name: + split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=1) + if use_weight_only: + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + split_v.transpose(), plugin_weight_only_quant_type) + weights[tllm_prex + + 'attention.dense.weight'] = processed_torch_weights + weights[ + tllm_prex + + 'attention.dense.per_channel_scale'] = torch_weight_scales + else: + weights[tllm_prex + 'attention.dense.weight'] = split_v + elif 'mlp.up_proj.weight' in name: + split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=0) + if use_weight_only: + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + split_v.transpose(), plugin_weight_only_quant_type) + weights[tllm_prex + + 'mlp.gate.weight'] = processed_torch_weights + weights[tllm_prex + + 'mlp.gate.per_channel_scale'] = torch_weight_scales + else: + weights[tllm_prex + 'mlp.gate.weight'] = split_v + elif 'mlp.down_proj.weight' in name: + split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=1) + if use_weight_only: + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + split_v.transpose(), plugin_weight_only_quant_type) + weights[tllm_prex + + 'mlp.proj.weight'] = processed_torch_weights + weights[tllm_prex + + 'mlp.proj.per_channel_scale'] = torch_weight_scales + else: + weights[tllm_prex + 'mlp.proj.weight'] = split_v + + elif 'mlp.gate_proj.weight' in name: + split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=0) + if use_weight_only: + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + split_v.transpose(), plugin_weight_only_quant_type) + layer.mlp.fc.weight.value = processed_torch_weights + layer.mlp.fc.per_channel_scale.value = torch_weight_scales + weights[tllm_prex + + 'mlp.fc.weight'] = processed_torch_weights + weights[tllm_prex + + 'mlp.fc.per_channel_scale'] = torch_weight_scales + else: + weights[tllm_prex + 'mlp.fc.weight'] = split_v + + del model_params + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + logger.info(f'Weights loaded. Total time: {t}') + return weights + + +def load_weights_from_hf_safetensors(model_dir: str, config: LLaMAConfig): + logger.info('Loading weights from Huggingface LLaMA safetensors...') + tik = time.time() + import json + import os + + import safetensors + weights = {} + + model_dir = model_dir if model_dir.endswith("/") else model_dir + "/" + safetensors_map = {} + try: + with open(model_dir + "model.safetensors.index.json", 'r') as fr: + sharding_map = json.load(fr) + for k, v in sharding_map['weight_map'].items(): + safetensors_map[k] = int(v[6:11]) - 1 + except FileNotFoundError: + pass + shard_files = [] + for name in os.listdir(model_dir): + if name.endswith(".safetensors"): + shard_files.append(name) + shard_files.sort() + safetensors_ptrs = [ + safetensors.safe_open(model_dir + shard_file, + framework="pt", + device="cpu") for shard_file in shard_files + ] + + mapping = config.mapping + num_hidden_layers = config.num_hidden_layers + vocab_size = config.vocab_size + pad_vocab = vocab_size % mapping.tp_size != 0 + vocab_size_padded = pad_vocab_size(config.vocab_size, mapping.tp_size) + dtype = config.dtype + + moe_config = config.moe + + model_prefix = "model." + key_list = [ + "embed_tokens.weight", # vocab_embedding + "lm_head.weight", # lm_head + "norm.weight", # ln_f + "self_attn.", # attention.qkv + "_proj.weight", # qkv suffix + "self_attn.o_proj.weight", # attention.dense + "mlp.up_proj.weight", # mlp.gate + "mlp.down_proj.weight", # mlp.proj + "mlp.gate_proj.weight", # mlp.fc + "input_layernorm.weight", # input_layernorm + "post_attention_layernorm.weight", # post_layernorm + ] + + torch_dtype = str_dtype_to_torch(dtype) + + def load(key, tp_dim=-1, no_prefix=0): + if not no_prefix: + key = model_prefix + key + ptr_idx = safetensors_map[key] if key in safetensors_map else 0 + if tp_dim == -1: + res = safetensors_ptrs[ptr_idx].get_tensor(key) + else: + tensor_slice = safetensors_ptrs[ptr_idx].get_slice(key) + tensor_shape = tensor_slice.get_shape() + if tensor_shape[tp_dim] % mapping.tp_size != 0: + logger.error( + "Current weight shape is invalid for mapping.tp_size=" + + str(mapping.tp_size)) + slice_width = tensor_shape[tp_dim] // mapping.tp_size + if tp_dim == 0: + res = tensor_slice[slice_width * mapping.tp_rank:slice_width * + (mapping.tp_rank + 1), :] + elif tp_dim == 1: + res = tensor_slice[:, + slice_width * mapping.tp_rank:slice_width * + (mapping.tp_rank + 1)] + else: + assert False, "Invalid TP dim" + return res.to(torch_dtype).contiguous( + ) if "block_sparse_moe.gate" not in key else res.to(torch.float32) + + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = load( + key_list[0], config.embedding_sharding_dim + if config.use_parallel_embedding else -1) # vocab_embedding + + if mapping.is_last_pp_rank(): + v = load(key_list[1], -1, 1) if pad_vocab else load(key_list[1], 0, + 1) # lm_head + if pad_vocab: + v = torch.nn.functional.pad( + v, (0, 0, 0, vocab_size_padded - vocab_size), 'constant', 0) + v = split(v, mapping.tp_size, mapping.tp_rank) + weights['lm_head.weight'] = v + weights['transformer.ln_f.weight'] = load(key_list[2]) # ln_f + + layers_range = mapping.pp_layers(num_hidden_layers) + for l in layers_range: + layer_idx = l - layers_range[0] + prefix = f'layers.{l}.' + tllm_prex = f'transformer.layers.{layer_idx}' + + # Attention + qkv_list = [] + for comp in ["q", "k", "v"]: + comp_part = load(prefix + key_list[3] + comp + key_list[4], 0) + qkv_list.append(comp_part) + weights[f'{tllm_prex}.attention.qkv.weight'] = torch.cat(qkv_list, 0) + weights[f'{tllm_prex}.attention.dense.weight'] = load( + prefix + key_list[5], 1) # attention.dense + + # MLP + if not moe_config.has_moe(): + weights[f'{tllm_prex}.mlp.gate.weight'] = load( + prefix + key_list[6], 0) # mlp.gate + weights[f'{tllm_prex}.mlp.proj.weight'] = load( + prefix + key_list[7], 1) # mlp.proj + weights[f'{tllm_prex}.mlp.fc.weight'] = load( + prefix + key_list[8], 0) # mlp.fc + + else: + weights[f'{tllm_prex}.mlp.router.weight'] = load( + prefix + 'block_sparse_moe.gate.weight') + rank_experts = list(range(moe_config.num_experts)) + if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL: + rank_experts = mapping.ep_experts(moe_config.num_experts) + + expert_weight_list = [] + for suffix in range(3): + tp_dim = -1 + if moe_config.tp_mode == moe_config.ParallelismMode.TENSOR_PARALLEL: + tp_dim = 1 if suffix == 1 else 0 + expert_weight_list.append( + torch.stack( + list( + load( + prefix + + f'block_sparse_moe.experts.{expert}.w{suffix + 1}.weight', + tp_dim=tp_dim) for expert in rank_experts))) + + w1 = expert_weight_list[0] + w2 = expert_weight_list[1] + w3 = expert_weight_list[2] + + weights[f'{tllm_prex}.mlp.fc.weight'] = \ + torch.concat([w3, w1], dim=-2).contiguous() + weights[f'{tllm_prex}.mlp.proj.weight'] = w2.contiguous() + + weights[f'{tllm_prex}.input_layernorm.weight'] = load( + prefix + key_list[9]) # input_layernorm + weights[f'{tllm_prex}.post_layernorm.weight'] = load( + prefix + key_list[10]) # post_layernorm + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + logger.info(f'Weights loaded. Total time: {t}') + + return weights + + +def load_weights_from_gptq(quant_ckpt_path: str, config: LLaMAConfig): + logger.info('Loading weights from groupwise GPTQ LLaMA safetensors...') + weights = {} + tik = time.time() + + num_hidden_layers = config.num_hidden_layers + vocab_size = config.vocab_size + dtype = config.dtype + mapping = config.mapping + + gptq_llama = safetensors.safe_open(quant_ckpt_path, + framework="pt", + device=0) + gptq_prefix = "model." + gptq_suffix_list = [".qweight", ".qzeros", ".scales"] + gptq_key_list = [ + "embed_tokens.weight", # vocab_embedding + "lm_head.weight", # lm_head + "norm.weight", # ln_f + "self_attn.", # attention.qkv + "_proj", # qkv suffix + "self_attn.o_proj", # attention.dense + "mlp.up_proj", # mlp.gate + "mlp.down_proj", # mlp.proj + "mlp.gate_proj", # mlp.fc + "input_layernorm.weight", # input_layernorm + "post_attention_layernorm.weight", # post_layernorm + ] + split_sym = "." + + packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 + preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm + torch_dtype = str_dtype_to_torch(dtype) + + def load(key, no_prefix=0): + if no_prefix: + return gptq_llama.get_tensor(key) + else: + return gptq_llama.get_tensor(gptq_prefix + key) + + def torch_split(v, dim): + if v.shape[dim] % mapping.tp_size != 0: + logger.error( + "Current weight shape is invalid for mapping.tp_size=" + + str(mapping.tp_size)) + assert False, "Invalid TP size" + return v.split(v.shape[dim] // mapping.tp_size, + dim=dim)[mapping.tp_rank] + + def unpack_int32_into_int8(w_packed): + # Unpack inputs packed in int32/float32 into uint4 and store them in int8 format + w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) + w_unpacked = torch.zeros(w_packed_int4x2.shape[0], + w_packed_int4x2.shape[1] * 2, + dtype=torch.int8) + w_unpacked[:, ::2] = w_packed_int4x2 % 16 + w_unpacked[:, 1::2] = w_packed_int4x2 // 16 + return w_unpacked.contiguous() + + def process_and_assign_weight(v: List[torch.Tensor], + tllm_prex: str, + tp_dim: int = -1): + if tp_dim == -1: + qweight_int32, qzeros_int32, scales_fp16 = [ + item.cpu() for item in v + ] + else: + qweight_int32, qzeros_int32, scales_fp16 = [ + torch_split(item, tp_dim).cpu() for item in v + ] + + USE_UINT4_INPUT = 1 # Set to true if checkpoint store UINT4 weights + USE_GPTQ_FOR_LLAMA = 1 # GPTQ-for-LLaMA added 1 to zeros + + qweight_unpacked_int8 = unpack_int32_into_int8( + qweight_int32.T).T.contiguous() - 8 + qweight_interleaved = preprocessor(packer(qweight_unpacked_int8), + torch.quint4x2, + torch.float16).view(torch.float16) + # zeros = zeros * scales + qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32) + if not USE_UINT4_INPUT: + # Correcting UINT4 values back to INT4 order + mask_negative = qzeros_unpacked_int32[qzeros_unpacked_int32 < 0] + mask_positive = qzeros_unpacked_int32[qzeros_unpacked_int32 >= 0] + qzeros_unpacked_int32 = qzeros_unpacked_int32 + 16 * mask_negative - 16 * mask_positive + zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * USE_UINT4_INPUT - + USE_GPTQ_FOR_LLAMA) * scales_fp16 + zeros_x_scales_fp16 = zeros_x_scales_fp16.half() + + results = { + f'{tllm_prex}.weight': qweight_interleaved, + f'{tllm_prex}.weights_scaling_factor': scales_fp16, + f'{tllm_prex}.zero': zeros_x_scales_fp16, + } + return results + + # Load weights from GPTQ checkpoint into TRT-LLM module + # 1. vocab_embedding + v = load(gptq_key_list[0]) + if mapping.is_first_pp_rank(): + # tensorrt_llm_llama.vocab_embedding.weight.value = v.to( + # torch_dtype).cpu().numpy() + weights['transformer.vocab_embedding.weight'] = v.to(torch_dtype) + # 2. lm_head + v = load(gptq_key_list[1], "no_prefix") + if mapping.is_last_pp_rank(): + # tensorrt_llm_llama.lm_head.weight.value = torch_split( + # v, 0).to(torch_dtype).cpu().numpy() + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) + pad_width = vocab_size_padded - vocab_size + v = torch.from_numpy( + np.pad(v.detach().cpu().numpy(), ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + weights['lm_head.weight'] = torch_split(v, 0).to(torch_dtype) + + # 3. ln_f + v = load(gptq_key_list[2]) + if mapping.is_last_pp_rank(): + # tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() + weights['transformer.ln_f.weight'] = v.to(torch_dtype) + # 4. Weights inside each layer + layers_range = mapping.pp_layers(num_hidden_layers) + for l in layers_range: + layer_idx = l - layers_range[0] + prefix = "layers" + split_sym + str(layer_idx) + split_sym + logger.info(f'Process weights in layer: {layer_idx}') + # layer = tensorrt_llm_llama.layers[layer_idx] + tllm_prex = f'transformer.layers.{l-layers_range[0]}' + # 4.1 attention.qkv + qkv_weight_list = [] + for suf in gptq_suffix_list: + qkv_list = [] + for comp in ["q", "k", "v"]: + comp_part = load(prefix + gptq_key_list[3] + comp + + gptq_key_list[4] + suf) + comp_part = torch_split(comp_part, 1) + qkv_list.append(comp_part) + qkv_weight_list.append(torch.cat(qkv_list, dim=1)) + + # process_and_assign_weight(layer.attention.qkv, qkv_weight_list) + weights.update( + process_and_assign_weight(qkv_weight_list, + f'{tllm_prex}.attention.qkv')) + # 4.2 attention.dense + v = [load(prefix + gptq_key_list[5] + suf) for suf in gptq_suffix_list] + # process_and_assign_weight(layer.attention.dense, v, 0) + weights.update( + process_and_assign_weight(v, + f'{tllm_prex}.attention.dense', + tp_dim=0)) + # 4.3 mlp.gate + v = [load(prefix + gptq_key_list[6] + suf) for suf in gptq_suffix_list] + # process_and_assign_weight(layer.mlp.gate, v, 1) + weights.update( + process_and_assign_weight(v, f'{tllm_prex}.mlp.gate', tp_dim=1)) + # 4.4 mlp.proj + v = [load(prefix + gptq_key_list[7] + suf) for suf in gptq_suffix_list] + # process_and_assign_weight(layer.mlp.proj, v, 0) + weights.update( + process_and_assign_weight(v, f'{tllm_prex}.mlp.proj', tp_dim=0)) + # 4.5 mlp.fc + v = [load(prefix + gptq_key_list[8] + suf) for suf in gptq_suffix_list] + # process_and_assign_weight(layer.mlp.fc, v, 1) + weights.update( + process_and_assign_weight(v, f'{tllm_prex}.mlp.fc', tp_dim=1)) + # 4.6 input_layernorm + v = load(prefix + gptq_key_list[9]) + # layer.input_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() + weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype) + + # 4.7 post_layernorm + v = load(prefix + gptq_key_list[10]) + # layer.post_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() + weights[f'{tllm_prex}.post_layernorm.weight'] = v.to(torch_dtype) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + logger.info(f'Weights loaded. Total time: {t}') + + return weights + + +def load_weights_from_meta_ckpt(meta_ckpt_dir: str, config: LLaMAConfig): + torch_dtype = str_dtype_to_torch(config.dtype) + mapping = config.mapping + weights = {} + + def gather_ckpts(ckpts): + gathered = {} + for k in ckpts[0]: + d = 0 + if any([n in k for n in ["wo", "w2", "tok"]]): + d = 1 + if "norm" in k or "rope" in k: # no TP + gathered[k] = ckpts[0][k].clone() + else: + gathered[k] = torch.cat([pt[k] for pt in ckpts], dim=d).clone() + return gathered + + def split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank): + split_ckpt = {} + for k, v in ckpt.items(): + d = 0 + if any(n in k for n in + ["wo", "feed_forward.w2", "tok", "feed_forward.gate"]): + d = 1 + if "norm" in k or "rope" in k: # no TP + split_ckpt[k] = v.clone() + elif config.num_key_value_heads < mapping.tp_size and any( + n in k for n in ["wk", "wv"]): + assert mapping.tp_size % config.num_key_value_heads == 0 + # special case: we need to duplicate KV head + tmp = dup_kv_weight(v, config.num_key_value_heads, + mapping.tp_size) + split_ckpt[k] = torch.split(tmp, + tmp.shape[d] // ranks_per_ckpt, + dim=d)[ckpt_rank].clone() + else: + split_ckpt[k] = torch.split(v, + v.shape[d] // ranks_per_ckpt, + dim=d)[ckpt_rank].clone() + return split_ckpt + + def get_current_weights(num_ckpts): + if num_ckpts > mapping.tp_size: + # combine ckpts + assert (num_ckpts % mapping.tp_size) == 0 + nf = num_ckpts // mapping.tp_size + fs = nf * mapping.tp_rank + file_ids = list(range(fs, fs + nf)) + ckpts = [] + for f in file_ids: + ckpt = torch.load(Path(meta_ckpt_dir, + f"consolidated.{f:02d}.pth"), + map_location="cpu") + ckpts.append(ckpt) + return gather_ckpts(ckpts) + elif num_ckpts < mapping.tp_size: + # split ckpt + assert (mapping.tp_size % num_ckpts) == 0 + ranks_per_ckpt = mapping.tp_size // num_ckpts + ckpt_fid = mapping.tp_rank // ranks_per_ckpt + ckpt_rank = mapping.tp_rank % ranks_per_ckpt + nH_per_ckpt = config.num_attention_heads // num_ckpts + assert (nH_per_ckpt % ranks_per_ckpt) == 0 + ckpt = torch.load(Path(meta_ckpt_dir, + f"consolidated.{ckpt_fid:02d}.pth"), + map_location="cpu") + return split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank) + + # num_ckpts == tensor_parallel, 1:1 mapping from files to TP + return torch.load(Path(meta_ckpt_dir, + f"consolidated.{mapping.tp_rank:02d}.pth"), + map_location="cpu") + + def permute(w, nH, d, dH): + # due to MQA's wk, nH*dH != d could be true + return w.view(nH, dH // 2, 2, d).transpose(1, 2).reshape(nH * dH, d) + + def extract_layer_idx(name): + ss = name.split('.') + for s in ss: + if s.isdigit(): + return s + return None + + if not hasattr(load_weights_from_meta_ckpt, "saved_embed"): + load_weights_from_meta_ckpt.saved_embed = None + + def combine_embeddings(embeds, num_ckpts): + if len(embeds) == 1: + return embeds[0] + assert [ + embeds[i].shape == embeds[i + 1].shape + for i in range(len(embeds) - 1) + ] + if embeds[0].shape[0] == config.vocab_size // num_ckpts: + merge_dim = 0 + elif embeds[0].shape[1] == config.hidden_size // num_ckpts: + merge_dim = 1 + else: + logger.error("Unable to infer embedding split dimension") + assert False, "Unable to infer embedding split dimension" + return torch.cat(embeds, dim=merge_dim) + + def gather_embedding(cur_embed, name: str, num_ckpts): + if mapping.tp_size == 1: + # even if num_ckpts > 1, get_current_weights will already have it gathered + return cur_embed + if load_weights_from_meta_ckpt.saved_embed is None: + embeds = [None] * num_ckpts + for i in range(num_ckpts): + ckpt = torch.load(Path(meta_ckpt_dir, + f"consolidated.{i:02d}.pth"), + map_location="cpu") + embeds[i] = ckpt[name] + embed = combine_embeddings(embeds, num_ckpts).to(torch_dtype) + load_weights_from_meta_ckpt.saved_embed = embed + + return load_weights_from_meta_ckpt.saved_embed + + logger.info('Loading weights from Meta LLaMA checkpoints ...') + tik = time.time() + + num_kv_heads = config.num_key_value_heads + mha_mode = (num_kv_heads == config.num_attention_heads) + + ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*.pth")) + num_ckpts = len(ckpts) + # llama/llama2 doesn't have MQA. So, simplifying loader logic by not worrying about it. + assert num_kv_heads > 1 or num_kv_heads >= num_ckpts, \ + f"We don't know how the {num_kv_heads} KV heads are distributed among {num_ckpts} checkpoints." + + head_size = config.hidden_size // config.num_attention_heads + ckpt = get_current_weights(num_ckpts) + layers_range = mapping.pp_layers(config.num_hidden_layers) + + for l in layers_range: + prefix = f'layers.{l}.attention.' + q_weight = permute(ckpt[prefix + 'wq.weight'].clone(), + nH=(config.num_attention_heads // mapping.tp_size), + d=config.hidden_size, + dH=head_size) + if num_kv_heads < mapping.tp_size and num_ckpts >= mapping.tp_size: + assert mapping.tp_size % num_kv_heads == 0 + assert False, "Not supported yet" + k_weight = permute(ckpt[prefix + 'wk.weight'].clone(), + nH=((num_kv_heads + mapping.tp_size - 1) // + mapping.tp_size), + d=config.hidden_size, + dH=head_size) + v_weight = ckpt[prefix + 'wv.weight'].clone() + + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + ckpt[prefix + 'qkv.weight'] = qkv_weight + + for k, v in ckpt.items(): + dtype = torch_dtype if 'feed_forward.gate' not in k else torch.float32 + + v = v.to(dtype) + if "tok_embeddings" in k: + if not config.use_parallel_embedding: + v = gather_embedding(v, k, num_ckpts) + elif config.embedding_sharding_dim == 0: + # this needs a gather and then resplit along different dims + v = gather_embedding(v, k, num_ckpts) + v = split(v, mapping.tp_size, mapping.tp_rank, 0) + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = v + elif "output" in k: + if mapping.is_last_pp_rank(): + if config.vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = pad_vocab_size(config.vocab_size, + mapping.tp_size) + pad_width = vocab_size_padded - config.vocab_size + v = torch.from_numpy( + np.pad(v.detach().cpu().numpy(), + ((0, pad_width), (0, 0)), + 'constant', + constant_values=0)) + weights['lm_head.weight'] = v + elif k == "norm.weight": + if mapping.is_last_pp_rank(): + weights['transformer.ln_f.weight'] = v + else: + # layer specific weights + layer_idx = extract_layer_idx(k) + + if layer_idx is None or int(layer_idx) not in layers_range: + continue + idx = int(layer_idx) - layers_range[0] + tllm_prex = f'transformer.layers.{idx}.' + + if 'attention_norm.weight' in k: + weights[tllm_prex + 'input_layernorm.weight'] = v + elif 'ffn_norm.weight' in k: + weights[tllm_prex + 'post_layernorm.weight'] = v + elif 'feed_forward.w3.weight' in k: + weights[tllm_prex + 'mlp.gate.weight'] = v + elif 'feed_forward.w2.weight' in k: + weights[tllm_prex + 'mlp.proj.weight'] = v + elif 'feed_forward.w1.weight' in k: + weights[tllm_prex + 'mlp.fc.weight'] = v + elif 'attention.wo.weight' in k: + weights[tllm_prex + 'attention.dense.weight'] = v + elif 'attention.qkv.weight' in k: + weights[tllm_prex + 'attention.qkv.weight'] = v + elif 'feed_forward.gate' in k: + weights[tllm_prex + 'mlp.router.weight'] = v + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + logger.info(f'Weights loaded. Total time: {t}') return weights diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py index cfec9735b..999c6a451 100644 --- a/tensorrt_llm/models/llama/model.py +++ b/tensorrt_llm/models/llama/model.py @@ -12,27 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -from pathlib import Path -from typing import Optional +from typing import Optional, Union from ..._utils import pad_vocab_size from ...functional import Tensor, non_gated_version, recv, send from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, - Embedding, GatedMLP, MoeConfig, PositionEmbeddingType, - RmsNorm) + Embedding, GatedMLP, PositionEmbeddingType, RmsNorm) from ...lora_manager import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ...plugin import init_all_reduce_helper from ...quantization import W8A8_SQ_PLUGIN_LIST, QuantAlgo +from ..convert_utils import has_safetensors from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, - PretrainedConfig, QuantConfig) + QuantConfig, preprocess_weights) +from .config import LLaMAConfig +from .convert import (load_hf_llama, load_weights_from_hf_by_shard, + load_weights_from_hf_model, + load_weights_from_hf_safetensors, + load_weights_from_meta_ckpt) class LLaMADecoderLayer(Module): - def __init__(self, config: PretrainedConfig, layer_idx: int): + def __init__(self, config: LLaMAConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.config = config @@ -65,18 +68,11 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): ClsMLP = GatedMLP mlp_kwargs = {} - if config.moe_num_experts > 1: + if config.moe.has_moe(): ClsMLP = MOE mlp_kwargs = { - "moe_config": - MoeConfig( - config.moe_num_experts, - config.moe_top_k, - config.moe_tp_mode, - config.moe_normalization_mode, - ), - "tp_rank": - config.mapping.tp_rank, + "moe_config": config.moe, + "tp_rank": config.mapping.tp_rank, } self.mlp = ClsMLP(hidden_size=config.hidden_size, @@ -172,7 +168,7 @@ def forward(self, class LLaMAModel(Module): - def __init__(self, config: PretrainedConfig) -> None: + def __init__(self, config: LLaMAConfig) -> None: super().__init__() init_all_reduce_helper() @@ -235,9 +231,9 @@ def forward(self, class LLaMAForCausalLM(DecoderModelForCausalLM): + config_class = LLaMAConfig - def __init__(self, config: PretrainedConfig): - self.check_config(config) + def __init__(self, config: LLaMAConfig): transformer = LLaMAModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) @@ -255,40 +251,53 @@ def __init__(self, config: PretrainedConfig): self.mapping = config.mapping super().__init__(config, transformer, lm_head) - def check_config(self, config): - config.set_if_not_exist('mlp_bias', False) - config.set_if_not_exist('attn_bias', False) - config.set_if_not_exist('rotary_base', 10000.0) - config.set_if_not_exist('rotary_scaling', None) - config.set_if_not_exist('moe_num_experts', 0) - config.set_if_not_exist('moe_top_k', 0) - config.set_if_not_exist('moe_tp_mode', - MoeConfig.ParallelismMode.TENSOR_PARALLEL) - config.set_if_not_exist( - 'moe_normalization_mode', - MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE) - @classmethod - def from_hugging_face(cls, - hf_model_dir, - dtype='float16', - mapping: Optional[Mapping] = None, - **kwargs): - from . import convert - if mapping is None: - mapping = Mapping() - llama = convert.from_hugging_face( + def from_hugging_face( cls, - hf_model_dir, - dtype, - mapping=mapping, - quantization=kwargs.get('quantization', QuantConfig()), - load_by_shard=kwargs.get('load_by_shard', False), - load_model_on_cpu=kwargs.get('load_model_on_cpu', False), - override_fields=kwargs.get('override_fields', {}), - skip_loading_weights=kwargs.get('skip_loading_weights', False), - preloaded_model=kwargs.get('preloaded_model', None)) - return llama + hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'], + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + ''' Create a LLaMAForCausalLM object from give parameters + ''' + import transformers + + load_by_shard = kwargs.pop('load_by_shard', False) + load_model_on_cpu = kwargs.pop('load_model_on_cpu', False) + + assert hf_model_or_dir is not None + use_preloading = isinstance(hf_model_or_dir, + transformers.PreTrainedModel) + if use_preloading: + hf_model = hf_model_or_dir + hf_config_or_dir = hf_model.config + else: + hf_model_dir = hf_model_or_dir + hf_config_or_dir = hf_model_or_dir + + config = LLaMAConfig.from_hugging_face(hf_config_or_dir, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + **kwargs) + + if use_preloading: + assert not load_by_shard + weights = load_weights_from_hf_model(hf_model, config) + elif load_by_shard: + weights = load_weights_from_hf_by_shard(hf_model_dir, config) + elif has_safetensors( + hf_model_dir) and not config.quant_mode.has_any_quant(): + weights = load_weights_from_hf_safetensors(hf_model_dir, config) + else: + hf_model = load_hf_llama(hf_model_dir, load_model_on_cpu) + weights = load_weights_from_hf_model(hf_model, config) + preprocess_weights(weights, config) + + model = LLaMAForCausalLM(config) + model.load(weights) + return model def default_plugin_config(self, **kwargs): plugin_config = super().default_plugin_config(**kwargs) @@ -298,98 +307,61 @@ def default_plugin_config(self, **kwargs): @classmethod def from_meta_ckpt(cls, - meta_ckpt_dir, - dtype, - mapping, - use_parallel_embedding: Optional[bool] = False, - embedding_sharding_dim: Optional[int] = 0): - meta_config = None - with open(Path(meta_ckpt_dir, "params.json")) as fp: - meta_config: dict = json.load(fp) - assert meta_config is not None - config = {} - n_embd = meta_config["dim"] - n_head = meta_config["n_heads"] - n_kv_head = meta_config.get("n_kv_heads", n_head) - vocab_size = meta_config.get("vocab_size", 32000) - - # Reset vocab_size to 32000 for LLama v2 checkpoint. - if vocab_size == -1: - vocab_size = 32000 - - if "hidden_dim" in meta_config: - inter_size = meta_config["hidden_dim"] - else: - multiple_of = meta_config.get("multiple_of", 1) - n_embd_ = int(4 * n_embd * 2 / 3) - ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1) - inter_size = multiple_of * ( - (int(n_embd_ * ffn_dim_multiplier) + multiple_of - 1) // - multiple_of) - # meta checkpoint don't have vocab_size|hidden_act|rotary_base specified, use same default value as HF - config.update({ - 'architecture': "LlamaForCausalLM", - 'dtype': dtype, - 'logits_dtype': 'float32', - 'num_hidden_layers': meta_config["n_layers"], - 'num_attention_heads': n_head, - 'hidden_size': n_embd, - 'intermediate_size': inter_size, - 'num_key_value_heads': n_kv_head, - 'vocab_size': vocab_size, - 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': 2048, - 'hidden_act': 'silu', - 'rotary_base': meta_config.get('rope_theta', 10000), - 'norm_epsilon': meta_config["norm_eps"], - 'mapping': { - 'world_size': mapping.tp_size * mapping.pp_size, - 'tp_size': mapping.tp_size, - 'pp_size': mapping.pp_size, - }, - }) - pretrained_config = PretrainedConfig.from_dict(config) - pretrained_config.use_parallel_embedding = use_parallel_embedding - pretrained_config.embedding_sharding_dim = embedding_sharding_dim - pretrained_config.set_rank(mapping.rank) - - llama = cls(pretrained_config) - from .weight import load_from_meta_llama - weights = load_from_meta_llama(meta_ckpt_dir, mapping, - pretrained_config) - llama.load(weights) - return llama + meta_ckpt_dir: str, + dtype: str = 'auto', + mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + **kwargs): + config = LLaMAConfig.from_meta_ckpt(meta_ckpt_dir, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + **kwargs) + + weights = load_weights_from_meta_ckpt(meta_ckpt_dir, config) + preprocess_weights(weights, config) + + model = LLaMAForCausalLM(config) + model.load(weights) + return model @classmethod def quantize( cls, - hf_model_dir, - output_dir, - quant_config: QuantConfig, - *, - dtype='float16', + hf_model_dir: str, + output_dir: str, + dtype: str = 'auto', mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + *, calib_dataset='cnn_dailymail', calib_batches=512, calib_batch_size=1, + calib_max_seq_length=512, random_seed=1234, tokenizer_max_seq_length=2048, **kwargs, ): - DEFAULT_Modelopt_FLOW = [ + DEFAULT_MODELOPT_FLOW = [ QuantAlgo.W4A16_AWQ, QuantAlgo.FP8, QuantAlgo.W8A8_SQ_PER_CHANNEL, QuantAlgo.W4A8_AWQ ] - use_modelopt_quantization = quant_config.quant_algo in DEFAULT_Modelopt_FLOW - if use_modelopt_quantization: + config = LLaMAConfig.from_hugging_face(hf_model_dir, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + **kwargs) + + if quant_config.quant_algo in DEFAULT_MODELOPT_FLOW: super().quantize(hf_model_dir, output_dir, - quant_config, - dtype=dtype, - mapping=mapping, + dtype=config.dtype, + mapping=config.mapping, + quant_config=config.quantization, calib_dataset=calib_dataset, calib_batches=calib_batches, calib_batch_size=calib_batch_size, + calib_max_seq_length=calib_max_seq_length, random_seed=random_seed, tokenizer_max_seq_length=tokenizer_max_seq_length) else: @@ -404,16 +376,10 @@ def quantize( assert is_valid_native_quant, f"Internal error: shall call Modelopt for this quantization {quant_config}" from . import convert - convert.quantize( - dtype, - hf_model_dir, - output_dir, - mapping, - quant_config, - calib_dataset=calib_dataset, - override_fields=kwargs.get('override_fields', {}), - dataset_cache_dir=kwargs.get('dataset_cache_dir', None), - ) + convert.quantize(hf_model_dir, + output_dir, + config=config, + calib_dataset=calib_dataset) def use_lora(self, lora_config: LoraConfig): use_lora(self, lora_config) diff --git a/tensorrt_llm/models/llama/weight.py b/tensorrt_llm/models/llama/weight.py deleted file mode 100644 index 0edebf09a..000000000 --- a/tensorrt_llm/models/llama/weight.py +++ /dev/null @@ -1,1183 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import configparser -import time -from pathlib import Path -from typing import List, Union - -import numpy as np -import torch -from safetensors import safe_open - -from ..._utils import (numpy_to_torch, pad_vocab_size, str_dtype_to_torch, - torch_to_numpy) -from ...layers import MoeConfig -from ...logger import logger -from ...mapping import Mapping -from ...quantization import QuantMode -from ..convert_utils import (iterate_shard_files, load_state_dict, - retrieved_layer_index_from_name) -from ..modeling_utils import PretrainedConfig - - -def gen_suffix(rank, use_smooth_quant, quant_per_channel): - suffix = f"{rank}.bin" - if use_smooth_quant: - sq_prefix = "int8." - if quant_per_channel: - sq_prefix += "col." - suffix = sq_prefix + suffix - return suffix - - -def extract_layer_idx(name): - ss = name.split('.') - for s in ss: - if s.isdigit(): - return s - return None - - -def split(v: Union[np.ndarray, torch.Tensor], - tp_size: int, - tp_rank: int, - dim=0): - if tp_size == 1: - return v - assert len(v.shape) > 1 or dim == 0 - if isinstance(v, np.ndarray): - return np.ascontiguousarray( - np.split(v, tp_size, axis=dim)[tp_rank].copy()) - else: - assert v.shape[dim] % tp_size == 0, \ - 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' - split_size = v.shape[dim] // tp_size - return v.split(split_size, dim=dim)[tp_rank].clone().detach() - - -def dup_kv_weight(v, num_head, tp_size): - assert tp_size % num_head == 0 - reps = tp_size // num_head - head_size = v.shape[0] // num_head - v = v.reshape(num_head, head_size, - -1)[:, None, :, :].expand(num_head, reps, head_size, - v.shape[1]) - return v.reshape(num_head * reps * head_size, -1).clone().detach() - - -def parse_bin_config(ini_file): - gpt_config = configparser.ConfigParser() - gpt_config.read(ini_file) - - n_embd = gpt_config.getint('llama', 'hidden_size') - n_head = gpt_config.getint('llama', 'num_attention_heads') - n_layer = gpt_config.getint('llama', 'num_hidden_layers') - n_positions = gpt_config.getint('llama', 'max_position_embeddings') - vocab_size = gpt_config.getint('llama', 'vocab_size') - hidden_act = gpt_config.get('llama', 'hidden_act') - inter_size = gpt_config.getint('llama', 'intermediate_size', fallback=None) - n_kv_head = gpt_config.getint('llama', 'num_key_value_heads', fallback=None) - - if inter_size is None: - inter_size = 4 * n_embd - - return n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head - - -class QkvWeightHelper: - """ A helper utility for loading QKV weights from sharded files. """ - - def __init__(self, config: PretrainedConfig): - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.tp_size = config.mapping.tp_size - self.tp_rank = config.mapping.tp_rank - self.is_mha = self.num_heads == self.num_kv_heads - self._qkv_weights = {} - - @staticmethod - def is_qkv_weight(name): - for k in ['q_proj', 'k_proj', 'v_proj']: - if 'self_attn' in name and k in name: - return True - return False - - def add_weight(self, i: int, name: str, weight: torch.Tensor): - if 'q_proj' in name: - tag = 'q' - elif 'k_proj' in name: - tag = 'k' - elif 'v_proj' in name: - tag = 'v' - else: - raise ValueError(f'Got an unexpected parameter of name {name}') - if i not in self._qkv_weights: - self._qkv_weights[i] = {} - self._qkv_weights[i][tag] = weight - - def is_qkv_prepared(self, layer_idx): - if layer_idx not in self._qkv_weights: - return False - weights = self._qkv_weights[layer_idx] - return 'q' in weights and 'k' in weights and 'v' in weights - - def split_qkv_weights(self, layer_idx): - if not self.is_qkv_prepared(layer_idx): - return None - weights = self._qkv_weights.pop(layer_idx) # to prevent memory leak. - q, k, v = (torch.tensor(weights[t]) for t in ['q', 'k', 'v']) - - if not self.is_mha: - head_size = self.hidden_size // self.num_heads - if self.num_kv_heads < self.tp_size: - # duplicate the KV heads up to tensor_parallel - k = dup_kv_weight(k, self.num_kv_heads, self.tp_size) - v = dup_kv_weight(v, self.num_kv_heads, self.tp_size) - assert k.shape[0] % (self.tp_size * head_size) == 0 - assert v.shape[0] % (self.tp_size * head_size) == 0 - wq = split(q, self.tp_size, self.tp_rank) - wk = split(k, self.tp_size, self.tp_rank) - wv = split(v, self.tp_size, self.tp_rank) - fused_qkv = torch.cat((wq, wk, wv), dim=0) - else: - qkv = torch.cat([q, k, v], dim=0) - qkv = qkv.reshape(3, q.shape[0], q.shape[1]) - fused_qkv = split(qkv, self.tp_size, self.tp_rank, dim=1) - fused_qkv = fused_qkv.reshape(3 * (q.shape[0] // self.tp_size), - q.shape[1]) - return fused_qkv - - -def load_from_hf_checkpoint(model_dir, mapping=Mapping(), config=None): - '''Weights-only quantization is the only supported quantization recipe here.''' - logger.info('Loading weights from HF LLaMA...') - tik = time.time() - weights = {} - dtype = config.dtype - if isinstance(dtype, str): - dtype = str_dtype_to_torch(dtype) - - moe_config = MoeConfig(config.moe_num_experts, config.moe_top_k, - config.moe_tp_mode, config.moe_normalization_mode) - assert not moe_config.has_moe(), "MoE does not support sharded load" - - model_dir = Path(model_dir) - - from transformers import AutoConfig - hf_config = AutoConfig.from_pretrained(model_dir) - - quant_mode = config.quant_mode - if quant_mode.is_int8_weight_only(): - plugin_weight_only_quant_type = torch.int8 - elif quant_mode.is_int4_weight_only(): - plugin_weight_only_quant_type = torch.quint4x2 - use_weight_only = quant_mode.is_weight_only() - - layers_range = mapping.pp_layers(config.num_hidden_layers) - - qkv_weight_helper = QkvWeightHelper(config) - - for model_file in iterate_shard_files(model_dir, - rank=mapping.tp_rank, - progress_bar=False): - logger.debug(f'Loading file {str(model_file)}...') - model_params = load_state_dict(model_file, dtype=dtype) - for name, param in model_params.items(): - logger.debug(f'Converting weight {name}...') - layer_idx = retrieved_layer_index_from_name(name) - if layer_idx is None: - layer = None - else: - if layer_idx not in layers_range: - continue - tllm_prex = f'transformer.layers.{layer_idx}.' - - if 'model.embed_tokens.weight' in name: - if hf_config.tie_word_embeddings: - # lm_head.weight has the same weights as embedding - if mapping.is_last_pp_rank(): - - if config.vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = pad_vocab_size( - config.vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - config.vocab_size - param = torch.from_numpy( - np.pad(param.detach().cpu().numpy(), - ((0, pad_width), (0, 0)), - 'constant', - constant_values=0)) - weights['lm_head.weight'] = split( - param, mapping.tp_size, mapping.tp_rank) - if config.use_parallel_embedding: - param = split(param, mapping.tp_size, mapping.tp_rank, - config.embedding_sharding_dim) - if mapping.is_first_pp_rank(): - weights['transformer.vocab_embedding.weight'] = param - elif 'model.norm.weight' in name: - if mapping.is_last_pp_rank(): - weights['transformer.ln_f.weight'] = param - elif 'lm_head.weight' in name: - if mapping.is_last_pp_rank(): - if config.vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = pad_vocab_size( - config.vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - config.vocab_size - param = torch.from_numpy( - np.pad(param.detach().cpu().numpy(), - ((0, pad_width), (0, 0)), - 'constant', - constant_values=0)) - weights['lm_head.weight'] = split(param, mapping.tp_size, - mapping.tp_rank) - elif 'input_layernorm.weight' in name: - weights[tllm_prex + 'input_layernorm.weight'] = param - elif 'post_attention_layernorm.weight' in name: - weights[tllm_prex + 'post_layernorm.weight'] = param - elif qkv_weight_helper.is_qkv_weight(name): - qkv_weight_helper.add_weight(layer_idx, name, param) - if not qkv_weight_helper.is_qkv_prepared(layer_idx): - continue - split_v = qkv_weight_helper.split_qkv_weights(layer_idx) - if use_weight_only: - param = split_v.transpose() - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - param, plugin_weight_only_quant_type) - weights[tllm_prex + - 'attention.qkv.weight'] = processed_torch_weights - weights[ - tllm_prex + - 'attention.qkv.per_channel_scale'] = torch_weight_scales - else: - weights[tllm_prex + 'attention.qkv.weight'] = split_v - elif 'self_attn.o_proj.weight' in name: - split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=1) - if use_weight_only: - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - split_v.transpose(), plugin_weight_only_quant_type) - weights[tllm_prex + - 'attention.dense.weight'] = processed_torch_weights - weights[ - tllm_prex + - 'attention.dense.per_channel_scale'] = torch_weight_scales - else: - weights[tllm_prex + 'attention.dense.weight'] = split_v - elif 'mlp.up_proj.weight' in name: - split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=0) - if use_weight_only: - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - split_v.transpose(), plugin_weight_only_quant_type) - weights[tllm_prex + - 'mlp.gate.weight'] = processed_torch_weights - weights[tllm_prex + - 'mlp.gate.per_channel_scale'] = torch_weight_scales - else: - weights[tllm_prex + 'mlp.gate.weight'] = split_v - elif 'mlp.down_proj.weight' in name: - split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=1) - if use_weight_only: - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - split_v.transpose(), plugin_weight_only_quant_type) - weights[tllm_prex + - 'mlp.proj.weight'] = processed_torch_weights - weights[tllm_prex + - 'mlp.proj.per_channel_scale'] = torch_weight_scales - else: - weights[tllm_prex + 'mlp.proj.weight'] = split_v - - elif 'mlp.gate_proj.weight' in name: - split_v = split(param, mapping.tp_size, mapping.tp_rank, dim=0) - if use_weight_only: - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - split_v.transpose(), plugin_weight_only_quant_type) - layer.mlp.fc.weight.value = processed_torch_weights - layer.mlp.fc.per_channel_scale.value = torch_weight_scales - weights[tllm_prex + - 'mlp.fc.weight'] = processed_torch_weights - weights[tllm_prex + - 'mlp.fc.per_channel_scale'] = torch_weight_scales - else: - weights[tllm_prex + 'mlp.fc.weight'] = split_v - - del model_params - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Weights loaded. Total time: {t}') - return weights - - -def load_from_hf_llama(tensorrt_llm_llama: 'LLaMAForCausalLM', - hf_llama, - mapping=Mapping(), - dtype='float32', - use_gemm_woq_plugin=True): - logger.info('Loading weights from HF LLaMA...') - tik = time.time() - - quant_mode = getattr(tensorrt_llm_llama, 'quant_mode', QuantMode(0)) - if quant_mode.is_int8_weight_only(): - plugin_weight_only_quant_type = torch.int8 - elif quant_mode.is_int4_weight_only(): - plugin_weight_only_quant_type = torch.quint4x2 - use_weight_only = quant_mode.is_weight_only() - num_kv_heads = tensorrt_llm_llama.config.num_key_value_heads - mha_mode = (num_kv_heads == tensorrt_llm_llama.config.num_attention_heads) - - model_params = dict(hf_llama.named_parameters()) - # concatenate, duplicate and reshape q, k, v -> qkv - for l in range(hf_llama.config.num_hidden_layers): - prefix = f'model.layers.{l}.self_attn.' - q_weight = model_params[prefix + 'q_proj.weight'] - k_weight = model_params[prefix + 'k_proj.weight'] - v_weight = model_params[prefix + 'v_proj.weight'] - if not mha_mode: - head_size = tensorrt_llm_llama.config.hidden_size // tensorrt_llm_llama.config.num_attention_heads - if num_kv_heads < mapping.tp_size: - # duplicate the KV heads up to tensor_parallel - k_weight = dup_kv_weight(k_weight, num_kv_heads, - mapping.tp_size) - v_weight = dup_kv_weight(v_weight, num_kv_heads, - mapping.tp_size) - assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0 - assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0 - qkv_weight = [q_weight, k_weight, v_weight] - else: - qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) - - model_params[prefix + 'qkv_proj.weight'] = qkv_weight - - moe_config = MoeConfig(tensorrt_llm_llama.config.moe_num_experts, - tensorrt_llm_llama.config.moe_top_k, - tensorrt_llm_llama.config.moe_tp_mode, - tensorrt_llm_llama.config.moe_normalization_mode) - # concatenate MoE gated activations & stack experts - for l in range(hf_llama.config.num_hidden_layers): - - if not moe_config.has_moe(): - continue - - rank_experts = list(range(moe_config.num_experts)) - if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL: - rank_experts = mapping.ep_experts(moe_config.num_experts) - for suffix in ["w1", "w2", "w3"]: - model_params[f'model.layers.{l}.block_sparse_moe.experts.{suffix}.weight'] = \ - torch.stack(list(model_params[f'model.layers.{l}.block_sparse_moe.experts.{expert}.{suffix}.weight'] - for expert in rank_experts)) - - w3 = model_params[ - f'model.layers.{l}.block_sparse_moe.experts.w3.weight'] - w2 = model_params[ - f'model.layers.{l}.block_sparse_moe.experts.w2.weight'] - w1 = model_params[ - f'model.layers.{l}.block_sparse_moe.experts.w1.weight'] - if moe_config.tp_mode == moe_config.ParallelismMode.TENSOR_PARALLEL: - w3 = split(w3, mapping.tp_size, mapping.tp_rank, dim=1) - w2 = split(w2, mapping.tp_size, mapping.tp_rank, dim=2) - w1 = split(w1, mapping.tp_size, mapping.tp_rank, dim=1) - # concat w3 and w1 for gated expert - model_params[f'model.layers.{l}.block_sparse_moe.experts.w3w1.weight'] = \ - torch.concat([w3, w1], dim=-2) - model_params[ - f'model.layers.{l}.block_sparse_moe.experts.w2.weight'] = w2 - - torch_dtype = str_dtype_to_torch(dtype) - layers_range = mapping.pp_layers(hf_llama.config.num_hidden_layers) - - vocab_size = hf_llama.config.vocab_size - weights = {} - for k, v in model_params.items(): - t_dtype = torch_dtype if "block_sparse_moe.gate" not in k else torch.float32 - if isinstance(v, list): - v = [torch_to_numpy(vv.to(t_dtype).detach().cpu()) for vv in v] - else: - v = torch_to_numpy(v.to(t_dtype).detach().cpu()) - if 'model.embed_tokens.weight' in k: - if hf_llama.config.tie_word_embeddings: - # lm_head.weight has the same weights as embedding - if mapping.is_last_pp_rank(): - if vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = pad_vocab_size( - vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - vocab_size - v = torch.from_numpy( - np.pad(v.detach().cpu().numpy(), - ((0, pad_width), (0, 0)), - 'constant', - constant_values=0)) - weights['lm_head.weight'] = split(v, mapping.tp_size, - mapping.tp_rank) - - if tensorrt_llm_llama.config.use_parallel_embedding: - v = split(v, mapping.tp_size, mapping.tp_rank, - tensorrt_llm_llama.config.embedding_sharding_dim) - if mapping.is_first_pp_rank(): - weights['transformer.vocab_embedding.weight'] = v - elif 'model.norm.weight' in k: - if mapping.is_last_pp_rank(): - weights['transformer.ln_f.weight'] = v - - elif 'lm_head.weight' in k: - if mapping.is_last_pp_rank(): - if vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = tensorrt_llm_llama.lm_head.out_features * mapping.tp_size - pad_width = vocab_size_padded - vocab_size - v = np.pad(v, ((0, pad_width), (0, 0)), - 'constant', - constant_values=0) - - weights['lm_head.weight'] = split(v, mapping.tp_size, - mapping.tp_rank) - else: - layer_idx = extract_layer_idx(k) - if layer_idx is None or int(layer_idx) not in layers_range: - continue - idx = int(layer_idx) - layers_range[0] - if 'input_layernorm.weight' in k: - weights['transformer.layers.{}.input_layernorm.weight'.format( - idx)] = v - elif 'post_attention_layernorm.weight' in k: - weights['transformer.layers.{}.post_layernorm.weight'.format( - idx)] = v - - elif 'self_attn.qkv_proj.weight' in k: - if not mha_mode: - assert isinstance(v, list) and len(v) == 3 - wq = split(v[0], mapping.tp_size, mapping.tp_rank) - wk = split(v[1], mapping.tp_size, mapping.tp_rank) - wv = split(v[2], mapping.tp_size, mapping.tp_rank) - split_v = np.concatenate((wq, wk, wv)) - else: - q_emb = v.shape[0] // 3 - model_emb = v.shape[1] - v = v.reshape(3, q_emb, model_emb) - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) - split_v = split_v.reshape(3 * (q_emb // mapping.tp_size), - model_emb) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(v), plugin_weight_only_quant_type) - if not use_gemm_woq_plugin: - weights['transformer.layers.{}.attention.qkv.weight'. - format(idx)] = v - else: - weights['transformer.layers.{}.attention.qkv.weight'. - format(idx)] = processed_torch_weights - - weights[ - 'transformer.layers.{}.attention.qkv.per_channel_scale'. - format(idx)] = torch_weight_scales - else: - weights['transformer.layers.{}.attention.qkv.weight'.format( - idx)] = split_v - - elif 'self_attn.o_proj.weight' in k: - # dst = tensorrt_llm_llama.layers[idx].attention.dense.weight - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(v), plugin_weight_only_quant_type) - if not use_gemm_woq_plugin: - weights['transformer.layers.{}.attention.dense.weight'. - format(idx)] = v - else: - weights['transformer.layers.{}.attention.dense.weight'. - format(idx)] = processed_torch_weights - - weights[ - 'transformer.layers.{}.attention.dense.per_channel_scale' - .format(idx)] = torch_weight_scales - - else: - weights['transformer.layers.{}.attention.dense.weight'. - format(idx)] = split_v - - elif 'mlp.up_proj.weight' in k: - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(v), plugin_weight_only_quant_type) - - if not use_gemm_woq_plugin: - weights['transformer.layers.{}.mlp.gate.weight'.format( - idx)] = v - else: - weights['transformer.layers.{}.mlp.gate.weight'.format( - idx)] = processed_torch_weights - - weights['transformer.layers.{}.mlp.gate.per_channel_scale'. - format(idx)] = torch_weight_scales - else: - weights['transformer.layers.{}.mlp.gate.weight'.format( - idx)] = split_v - - elif 'mlp.down_proj.weight' in k: - # dst = tensorrt_llm_llama.layers[idx].mlp.proj.weight - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(v), plugin_weight_only_quant_type) - if not use_gemm_woq_plugin: - weights['transformer.layers.{}.mlp.proj.weight'.format( - idx)] = v - else: - weights['transformer.layers.{}.mlp.proj.weight'.format( - idx)] = processed_torch_weights - - weights['transformer.layers.{}.mlp.proj.per_channel_scale'. - format(idx)] = torch_weight_scales - else: - weights['transformer.layers.{}.mlp.proj.weight'.format( - idx)] = split_v - elif 'mlp.gate_proj.weight' in k: - # dst = tensorrt_llm_llama.layers[idx].mlp.fc.weight - split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0) - if use_weight_only: - v = np.ascontiguousarray(split_v.transpose()) - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(v), plugin_weight_only_quant_type) - - if not use_gemm_woq_plugin: - weights['transformer.layers.{}.mlp.fc.weight'.format( - idx)] = v - else: - weights['transformer.layers.{}.mlp.fc.weight'.format( - idx)] = processed_torch_weights - - weights['transformer.layers.{}.mlp.fc.per_channel_scale'. - format(idx)] = torch_weight_scales - else: - # dst.value = np.ascontiguousarray(split_v) - weights['transformer.layers.{}.mlp.fc.weight'.format( - idx)] = split_v - elif 'experts.w2.weight' in k: - # Note: no need for splitting, it's already been done above - split_v = v - if use_weight_only: - v = np.ascontiguousarray( - np.transpose(split_v, axes=(0, 2, 1))) - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(v), plugin_weight_only_quant_type) - weights['transformer.layers.{}.mlp.proj.weight'.format( - idx)] = processed_torch_weights - weights['transformer.layers.{}.mlp.proj.per_channel_scale'. - format(idx)] = torch_weight_scales - - else: - weights['transformer.layers.{}.mlp.proj.weight'.format( - idx)] = v - elif 'experts.w3w1.weight' in k: - # Note: no need for splitting, it's already been done above - split_v = v - if use_weight_only: - v = np.ascontiguousarray( - np.transpose(split_v, axes=(0, 2, 1))) - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - numpy_to_torch(v), plugin_weight_only_quant_type) - weights['transformer.layers.{}.mlp.fc.weight'.format( - idx)] = processed_torch_weights - weights['transformer.layers.{}.mlp.fc.per_channel_scale'. - format(idx)] = torch_weight_scales - - else: - weights['transformer.layers.{}.mlp.fc.weight'.format( - idx)] = v - - elif 'block_sparse_moe.gate' in k: - v = split(v, mapping.tp_size, mapping.tp_rank, dim=-1) - weights['transformer.layers.{}.mlp.router.weight'.format( - idx)] = v - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Weights loaded. Total time: {t}') - return weights - - -def load_from_gptq_llama(config: PretrainedConfig, quant_ckpt_path): - logger.info('Loading weights from groupwise GPTQ LLaMA safetensors...') - weights = {} - tik = time.time() - - num_hidden_layers = config.num_hidden_layers - vocab_size = config.vocab_size - dtype = config.dtype - mapping = config.mapping - - gptq_llama = safe_open(quant_ckpt_path, framework="pt", device=0) - gptq_prefix = "model." - gptq_suffix_list = [".qweight", ".qzeros", ".scales"] - gptq_key_list = [ - "embed_tokens.weight", # vocab_embedding - "lm_head.weight", # lm_head - "norm.weight", # ln_f - "self_attn.", # attention.qkv - "_proj", # qkv suffix - "self_attn.o_proj", # attention.dense - "mlp.up_proj", # mlp.gate - "mlp.down_proj", # mlp.proj - "mlp.gate_proj", # mlp.fc - "input_layernorm.weight", # input_layernorm - "post_attention_layernorm.weight", # post_layernorm - ] - split_sym = "." - - packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 - preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm - torch_dtype = str_dtype_to_torch(dtype) - - def load(key, no_prefix=0): - if no_prefix: - return gptq_llama.get_tensor(key) - else: - return gptq_llama.get_tensor(gptq_prefix + key) - - def torch_split(v, dim): - if v.shape[dim] % mapping.tp_size != 0: - logger.error( - "Current weight shape is invalid for mapping.tp_size=" + - str(mapping.tp_size)) - assert False, "Invalid TP size" - return v.split(v.shape[dim] // mapping.tp_size, - dim=dim)[mapping.tp_rank] - - def unpack_int32_into_int8(w_packed): - # Unpack inputs packed in int32/float32 into uint4 and store them in int8 format - w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) - w_unpacked = torch.zeros(w_packed_int4x2.shape[0], - w_packed_int4x2.shape[1] * 2, - dtype=torch.int8) - w_unpacked[:, ::2] = w_packed_int4x2 % 16 - w_unpacked[:, 1::2] = w_packed_int4x2 // 16 - return w_unpacked.contiguous() - - def process_and_assign_weight(v: List[torch.Tensor], - tllm_prex: str, - tp_dim: int = -1): - if tp_dim == -1: - qweight_int32, qzeros_int32, scales_fp16 = [ - item.cpu() for item in v - ] - else: - qweight_int32, qzeros_int32, scales_fp16 = [ - torch_split(item, tp_dim).cpu() for item in v - ] - - USE_UINT4_INPUT = 1 # Set to true if checkpoint store UINT4 weights - USE_GPTQ_FOR_LLAMA = 1 # GPTQ-for-LLaMA added 1 to zeros - - qweight_unpacked_int8 = unpack_int32_into_int8( - qweight_int32.T).T.contiguous() - 8 - qweight_interleaved = preprocessor(packer(qweight_unpacked_int8), - torch.quint4x2, - torch.float16).view(torch.float16) - # zeros = zeros * scales - qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32) - if not USE_UINT4_INPUT: - # Correcting UINT4 values back to INT4 order - mask_negative = qzeros_unpacked_int32[qzeros_unpacked_int32 < 0] - mask_positive = qzeros_unpacked_int32[qzeros_unpacked_int32 >= 0] - qzeros_unpacked_int32 = qzeros_unpacked_int32 + 16 * mask_negative - 16 * mask_positive - zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * USE_UINT4_INPUT - - USE_GPTQ_FOR_LLAMA) * scales_fp16 - zeros_x_scales_fp16 = zeros_x_scales_fp16.half() - - results = { - f'{tllm_prex}.weight': qweight_interleaved, - f'{tllm_prex}.weights_scaling_factor': scales_fp16, - f'{tllm_prex}.zero': zeros_x_scales_fp16, - } - return results - - # Load weights from GPTQ checkpoint into TRT-LLM module - # 1. vocab_embedding - v = load(gptq_key_list[0]) - if mapping.is_first_pp_rank(): - # tensorrt_llm_llama.vocab_embedding.weight.value = v.to( - # torch_dtype).cpu().numpy() - weights['transformer.vocab_embedding.weight'] = v.to(torch_dtype) - # 2. lm_head - v = load(gptq_key_list[1], "no_prefix") - if mapping.is_last_pp_rank(): - # tensorrt_llm_llama.lm_head.weight.value = torch_split( - # v, 0).to(torch_dtype).cpu().numpy() - if vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - vocab_size - v = torch.from_numpy( - np.pad(v.detach().cpu().numpy(), ((0, pad_width), (0, 0)), - 'constant', - constant_values=0)) - weights['lm_head.weight'] = torch_split(v, 0).to(torch_dtype) - - # 3. ln_f - v = load(gptq_key_list[2]) - if mapping.is_last_pp_rank(): - # tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() - weights['transformer.ln_f.weight'] = v.to(torch_dtype) - # 4. Weights inside each layer - layers_range = mapping.pp_layers(num_hidden_layers) - for l in layers_range: - layer_idx = l - layers_range[0] - prefix = "layers" + split_sym + str(layer_idx) + split_sym - logger.info(f'Process weights in layer: {layer_idx}') - # layer = tensorrt_llm_llama.layers[layer_idx] - tllm_prex = f'transformer.layers.{l-layers_range[0]}' - # 4.1 attention.qkv - qkv_weight_list = [] - for suf in gptq_suffix_list: - qkv_list = [] - for comp in ["q", "k", "v"]: - comp_part = load(prefix + gptq_key_list[3] + comp + - gptq_key_list[4] + suf) - comp_part = torch_split(comp_part, 1) - qkv_list.append(comp_part) - qkv_weight_list.append(torch.cat(qkv_list, dim=1)) - - # process_and_assign_weight(layer.attention.qkv, qkv_weight_list) - weights.update( - process_and_assign_weight(qkv_weight_list, - f'{tllm_prex}.attention.qkv')) - # 4.2 attention.dense - v = [load(prefix + gptq_key_list[5] + suf) for suf in gptq_suffix_list] - # process_and_assign_weight(layer.attention.dense, v, 0) - weights.update( - process_and_assign_weight(v, - f'{tllm_prex}.attention.dense', - tp_dim=0)) - # 4.3 mlp.gate - v = [load(prefix + gptq_key_list[6] + suf) for suf in gptq_suffix_list] - # process_and_assign_weight(layer.mlp.gate, v, 1) - weights.update( - process_and_assign_weight(v, f'{tllm_prex}.mlp.gate', tp_dim=1)) - # 4.4 mlp.proj - v = [load(prefix + gptq_key_list[7] + suf) for suf in gptq_suffix_list] - # process_and_assign_weight(layer.mlp.proj, v, 0) - weights.update( - process_and_assign_weight(v, f'{tllm_prex}.mlp.proj', tp_dim=0)) - # 4.5 mlp.fc - v = [load(prefix + gptq_key_list[8] + suf) for suf in gptq_suffix_list] - # process_and_assign_weight(layer.mlp.fc, v, 1) - weights.update( - process_and_assign_weight(v, f'{tllm_prex}.mlp.fc', tp_dim=1)) - # 4.6 input_layernorm - v = load(prefix + gptq_key_list[9]) - # layer.input_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() - weights[f'{tllm_prex}.input_layernorm.weight'] = v.to(torch_dtype) - - # 4.7 post_layernorm - v = load(prefix + gptq_key_list[10]) - # layer.post_layernorm.weight.value = v.to(torch_dtype).cpu().numpy() - weights[f'{tllm_prex}.post_layernorm.weight'] = v.to(torch_dtype) - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Weights loaded. Total time: {t}') - - return weights - - -def load_from_meta_llama(meta_ckpt_dir, mapping, config): - torch_dtype = str_dtype_to_torch(config.dtype) - weights = {} - - def gather_ckpts(ckpts): - gathered = {} - for k in ckpts[0]: - d = 0 - if any([n in k for n in ["wo", "w2", "tok"]]): - d = 1 - if "norm" in k or "rope" in k: # no TP - gathered[k] = ckpts[0][k].clone() - else: - gathered[k] = torch.cat([pt[k] for pt in ckpts], dim=d).clone() - return gathered - - def split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank): - split_ckpt = {} - for k, v in ckpt.items(): - d = 0 - if any(n in k for n in - ["wo", "feed_forward.w2", "tok", "feed_forward.gate"]): - d = 1 - if "norm" in k or "rope" in k: # no TP - split_ckpt[k] = v.clone() - elif config.num_key_value_heads < mapping.tp_size and any( - n in k for n in ["wk", "wv"]): - assert mapping.tp_size % config.num_key_value_heads == 0 - # special case: we need to duplicate KV head - tmp = dup_kv_weight(v, config.num_key_value_heads, - mapping.tp_size) - split_ckpt[k] = torch.split(tmp, - tmp.shape[d] // ranks_per_ckpt, - dim=d)[ckpt_rank].clone() - else: - split_ckpt[k] = torch.split(v, - v.shape[d] // ranks_per_ckpt, - dim=d)[ckpt_rank].clone() - return split_ckpt - - def get_current_weights(num_ckpts): - if num_ckpts > mapping.tp_size: - # combine ckpts - assert (num_ckpts % mapping.tp_size) == 0 - nf = num_ckpts // mapping.tp_size - fs = nf * mapping.tp_rank - file_ids = list(range(fs, fs + nf)) - ckpts = [] - for f in file_ids: - ckpt = torch.load(Path(meta_ckpt_dir, - f"consolidated.{f:02d}.pth"), - map_location="cpu") - ckpts.append(ckpt) - return gather_ckpts(ckpts) - elif num_ckpts < mapping.tp_size: - # split ckpt - assert (mapping.tp_size % num_ckpts) == 0 - ranks_per_ckpt = mapping.tp_size // num_ckpts - ckpt_fid = mapping.tp_rank // ranks_per_ckpt - ckpt_rank = mapping.tp_rank % ranks_per_ckpt - nH_per_ckpt = config.num_attention_heads // num_ckpts - assert (nH_per_ckpt % ranks_per_ckpt) == 0 - ckpt = torch.load(Path(meta_ckpt_dir, - f"consolidated.{ckpt_fid:02d}.pth"), - map_location="cpu") - return split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank) - - # num_ckpts == tensor_parallel, 1:1 mapping from files to TP - return torch.load(Path(meta_ckpt_dir, - f"consolidated.{mapping.tp_rank:02d}.pth"), - map_location="cpu") - - def permute(w, nH, d, dH): - # due to MQA's wk, nH*dH != d could be true - return w.view(nH, dH // 2, 2, d).transpose(1, 2).reshape(nH * dH, d) - - def extract_layer_idx(name): - ss = name.split('.') - for s in ss: - if s.isdigit(): - return s - return None - - if not hasattr(load_from_meta_llama, "saved_embed"): - load_from_meta_llama.saved_embed = None - - def combine_embeddings(embeds, num_ckpts): - if len(embeds) == 1: - return embeds[0] - assert [ - embeds[i].shape == embeds[i + 1].shape - for i in range(len(embeds) - 1) - ] - if embeds[0].shape[0] == config.vocab_size // num_ckpts: - merge_dim = 0 - elif embeds[0].shape[1] == config.hidden_size // num_ckpts: - merge_dim = 1 - else: - logger.error("Unable to infer embedding split dimension") - assert False, "Unable to infer embedding split dimension" - return torch.cat(embeds, dim=merge_dim) - - def gather_embedding(cur_embed, name: str, num_ckpts): - if mapping.tp_size == 1: - # even if num_ckpts > 1, get_current_weights will already have it gathered - return cur_embed - if load_from_meta_llama.saved_embed is None: - embeds = [None] * num_ckpts - for i in range(num_ckpts): - ckpt = torch.load(Path(meta_ckpt_dir, - f"consolidated.{i:02d}.pth"), - map_location="cpu") - embeds[i] = ckpt[name] - embed = combine_embeddings(embeds, num_ckpts).to(torch_dtype) - load_from_meta_llama.saved_embed = embed - - return load_from_meta_llama.saved_embed - - logger.info('Loading weights from Meta LLaMA checkpoints ...') - tik = time.time() - - num_kv_heads = config.num_key_value_heads - mha_mode = (num_kv_heads == config.num_attention_heads) - - ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*.pth")) - num_ckpts = len(ckpts) - # llama/llama2 doesn't have MQA. So, simplifying loader logic by not worrying about it. - assert num_kv_heads > 1 or num_kv_heads >= num_ckpts, \ - f"We don't know how the {num_kv_heads} KV heads are distributed among {num_ckpts} checkpoints." - - head_size = config.hidden_size // config.num_attention_heads - ckpt = get_current_weights(num_ckpts) - layers_range = mapping.pp_layers(config.num_hidden_layers) - - for l in layers_range: - prefix = f'layers.{l}.attention.' - q_weight = permute(ckpt[prefix + 'wq.weight'].clone(), - nH=(config.num_attention_heads // mapping.tp_size), - d=config.hidden_size, - dH=head_size) - if num_kv_heads < mapping.tp_size and num_ckpts >= mapping.tp_size: - assert mapping.tp_size % num_kv_heads == 0 - assert False, "Not supported yet" - k_weight = permute(ckpt[prefix + 'wk.weight'].clone(), - nH=((num_kv_heads + mapping.tp_size - 1) // - mapping.tp_size), - d=config.hidden_size, - dH=head_size) - v_weight = ckpt[prefix + 'wv.weight'].clone() - - qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) - ckpt[prefix + 'qkv.weight'] = qkv_weight - - for k, v in ckpt.items(): - dtype = torch_dtype if 'feed_forward.gate' not in k else torch.float32 - - v = v.to(dtype) - if "tok_embeddings" in k: - if not config.use_parallel_embedding: - v = gather_embedding(v, k, num_ckpts) - elif config.embedding_sharding_dim == 0: - # this needs a gather and then resplit along different dims - v = gather_embedding(v, k, num_ckpts) - v = split(v, mapping.tp_size, mapping.tp_rank, 0) - if mapping.is_first_pp_rank(): - weights['transformer.vocab_embedding.weight'] = v - elif "output" in k: - if mapping.is_last_pp_rank(): - if config.vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = pad_vocab_size(config.vocab_size, - mapping.tp_size) - pad_width = vocab_size_padded - config.vocab_size - v = torch.from_numpy( - np.pad(v.detach().cpu().numpy(), - ((0, pad_width), (0, 0)), - 'constant', - constant_values=0)) - weights['lm_head.weight'] = v - elif k == "norm.weight": - if mapping.is_last_pp_rank(): - weights['transformer.ln_f.weight'] = v - else: - # layer specific weights - layer_idx = extract_layer_idx(k) - - if layer_idx is None or int(layer_idx) not in layers_range: - continue - idx = int(layer_idx) - layers_range[0] - tllm_prex = f'transformer.layers.{idx}.' - - if 'attention_norm.weight' in k: - weights[tllm_prex + 'input_layernorm.weight'] = v - elif 'ffn_norm.weight' in k: - weights[tllm_prex + 'post_layernorm.weight'] = v - elif 'feed_forward.w3.weight' in k: - weights[tllm_prex + 'mlp.gate.weight'] = v - elif 'feed_forward.w2.weight' in k: - weights[tllm_prex + 'mlp.proj.weight'] = v - elif 'feed_forward.w1.weight' in k: - weights[tllm_prex + 'mlp.fc.weight'] = v - elif 'attention.wo.weight' in k: - weights[tllm_prex + 'attention.dense.weight'] = v - elif 'attention.qkv.weight' in k: - weights[tllm_prex + 'attention.qkv.weight'] = v - elif 'feed_forward.gate' in k: - weights[tllm_prex + 'mlp.router.weight'] = v - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Weights loaded. Total time: {t}') - return weights - - -def load_from_hf_safetensors(model_dir: str, config: PretrainedConfig, mapping): - logger.info('Loading weights from Huggingface LLaMA safetensors...') - tik = time.time() - import json - import os - - import safetensors - weights = {} - - model_dir = model_dir if model_dir.endswith("/") else model_dir + "/" - safetensors_map = {} - try: - with open(model_dir + "model.safetensors.index.json", 'r') as fr: - sharding_map = json.load(fr) - for k, v in sharding_map['weight_map'].items(): - safetensors_map[k] = int(v[6:11]) - 1 - except FileNotFoundError: - pass - shard_files = [] - for name in os.listdir(model_dir): - if name.endswith(".safetensors"): - shard_files.append(name) - shard_files.sort() - safetensors_ptrs = [ - safetensors.safe_open(model_dir + shard_file, - framework="pt", - device="cpu") for shard_file in shard_files - ] - - num_hidden_layers = config.num_hidden_layers - vocab_size = config.vocab_size - pad_vocab = vocab_size % mapping.tp_size != 0 - vocab_size_padded = pad_vocab_size(config.vocab_size, mapping.tp_size) - dtype = config.dtype - - moe_config = MoeConfig(config.moe_num_experts, config.moe_top_k, - config.moe_tp_mode, config.moe_normalization_mode) - - model_prefix = "model." - key_list = [ - "embed_tokens.weight", # vocab_embedding - "lm_head.weight", # lm_head - "norm.weight", # ln_f - "self_attn.", # attention.qkv - "_proj.weight", # qkv suffix - "self_attn.o_proj.weight", # attention.dense - "mlp.up_proj.weight", # mlp.gate - "mlp.down_proj.weight", # mlp.proj - "mlp.gate_proj.weight", # mlp.fc - "input_layernorm.weight", # input_layernorm - "post_attention_layernorm.weight", # post_layernorm - ] - - torch_dtype = str_dtype_to_torch(dtype) - - def load(key, tp_dim=-1, no_prefix=0): - if not no_prefix: - key = model_prefix + key - ptr_idx = safetensors_map[key] if key in safetensors_map else 0 - if tp_dim == -1: - res = safetensors_ptrs[ptr_idx].get_tensor(key) - else: - tensor_slice = safetensors_ptrs[ptr_idx].get_slice(key) - tensor_shape = tensor_slice.get_shape() - if tensor_shape[tp_dim] % mapping.tp_size != 0: - logger.error( - "Current weight shape is invalid for mapping.tp_size=" + - str(mapping.tp_size)) - slice_width = tensor_shape[tp_dim] // mapping.tp_size - if tp_dim == 0: - res = tensor_slice[slice_width * mapping.tp_rank:slice_width * - (mapping.tp_rank + 1), :] - elif tp_dim == 1: - res = tensor_slice[:, - slice_width * mapping.tp_rank:slice_width * - (mapping.tp_rank + 1)] - else: - assert False, "Invalid TP dim" - return res.to(torch_dtype).contiguous( - ) if "block_sparse_moe.gate" not in key else res.to(torch.float32) - - if mapping.is_first_pp_rank(): - weights['transformer.vocab_embedding.weight'] = load( - key_list[0], config.embedding_sharding_dim - if config.use_parallel_embedding else -1) # vocab_embedding - - if mapping.is_last_pp_rank(): - v = load(key_list[1], -1, 1) if pad_vocab else load(key_list[1], 0, - 1) # lm_head - if pad_vocab: - v = torch.nn.functional.pad( - v, (0, 0, 0, vocab_size_padded - vocab_size), 'constant', 0) - v = split(v, mapping.tp_size, mapping.tp_rank) - weights['lm_head.weight'] = v - weights['transformer.ln_f.weight'] = load(key_list[2]) # ln_f - - layers_range = mapping.pp_layers(num_hidden_layers) - for l in layers_range: - layer_idx = l - layers_range[0] - prefix = f'layers.{l}.' - tllm_prex = f'transformer.layers.{layer_idx}' - - # Attention - qkv_list = [] - for comp in ["q", "k", "v"]: - comp_part = load(prefix + key_list[3] + comp + key_list[4], 0) - qkv_list.append(comp_part) - weights[f'{tllm_prex}.attention.qkv.weight'] = torch.cat(qkv_list, 0) - weights[f'{tllm_prex}.attention.dense.weight'] = load( - prefix + key_list[5], 1) # attention.dense - - # MLP - if not moe_config.has_moe(): - weights[f'{tllm_prex}.mlp.gate.weight'] = load( - prefix + key_list[6], 0) # mlp.gate - weights[f'{tllm_prex}.mlp.proj.weight'] = load( - prefix + key_list[7], 1) # mlp.proj - weights[f'{tllm_prex}.mlp.fc.weight'] = load( - prefix + key_list[8], 0) # mlp.fc - - else: - weights[f'{tllm_prex}.mlp.router.weight'] = load( - prefix + 'block_sparse_moe.gate.weight') - rank_experts = list(range(moe_config.num_experts)) - if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL: - rank_experts = mapping.ep_experts(moe_config.num_experts) - - expert_weight_list = [] - for suffix in range(3): - tp_dim = -1 - if moe_config.tp_mode == moe_config.ParallelismMode.TENSOR_PARALLEL: - tp_dim = 1 if suffix == 1 else 0 - expert_weight_list.append( - torch.stack( - list( - load( - prefix + - f'block_sparse_moe.experts.{expert}.w{suffix + 1}.weight', - tp_dim=tp_dim) for expert in rank_experts))) - - w1 = expert_weight_list[0] - w2 = expert_weight_list[1] - w3 = expert_weight_list[2] - - weights[f'{tllm_prex}.mlp.fc.weight'] = \ - torch.concat([w3, w1], dim=-2).contiguous() - weights[f'{tllm_prex}.mlp.proj.weight'] = w2.contiguous() - - weights[f'{tllm_prex}.input_layernorm.weight'] = load( - prefix + key_list[9]) # input_layernorm - weights[f'{tllm_prex}.post_layernorm.weight'] = load( - prefix + key_list[10]) # post_layernorm - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - logger.info(f'Weights loaded. Total time: {t}') - - return weights diff --git a/tensorrt_llm/models/medusa/config.py b/tensorrt_llm/models/medusa/config.py new file mode 100644 index 000000000..1e6df3c5a --- /dev/null +++ b/tensorrt_llm/models/medusa/config.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..llama.config import LLaMAConfig + + +class MedusaConfig(LLaMAConfig): + + def __init__(self, + *, + num_medusa_heads: int = 4, + num_medusa_layers: int = 1, + max_draft_len: int = 63, + **kwargs): + self.num_medusa_heads = num_medusa_heads + self.num_medusa_layers = num_medusa_layers + self.max_draft_len = max_draft_len + super().__init__(**kwargs) + + def to_dict(self): + output = super().to_dict() + # Serialize the fields added in MedusaConfig + output['num_medusa_heads'] = self.num_medusa_heads + output['num_medusa_layers'] = self.num_medusa_layers + output['max_draft_len'] = self.max_draft_len + return output diff --git a/tensorrt_llm/models/medusa/model.py b/tensorrt_llm/models/medusa/model.py index 016442657..689e8d822 100644 --- a/tensorrt_llm/models/medusa/model.py +++ b/tensorrt_llm/models/medusa/model.py @@ -21,6 +21,7 @@ from ...layers import ColumnLinear from ...mapping import Mapping from ...module import Module, ModuleList +from .config import MedusaConfig class MedusaLayer(Module): @@ -80,8 +81,9 @@ def forward(self, x): class MedusaForCausalLm(LLaMAForCausalLM): + config_class = MedusaConfig - def __init__(self, config): + def __init__(self, config: MedusaConfig): super().__init__(config) self.num_medusa_heads = config.num_medusa_heads diff --git a/tensorrt_llm/models/medusa/weight.py b/tensorrt_llm/models/medusa/weight.py index fd50f3ef6..251ea3edc 100644 --- a/tensorrt_llm/models/medusa/weight.py +++ b/tensorrt_llm/models/medusa/weight.py @@ -7,7 +7,7 @@ from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import MedusaLM -from tensorrt_llm.models.llama.weight import split +from tensorrt_llm.models.convert_utils import split def load_medusa_hf(medusa_path: str, diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 5ea7e0459..322320eee 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -12,8 +12,8 @@ import torch from .._common import default_net -from .._utils import (numpy_to_torch, release_gc, str_dtype_to_torch, - str_dtype_to_trt, trt_dtype_to_torch) +from .._utils import (get_init_params, numpy_to_torch, release_gc, + str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch) from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits from ..layers import (AttentionParams, Embedding, FusedGatedMLP, FusedRgLru, GatedMLP, KeyValueCacheParams, LoraParams, @@ -27,7 +27,6 @@ from ..module import Module, ModuleList from ..parameter import Parameter from ..quantization import QuantMode -from ..quantization.layers import FP8Linear from ..quantization.mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo from ..top_model_mixin import TopModelMixin from .convert_utils import weight_only_quantize_dict @@ -97,7 +96,11 @@ def quant_algo_to_modelopt_qformat(self): qformat = 'full_prec' return qformat - def asdict(self): + @classmethod + def from_dict(cls, config: dict): + return cls(**config) + + def to_dict(self): return dataclasses.asdict(self) @@ -118,159 +121,144 @@ def save_checkpoint(output_dir: str, config: dict, weights: dict) -> None: class PretrainedConfig: def __init__(self, + *, architecture: str, dtype: str, - logits_dtype: str, - vocab_size: int, - max_position_embeddings: int, hidden_size: int, num_hidden_layers: int, num_attention_heads: int, - num_key_value_heads: int, - hidden_act: str, - intermediate_size: int, - norm_epsilon: float, - position_embedding_type: str, - world_size: int, - tp_size: int, - pp_size: int, - gpus_per_node: int, - quantization: Union[QuantConfig, dict], + vocab_size: Optional[int] = None, + hidden_act: str = 'gelu', + logits_dtype: str = 'float32', + norm_epsilon: float = 1e-5, + position_embedding_type: Union[ + PositionEmbeddingType, + str] = PositionEmbeddingType.learned_absolute, + max_position_embeddings: Optional[int] = None, + num_key_value_heads: Optional[int] = None, + intermediate_size: Optional[int] = None, + mapping: Optional[Union[Mapping, dict]] = None, + quantization: Optional[Union[QuantConfig, dict]] = None, use_parallel_embedding: bool = False, embedding_sharding_dim: int = 0, share_embedding_table: bool = False, - head_size: int = None, + head_size: Optional[int] = None, qk_layernorm: bool = False, **kwargs): self.architecture = architecture self.dtype = dtype - self.logits_dtype = logits_dtype - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.head_size = hidden_size // num_attention_heads if head_size is None else head_size - self.qk_layernorm = qk_layernorm self.hidden_act = hidden_act - self.intermediate_size = intermediate_size + + self.logits_dtype = logits_dtype self.norm_epsilon = norm_epsilon - self.position_embedding_type = PositionEmbeddingType.from_string( - position_embedding_type) + + if isinstance(position_embedding_type, str): + position_embedding_type = PositionEmbeddingType.from_string( + position_embedding_type) + assert isinstance(position_embedding_type, PositionEmbeddingType) + self.position_embedding_type = position_embedding_type + + self.max_position_embeddings = max_position_embeddings + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + if intermediate_size is None: + intermediate_size = hidden_size * 4 + self.intermediate_size = intermediate_size + + if mapping is None: + mapping = Mapping() + elif isinstance(mapping, dict): + mapping = Mapping.from_dict(mapping) + assert isinstance(mapping, Mapping) + self.mapping = mapping + + if quantization is None: + quantization = QuantConfig() + elif isinstance(quantization, dict): + quantization = QuantConfig.from_dict(quantization) + assert isinstance(quantization, QuantConfig) + self.quantization = quantization + self.use_parallel_embedding = use_parallel_embedding self.embedding_sharding_dim = embedding_sharding_dim self.share_embedding_table = share_embedding_table - self.mapping = Mapping(world_size=world_size, - tp_size=tp_size, - pp_size=pp_size, - gpus_per_node=gpus_per_node) - if isinstance(quantization, dict): - self.quantization = dataclasses.replace(QuantConfig(), - **quantization) - else: - assert isinstance( - quantization, QuantConfig - ), f"Expecting type of QuantConfig, found {type(quantization)}" - self.quantization = quantization - self.kv_dtype = self.dtype - if self.quant_mode.has_int8_kv_cache(): - self.kv_dtype = 'int8' - elif self.quant_mode.has_fp8_kv_cache(): - self.kv_dtype = 'fp8' + + if share_embedding_table and mapping.tp_size > 1: + if (not use_parallel_embedding) or (use_parallel_embedding and + embedding_sharding_dim == 1): + raise NotImplementedError( + "For tensor parallelism, sharing the embedding table must set" \ + "use_parallel_embedding=True and embedding_sharding_dim=0" + ) + if share_embedding_table and mapping.pp_size > 1: + raise NotImplementedError( + "Embedding table cannot be shared for pipeline parallelism") + + if head_size is None: + head_size = hidden_size // num_attention_heads + self.head_size = head_size + self.qk_layernorm = qk_layernorm for key, value in kwargs.items(): try: setattr(self, key, value) + logger.warning( + f"Implicitly setting {self.__class__.__name__}.{key} = {value}" + ) except AttributeError as err: raise err + @property + def kv_dtype(self): + if self.quant_mode.has_int8_kv_cache(): + return 'int8' + elif self.quant_mode.has_fp8_kv_cache(): + return 'fp8' + else: + return self.dtype + def set_if_not_exist(self, key, value): if not hasattr(self, key): setattr(self, key, value) @classmethod - def from_dict(cls, config): - config = copy.deepcopy( - config - ) # many config.pop calls inside, make one local copy of the config dict such that the function has no side effects - architecture = config.pop('architecture') - dtype = config.pop('dtype') - vocab_size = config.pop('vocab_size', None) - hidden_size = config.pop('hidden_size') - num_hidden_layers = config.pop('num_hidden_layers') - num_attention_heads = config.pop('num_attention_heads') - hidden_act = config.pop('hidden_act', None) - norm_epsilon = config.pop('norm_epsilon', 1e-5) - position_embedding_type = config.pop('position_embedding_type', - 'learned_absolute') - logits_dtype = config.pop('logits_dtype', 'float32') - num_key_value_heads = config.pop('num_key_value_heads', - num_attention_heads) - intermediate_size = config.pop('intermediate_size', None) - max_position_embeddings = config.pop('max_position_embeddings', None) - use_parallel_embedding = config.pop('use_parallel_embedding', False) - embedding_sharding_dim = config.pop('embedding_sharding_dim', 0) - share_embedding_table = config.pop('share_embedding_table', False) - - mapping = config.pop('mapping', {}) - world_size = mapping.get('world_size', 1) - tp_size = mapping.get('tp_size', 1) - pp_size = mapping.get('pp_size', 1) - gpus_per_node = mapping.get('gpus_per_node', 8) - - if share_embedding_table and tp_size > 1: - if (not use_parallel_embedding) or (use_parallel_embedding and - embedding_sharding_dim == 1): - raise NotImplementedError( - "For tensor parallelism, sharing the embedding table must set" \ - "use_parallel_embedding=True and embedding_sharding_dim=0" - ) - if share_embedding_table and pp_size > 1: - raise NotImplementedError( - "Embedding table cannot be shared for pipeline parallelism") + def from_dict(cls, config: dict): + # Maybe we need AutoConfig for this + from . import MODEL_MAP + model_cls = MODEL_MAP[config['architecture']] + config_cls = getattr(model_cls, 'config_class', cls) + return config_cls(**config) - quant_config = QuantConfig() + def to_dict(self): + output = copy.deepcopy(self.__dict__) - if 'quantization' in config: - # override the default quantization object from the given dict, allows user to specify partial set of the fields - quant_config_from_user = config.pop('quantization') - if isinstance(quant_config_from_user, dict): - quant_config = dataclasses.replace(quant_config, - **quant_config_from_user) - # allow user to directly pass one QuantConfig object - else: - assert isinstance(quant_config_from_user, QuantConfig) - quant_config = quant_config_from_user + output['position_embedding_type'] = str(self.position_embedding_type) + output['mapping'] = self.mapping.to_dict() + output['mapping'].pop('rank') + output['quantization'] = self.quantization.to_dict() - return cls(architecture, dtype, logits_dtype, vocab_size, - max_position_embeddings, hidden_size, num_hidden_layers, - num_attention_heads, num_key_value_heads, hidden_act, - intermediate_size, norm_epsilon, position_embedding_type, - world_size, tp_size, pp_size, gpus_per_node, quant_config, - use_parallel_embedding, embedding_sharding_dim, - share_embedding_table, **config) + return output @classmethod def from_json_file(cls, config_file: str): with open(config_file) as f: config = json.load(f) - return PretrainedConfig.from_dict(config) + return cls.from_dict(config) - def to_dict(self): - output = copy.deepcopy(self.__dict__) + @classmethod + def from_checkpoint(cls, ckpt_dir: str): + return cls.from_json_file(os.path.join(ckpt_dir, 'config.json')) - output['position_embedding_type'] = str(self.position_embedding_type) - output['mapping'] = { - 'world_size': self.mapping.world_size, - 'tp_size': self.mapping.tp_size, - 'pp_size': self.mapping.pp_size, - 'gpus_per_node': self.mapping.gpus_per_node, - } - output['quantization'] = dataclasses.asdict(self.quantization) - - return output + def to_json_file(self, config_file: str): + with open(config_file, 'w') as f: + json.dump(self.to_dict(), f, indent=4) @property def quant_mode(self): @@ -368,9 +356,16 @@ def __init__(self, config: PretrainedConfig): def __post_init__(self): from ..quantization.quantize import quantize - quantize(self, self.config.quantization) + # Currently, use_parallel_embedding and share_embedding_table must be enabled before weight loading; + # otherwise, the model will be inconsistent with the weights loaded from checkpoint. + optimize_model( + self, + use_parallel_embedding=self.config.use_parallel_embedding, + share_embedding_table=self.config.share_embedding_table, + ) + def release(self): release_gc() @@ -389,23 +384,32 @@ def from_config(cls, config: PretrainedConfig): @classmethod def from_checkpoint(cls, ckpt_dir: str, - rank: int = 0, - config: PretrainedConfig = None): + rank: Optional[int] = None, + config: Optional[PretrainedConfig] = None): if config is None: config = PretrainedConfig.from_json_file( os.path.join(ckpt_dir, 'config.json')) + + if rank is not None: config.set_rank(rank) - model = cls.from_config(config) - - weights = {} - with safetensors.safe_open(os.path.join(ckpt_dir, - f'rank{rank}.safetensors'), - framework='pt', - device='cpu') as f: - for key in f.keys(): - weights[key] = f.get_tensor(key) - preprocess_weights(weights, config) - model.load(weights) + + model = cls(config) + weights = None + if config.architecture in WEIGHT_LOADER_MODELS: + weights_path = os.path.join(ckpt_dir, 'rank0.safetensors') + else: + rank = config.mapping.rank + weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors') + + assert os.path.isfile(weights_path) + weights = safetensors.torch.load_file(weights_path) + + is_checkpoint_pruned = getattr(config, 'is_pruned', False) + if weights is not None: + preprocess_weights(weights, + config, + from_pruned=is_checkpoint_pruned) + model.load(weights, from_pruned=is_checkpoint_pruned) return model @@ -472,11 +476,10 @@ def save_checkpoint(self, output_dir, save_config=True): name: numpy_to_torch(param.raw_value) for name, param in self.named_parameters() } - from safetensors.torch import save_file - save_file(weights, os.path.join(output_dir, f'rank{rank}.safetensors')) + safetensors.torch.save_file( + weights, os.path.join(output_dir, f'rank{rank}.safetensors')) if save_config: - with open(os.path.join(output_dir, 'config.json'), 'w') as f: - json.dump(self.config.to_dict(), f, indent=4) + self.config.to_json_file(os.path.join(output_dir, 'config.json')) def prepare_inputs(self, max_batch_size, @@ -605,12 +608,12 @@ def prepare_inputs(self, @classmethod def quantize( cls, - hf_model_dir, - output_dir, - quant_config: QuantConfig, - *, - dtype='float16', + hf_model_dir: str, + output_dir: str, + dtype: str = 'float16', mapping: Optional[Mapping] = None, + quant_config: Optional[QuantConfig] = None, + *, calib_dataset='cnn_dailymail', calib_batches=512, calib_batch_size=1, @@ -650,6 +653,8 @@ def __init__(self, config: PretrainedConfig, transformer, lm_head): super().__init__(config) self.transformer = transformer self.lm_head = lm_head + config.set_if_not_exist('mup_width_multiplier', 1.0) + self.mup_width_multiplier = config.mup_width_multiplier def forward(self, input_ids: Tensor, @@ -699,6 +704,10 @@ def forward(self, # [batch_size, hidden_size] -> [batch_size, vocab_size] lm_logits = self.lm_head(hidden_states) + if hasattr(self.config, 'output_multiplier_scale'): + lm_logits *= getattr(self.config, 'output_multiplier_scale', 1) + if self.mup_width_multiplier is not None: + lm_logits = lm_logits / self.mup_width_multiplier lm_logits.mark_output('logits', self.config.logits_dtype) else: hidden_states.mark_output('hidden_states_output', self.config.dtype) @@ -718,41 +727,35 @@ def forward(self, return hidden_states -def fuse_gate_mlp(model): +def fuse_gate_mlp( + model: PretrainedModel, + gemm_swiglu_plugin_dtype: Optional[str] = None, +) -> PretrainedModel: from ..quantization.quantize import fp8_quantize - for layer in model.transformer.layers: - if not hasattr(layer, 'mlp'): - continue - - quant_algo = model.config.quantization.quant_algo - if isinstance(layer.mlp, GatedMLP): - fused_layer = FusedGatedMLP( - hidden_size=layer.mlp.hidden_size, - ffn_hidden_size=layer.mlp.ffn_hidden_size, - hidden_act=layer.mlp.hidden_act, - bias=layer.mlp.bias, - dtype=layer.mlp.dtype, - tp_group=layer.mlp.tp_group, - tp_size=layer.mlp.tp_size, - quant_mode=layer.mlp.quant_mode) + quant_algo = model.config.quantization.quant_algo + for name, mlp, layer in model.named_modules_with_parent(): + if isinstance(mlp, GatedMLP): + init_params = get_init_params(mlp) + init_params["inner_layernorm"] = mlp.inner_layernorm is not None + fused_layer = FusedGatedMLP(**init_params) if quant_algo == QuantAlgo.FP8: fused_layer = fp8_quantize(fused_layer, model.config.quantization) - if isinstance(layer.mlp.dtype, str): - dtype = str_dtype_to_torch(layer.mlp.dtype) + if isinstance(mlp.dtype, str): + dtype = str_dtype_to_torch(mlp.dtype) else: - dtype = trt_dtype_to_torch(layer.mlp.dtype) + dtype = trt_dtype_to_torch(mlp.dtype) # dequantize gate_weight = numpy_to_torch( - layer.mlp.gate.weight.raw_value).to(dtype) * numpy_to_torch( - layer.mlp.gate.weights_scaling_factor.raw_value) + mlp.gate.weight.raw_value).to(dtype) * numpy_to_torch( + mlp.gate.weights_scaling_factor.raw_value) fc_weight = numpy_to_torch( - layer.mlp.fc.weight.raw_value).to(dtype) * numpy_to_torch( - layer.mlp.fc.weights_scaling_factor.raw_value) + mlp.fc.weight.raw_value).to(dtype) * numpy_to_torch( + mlp.fc.weights_scaling_factor.raw_value) # concat fused_weight = torch.cat([gate_weight, fc_weight], dim=0) @@ -760,37 +763,50 @@ def fuse_gate_mlp(model): # quantize fused_weight_scaling_factor = numpy_to_torch( max( - layer.mlp.gate.weights_scaling_factor.raw_value, - layer.mlp.fc.weights_scaling_factor.raw_value, + mlp.gate.weights_scaling_factor.raw_value, + mlp.fc.weights_scaling_factor.raw_value, )) fused_weight = (fused_weight / fused_weight_scaling_factor).to( torch.float8_e4m3fn) - fused_layer.fused_fc.weight.value = fused_weight + if gemm_swiglu_plugin_dtype == 'fp8': + # gemm_swiglu_plugin needs (k, n) weights + # but weights should still be k-major for fp8 + fused_layer.fused_fc.weight = Parameter( + shape=(fused_layer.fused_fc.in_features, + fused_layer.fused_fc.out_features), + dtype='fp8') + fused_layer.fused_fc.weight.value = fused_weight.view( + fused_layer.fused_fc.in_features, + fused_layer.fused_fc.out_features) + else: + fused_layer.fused_fc.weight.value = fused_weight fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor - fused_layer.fused_fc.activation_scaling_factor.value = \ - max(layer.mlp.gate.activation_scaling_factor.raw_value, - layer.mlp.fc.activation_scaling_factor.raw_value - ) + fused_layer.fused_fc.activation_scaling_factor.value = max( + mlp.gate.activation_scaling_factor.raw_value, + mlp.fc.activation_scaling_factor.raw_value, + ) elif quant_algo is None: - fused_layer.fused_fc.weight.value = np.concatenate([ - layer.mlp.gate.weight.raw_value, - layer.mlp.fc.weight.raw_value - ], - axis=0) - if layer.mlp.bias: - fused_layer.fused_fc.bias.value = np.concatenate([ - layer.mlp.gate.bias.raw_value, - layer.mlp.fc.bias.raw_value + fused_layer.fused_fc.weight.value = np.concatenate( + [ + mlp.gate.weight.raw_value, + mlp.fc.weight.raw_value, ], - axis=0) + axis=0, + ) + if mlp.bias: + fused_layer.fused_fc.bias.value = np.concatenate( + [mlp.gate.bias.raw_value, mlp.fc.bias.raw_value], + axis=0) else: raise ValueError(f'Unsupported quant algo: {quant_algo}') - fused_layer.proj = layer.mlp.proj + fused_layer.proj = mlp.proj + fused_layer.inner_layernorm = mlp.inner_layernorm - layer.mlp = fused_layer + mlp_name = name.rsplit('.', 1)[-1] + setattr(layer, mlp_name, fused_layer) return model @@ -798,31 +814,40 @@ def fuse_gate_mlp(model): def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel: '''Split all the models' Attention layer's QKV GEMM into 3 GEMMs layer.q layer.k, layer.v and return the changed model ''' - for name, layer in model.named_modules(remove_duplicate=True): + from ..quantization.quantize import quantize + + for name, layer in model.named_modules(): if isinstance(layer, Attention) and not layer.cross_attention: assert layer.tp_size == 1, "please disable manual tp when enable auto parallel" - if layer.unfuse_qkv_gemm: + if layer.qkv is None: continue - layer.unfuse_qkv_gemm = True - linear_class = FP8Linear if layer.quant_mode.has_fp8_qdq( - ) else ColumnLinear - q = linear_class(layer.hidden_size, - layer.attention_hidden_size, - bias=layer.bias, - dtype=layer.dtype, - gather_output=False) - k = linear_class(layer.hidden_size, - layer.num_attention_kv_heads * - layer.attention_head_size, - bias=layer.bias, - dtype=layer.dtype, - gather_output=False) - v = linear_class(layer.hidden_size, - layer.num_attention_kv_heads * - layer.attention_head_size, - bias=layer.bias, - dtype=layer.dtype, - gather_output=False) + qkv_params = get_init_params(layer.qkv, ColumnLinear) + qkv_params["bias"] = qkv_params["bias"] is not None + qkv_params["strict_dtype"] = qkv_params["strict_dtype"] is not None + q = ColumnLinear( + **{ + **qkv_params, + "out_features": + layer.tp_size * layer.num_attention_heads * + layer.attention_head_size, + }) + k = ColumnLinear( + **{ + **qkv_params, + "out_features": + layer.tp_size * layer.num_attention_kv_heads * + layer.attention_head_size, + }) + v = ColumnLinear( + **{ + **qkv_params, + "out_features": + layer.tp_size * layer.num_attention_kv_heads * + layer.attention_head_size, + }) + q = quantize(q, model.config.quantization) + k = quantize(k, model.config.quantization) + v = quantize(v, model.config.quantization) if layer.qkv.weight.is_inited(): qkv_weight = layer.qkv.weight.raw_value weights = np.split(qkv_weight, [ @@ -851,53 +876,42 @@ def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel: def fuse_rg_lru(model: PretrainedModel) -> PretrainedModel: - for layer in model.transformer.layers: - if not hasattr(layer, 'recurrent'): - continue - - if isinstance(layer.recurrent.rg_lru, RgLru): - rg_lru = layer.recurrent.rg_lru - fused_layer = FusedRgLru(lru_width=rg_lru.lru_width, - num_heads=rg_lru.num_heads, - dtype=rg_lru.dtype, - tp_group=rg_lru.tp_group, - tp_size=rg_lru.tp_size, - tp_rank=rg_lru.tp_rank) - - fused_layer.gate.weight.value = np.concatenate([ - rg_lru.input_gate.weight.raw_value, - rg_lru.recurrent_gate.weight.raw_value - ], - axis=-1) - fused_layer.gate.bias.value = np.concatenate([ - rg_lru.input_gate.bias.raw_value, - rg_lru.recurrent_gate.bias.raw_value - ], - axis=-1) + for name, rg_lru, parent in model.named_modules_with_parent(): + if isinstance(rg_lru, RgLru): + fused_layer = FusedRgLru(**get_init_params(rg_lru)) + fused_layer.gate.weight.value = np.concatenate( + [ + rg_lru.input_gate.weight.raw_value, + rg_lru.recurrent_gate.weight.raw_value, + ], + axis=-1, + ) + fused_layer.gate.bias.value = np.concatenate( + [ + rg_lru.input_gate.bias.raw_value, + rg_lru.recurrent_gate.bias.raw_value, + ], + axis=-1, + ) fused_layer.recurrent_param.value = rg_lru.recurrent_param.raw_value - layer.recurrent.rg_lru = fused_layer + rg_lru_name = name.rsplit('.', 1)[-1] + setattr(parent, rg_lru_name, fused_layer) return model -def set_prompt_tuning( - model: DecoderModelForCausalLM) -> DecoderModelForCausalLM: +def set_prompt_tuning(model: PretrainedModel) -> PretrainedModel: '''Replace the given models embedding layer with a PromptTuningEmbedding layer in-place, return the changed model - Pre-conditions: model.transformer.vocab_embedding exists - Post-conditions: isinstance(model.transformer.vocab_embedding, PromptTuningEmbedding) + Pre-conditions: vocab_embedding exists + Post-conditions: isinstance(vocab_embedding, PromptTuningEmbedding) ''' - if isinstance(model.transformer.vocab_embedding, Embedding): - embedding = model.transformer.vocab_embedding - model.transformer.vocab_embedding = PromptTuningEmbedding( - num_embeddings=embedding.num_embeddings, - embedding_dim=embedding.embedding_dim, - dtype=embedding.dtype, - tp_size=embedding.tp_size, - tp_group=embedding.tp_group, - sharding_dim=embedding.sharding_dim, - tp_rank=embedding.tp_rank) - - model.transformer.vocab_embedding.weight.value = embedding.weight.raw_value + for name, embedding, parent in model.named_modules_with_parent(): + layer_name = name.rsplit('.', 1)[-1] + if layer_name == "vocab_embedding" and isinstance(embedding, Embedding): + ptuning_embedding = PromptTuningEmbedding( + **get_init_params(embedding)) + ptuning_embedding.weight.value = embedding.weight.raw_value + parent.vocab_embedding = ptuning_embedding return model @@ -905,7 +919,7 @@ def add_lora(model: PretrainedModel, max_lora_rank: Optional[int]) -> PretrainedModel: ''' Add lora layers to the Attention/BertAttention/Linear/RowLinear/FusedGatedMLP layers to the given model, return the changed model ''' - for name, layer in model.named_modules(remove_duplicate=True): + for name, layer in model.named_modules(): max_rank = max_lora_rank if isinstance(layer, (Attention, BertAttention)): if max_rank is None: @@ -934,7 +948,7 @@ def add_lora(model: PretrainedModel, if max_rank is None: max_rank = min(layer.hidden_size, layer.ffn_hidden_size // layer.tp_size) - layer.mlp_in_lora = Lora( + layer.lora = Lora( in_hidden_size=layer.hidden_size, out_hidden_sizes=[ layer.ffn_hidden_size // layer.tp_size, @@ -948,79 +962,97 @@ def add_lora(model: PretrainedModel, def to_ootb_moe(model: PretrainedModel) -> PretrainedModel: ''' Use OOTB MoE instead of MoE plugin, return the changed model ''' - for name, module in model.named_modules(remove_duplicate=True): - if not hasattr(module, 'mlp'): - continue - layer = module.mlp + for name, layer, parent in model.named_modules_with_parent(): if isinstance(layer, MOE): - module.mlp = layer.to(MoeOOTB, model.config) + layer_name = name.rsplit('.', 1)[-1] + ootb_layer = layer.to(MoeOOTB, model.config) + setattr(parent, layer_name, ootb_layer) return model -def parallelize_embedding(model: DecoderModelForCausalLM): - if model.config.mapping.is_first_pp_rank(): - for name, module in model.transformer.named_children(): - if name.endswith('embedding') and isinstance(module, Embedding): - assert module.tp_group is None, "The embedding has already been parallelized." - model.transformer._modules[name] = module.__class__( - module.num_embeddings, - module.embedding_dim, - dtype=module.dtype, - tp_group=model.config.mapping.tp_group, - tp_size=model.config.mapping.tp_size, - sharding_dim=model.config.embedding_sharding_dim, - tp_rank=model.config.mapping.tp_rank) - +def parallelize_embedding(model: PretrainedModel) -> PretrainedModel: + for name, embedding, parent in model.named_modules_with_parent(): + layer_name = name.rsplit('.', 1)[-1] + if isinstance(embedding, Embedding) and embedding.tp_group is None: + init_params = get_init_params(embedding) + init_params["tp_group"] = model.config.mapping.tp_group + init_params["tp_size"] = model.config.mapping.tp_size + init_params["tp_rank"] = model.config.mapping.tp_rank + init_params["sharding_dim"] = model.config.embedding_sharding_dim + new_embedding = embedding.__class__(**init_params) + setattr(parent, layer_name, new_embedding) return model -def share_embedding(model: DecoderModelForCausalLM): - model.lm_head.weight = model.transformer.vocab_embedding.weight - if hasattr( - model.transformer.vocab_embedding, "per_token_scale" - ) and model.transformer.vocab_embedding.per_token_scale is not None: - model.lm_head.per_channel_scale = model.transformer.vocab_embedding.per_token_scale +def share_embedding(model: PretrainedModel) -> PretrainedModel: + lm_head = None + vocab_embedding = None + for name, layer in model.named_modules(): + layer_name = name.rsplit('.', 1)[-1] + if layer_name == "lm_head": + lm_head = layer + if layer_name == "vocab_embedding": + vocab_embedding = layer + if lm_head is not None and vocab_embedding is not None: + break + + if lm_head is not None and vocab_embedding is not None: + lm_head.weight = vocab_embedding.weight + if (hasattr(vocab_embedding, "per_token_scale") + and vocab_embedding.per_token_scale is not None): + lm_head.per_channel_scale = vocab_embedding.per_token_scale return model -def set_fp8_context_fhma(model: DecoderModelForCausalLM): - for layer in model.transformer.layers: - scale = [1.0 - ] / layer.attention.dense.activation_scaling_factor.raw_value - layer.attention.attention_output_orig_quant_scale = Parameter( - value=scale.astype(np.float32)) +def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel: + for name, layer in model.named_modules(): + if isinstance(layer, Attention): + scale = [1.0] / layer.dense.activation_scaling_factor.raw_value + layer.attention_output_orig_quant_scale = Parameter( + value=scale.astype(np.float32)) return model -def optimize_model(model: DecoderModelForCausalLM, - use_parallel_embedding: bool = False, - share_embedding_table: bool = False, - use_fused_mlp: bool = False, - use_unfused_qkv_gemm: bool = False, - use_ootb_moe: bool = False, - use_prompt_tuning: bool = False, - use_lora: bool = False, - max_lora_rank: Optional[int] = None, - use_fp8_context_fmha: bool = False, - use_fused_rg_lru: bool = False): +def optimize_model( + model: PretrainedModel, + use_parallel_embedding: bool = False, + share_embedding_table: bool = False, + use_ootb_moe: bool = False, + use_fused_mlp: bool = False, + gemm_swiglu_plugin_dtype: Optional[str] = None, + use_fused_rg_lru: bool = False, + use_unfused_qkv_gemm: bool = False, + use_prompt_tuning: bool = False, + use_lora: bool = False, + max_lora_rank: Optional[int] = None, + use_fp8_context_fmha: bool = False, +) -> PretrainedModel: + """ + Run optimization passes on model. + There are dependencies between some passes, + so we always run passes in the order of arguments to guarantee the execution order. + """ + # before weight loading if use_parallel_embedding: model = parallelize_embedding(model) if share_embedding_table: model = share_embedding(model) + + # After weight loading + if use_ootb_moe: + model = to_ootb_moe(model) if use_fused_mlp: - model = fuse_gate_mlp(model) + model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype) + if use_fused_rg_lru: + model = fuse_rg_lru(model) if use_unfused_qkv_gemm: model = unfuse_qkv_gemm(model) - if use_ootb_moe: - model = to_ootb_moe(model) if use_prompt_tuning: model = set_prompt_tuning(model) if use_lora: model = add_lora(model, max_lora_rank) if use_fp8_context_fmha: model = set_fp8_context_fhma(model) - if use_fused_rg_lru: - model = fuse_rg_lru(model) return model @@ -1122,68 +1154,15 @@ def preprocess_weights(weights: Dict[str, torch.Tensor], if 'attention.dense.bias' in name or 'mlp.proj.bias' in name: weights[name] = torch.zeros_like(param) - -def load_model( - model_config: Optional[PretrainedConfig] = None, - ckpt_dir: Optional[str] = None, - model_cls: Optional[type[PretrainedModel]] = None, -): - from . import MODEL_MAP - - assert model_config is not None or ckpt_dir is not None, "must provide either model_config or ckpt_dir" - - if model_config is None: - model_config = PretrainedConfig.from_json_file( - os.path.join(ckpt_dir, 'config.json')) - - architecture = model_config.architecture - - if model_cls is None: - if architecture not in MODEL_MAP: - raise RuntimeError( - f'Unsupported model architecture: {architecture}') - model_cls = MODEL_MAP[architecture] - - # TODO: use PretrainedModel.from_checkpoint instead after PretrainedModel becomes base class of all models. - model = model_cls.from_config(model_config) - weights = None - if ckpt_dir is not None: - if model_config.architecture in WEIGHT_LOADER_MODELS: - model_path = os.path.join(ckpt_dir, 'rank0.safetensors') - else: - rank = model_config.mapping.rank - model_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors') - - if os.path.isfile(model_path): - weights = {} - with safetensors.safe_open(model_path, framework='pt', - device='cpu') as f: - for key in f.keys(): - weights[key] = f.get_tensor(key) - else: - logger.warning( - f"Cannot find {model_path}. Use dummy model weights.") - - if model_config.share_embedding_table and weights: + # For share_embedding_table + if model_config.share_embedding_table: if "lm_head.weight" in weights and "transformer.vocab_embedding.weight" in weights: - assert not ( - weights["lm_head.weight"] - - weights["transformer.vocab_embedding.weight"] - ).any( - ), "When share_embedding_table is enabled, lm_head.weight and transformer.vocab_embedding.weight must be same." - - # Currently, use_parallel_embedding and share_embedding_table should be enabled before weight loading; - # otherwise, the model will be inconsistent with the weights loaded from checkpoint. - model = optimize_model( - model, - use_parallel_embedding=model_config.use_parallel_embedding, - share_embedding_table=model_config.share_embedding_table, - ) - is_checkpoint_pruned = getattr(model_config, 'is_pruned', False) - if weights is not None: - preprocess_weights(weights, - model_config, - from_pruned=is_checkpoint_pruned) - model.load(weights, from_pruned=is_checkpoint_pruned) - - return model + if (weights["lm_head.weight"] - + weights["transformer.vocab_embedding.weight"]).any(): + logger.warning( + "lm_head.weight and transformer.vocab_embedding.weight are not identical, " + "share_embedding_table cannot be enabled; setting share_embedding_table=False." + ) + model_config.share_embedding_table = False + else: + weights.pop("lm_head.weight") diff --git a/tensorrt_llm/models/phi3/phi3small/__init__.py b/tensorrt_llm/models/phi3/phi3small/__init__.py new file mode 100644 index 000000000..71bf6d298 --- /dev/null +++ b/tensorrt_llm/models/phi3/phi3small/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorrt_llm/models/phi3/phi3small/convert.py b/tensorrt_llm/models/phi3/phi3small/convert.py new file mode 100644 index 000000000..c11824876 --- /dev/null +++ b/tensorrt_llm/models/phi3/phi3small/convert.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from tensorrt_llm.quantization import QuantAlgo + +from ...._utils import str_dtype_to_torch + + +def shuffle_qkv_weights(weights, config): + # Input weights are organized as + # (q00, q01, ... q0m, k0, v0), (q10, q11, ... q1m, k1, v1), ... (qn0, qn1, ... qnm, kn, vn) + # where n = num_kv_heads, m = num_attention_heads // num_kv_heads (i.e. #q_heads per kv_head) + # + # Output weights will be organized as + # (q00, q01, ..., qnm), (k0, k1, .., kn), (v0, v1, .., vn) + + num_heads = config['num_attention_heads'] + num_kv_heads = config['num_kv_heads'] if 'num_kv_heads' in config.keys( + ) else config['num_key_value_heads'] + num_q_per_kv = num_heads // num_kv_heads + + hidden_size = config['hidden_size'] + head_dim = hidden_size // num_heads + + input_shape = weights.shape + if weights.dim() < 2: + weights = weights.unsqueeze(1) + + weights = weights.reshape(num_kv_heads, (num_q_per_kv + 2), head_dim, + weights.shape[-1]) + q = weights[:, :-2, :, :] + k = weights[:, -2, :, :] + v = weights[:, -1, :, :] + + # num_heads x head_dim x hidden_size + q = q.reshape(-1, q.shape[2], q.shape[3]) + + # num_heads + (2 * num_kv_heads) x head_dim x hidden_size + weights = torch.cat([q, k, v], dim=0) + weights = weights.reshape(-1, weights.shape[2]) + + weights = weights.squeeze() + assert input_shape == weights.shape + + return weights + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return torch.chunk(v, tp_size)[idx].contiguous() + else: + return torch.chunk(v, tp_size, dim=dim)[idx].contiguous() + + +def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank): + """ + Splits the QKV matrix according to tensor parallelism + """ + v = v.reshape(3, n_hidden, n_hidden) + split_v = split(v, tensor_parallel, rank, dim=1) + split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden) + return split_v.contiguous() + + +def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): + """ + Splits the QKV bias according to tensor parallelism + """ + v = v.reshape(3, n_hidden) + split_v = split(v, tensor_parallel, rank, dim=1) + split_v = split_v.reshape(3 * (n_hidden // tensor_parallel)) + return split_v.contiguous() + + +def split_matrix_tp(v, tensor_parallel, rank, dim): + return split(v, tensor_parallel, rank, dim=dim) + + +def split_embedding( + param: torch.Tensor, + tp_size: int, + tp_rank: int, + use_parallel_embedding: bool = False, + sharding_dim: int = 0, +) -> torch.Tensor: + if param is None: + return None + if not use_parallel_embedding: + return param + + vocab_size, hidden_size = param.size() + if sharding_dim == 0: + if vocab_size % tp_size != 0: + vocab_size_padded = pad_vocab_size(vocab_size, tp_size) + pad_width = vocab_size_padded - vocab_size + param = torch.nn.functional.pad(param, (0, 0, 0, pad_width), + value=0) + else: + assert hidden_size % tp_size == 0 + return split(param, tp_size, tp_rank, dim=sharding_dim) + + +def get_weight(config, prefix, dtype): + return config[prefix + '.weight'].to(dtype).detach() + + +def get_bias(config, prefix, dtype): + return config[prefix + '.bias'].to(dtype).detach() + + +def get_weight_and_bias(config, prefix, dtype): + return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype) + + +def get_tllm_linear_weight(weight, + prefix, + bias=None, + use_weight_only=False, + plugin_weight_only_quant_type=torch.int8): + results = {} + if use_weight_only: + v = weight.t().contiguous() + processed_torch_weights, torch_weight_scales = \ + torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( + v, plugin_weight_only_quant_type) + results[prefix + '.weight'] = processed_torch_weights + results[prefix + '.per_channel_scale'] = torch_weight_scales + else: + results[prefix + '.weight'] = weight.contiguous() + + if bias is not None: + results[prefix + '.bias'] = bias + + return results + + +def split_weights_tp(config, weights, args, rank, dtype): + num_heads = config['num_attention_heads'] + num_kv_heads = config['num_kv_heads'] + hidden_size = config['hidden_size'] + + mha_mode = num_heads == num_kv_heads + tp_size = args.tp_size + + use_weight_only = args.use_weight_only + plugin_weight_only_quant_type = None + if use_weight_only and args.weight_only_precision == 'int8': + plugin_weight_only_quant_type = torch.int8 + elif use_weight_only and args.weight_only_precision == 'int4': + plugin_weight_only_quant_type = torch.quint4x2 + + # Helper + def get_weight(weight, prefix, bias): + return get_tllm_linear_weight(weight, prefix, bias, use_weight_only, + plugin_weight_only_quant_type) + + for layer_id in range(config['num_hidden_layers']): + layer_prefix = f"transformer.layers.{layer_id}." + + prefix = layer_prefix + 'attention.qkv' + qkv_weight, qkv_bias = get_weight_and_bias(weights, prefix, dtype) + + if not mha_mode: + num_q_per_kv = num_heads // num_kv_heads + + qkv_weight = qkv_weight.reshape(num_q_per_kv + 2, -1, hidden_size) + q = qkv_weight[:num_q_per_kv, :, :].reshape(-1, hidden_size) + k = qkv_weight[num_q_per_kv:num_q_per_kv + 1, :, :].reshape( + -1, hidden_size) + v = qkv_weight[num_q_per_kv + 1:num_q_per_kv + 2, :, :].reshape( + -1, hidden_size) + split_weight = torch.cat( + [split(x, tp_size, rank) for x in [q, k, v]], dim=0) + + qkv_bias = qkv_bias.reshape(num_q_per_kv + 2, -1) + q = qkv_bias[:num_q_per_kv, :].reshape(-1) + k = qkv_bias[num_q_per_kv:num_q_per_kv + 1, :].reshape(-1) + v = qkv_bias[num_q_per_kv + 1:num_q_per_kv + 2, :].reshape(-1) + split_bias = torch.cat([split(x, tp_size, rank) for x in [q, k, v]], + dim=0) + else: + split_weight = split_qkv_tp(qkv_weight, num_heads, hidden_size, + tp_size, rank) + split_bias = split_qkv_bias_tp(qkv_bias, num_heads, hidden_size, + tp_size, rank) + + weights.update(get_weight(split_weight, prefix, split_bias)) + + prefix = layer_prefix + 'attention.dense' + attn_dense_weight, attn_dense_bias = get_weight_and_bias( + weights, prefix, dtype) + split_v = split_matrix_tp(attn_dense_weight, tp_size, rank, dim=1) + weights.update(get_weight(split_v, prefix, attn_dense_bias)) + + prefix = layer_prefix + 'mlp.fc' + mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(weights, prefix, dtype) + split_v = split_matrix_tp(mlp_fc_weight, tp_size, rank, dim=0) + bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0) + weights.update(get_weight(split_v, prefix, bias)) + + prefix = layer_prefix + 'mlp.proj' + mlp_proj_weight, mlp_proj_bias = get_weight_and_bias( + weights, prefix, dtype) + split_v = split_matrix_tp(mlp_proj_weight, tp_size, rank, dim=1) + weights.update(get_weight(split_v, prefix, mlp_proj_bias)) + + weights['transformer.vocab_embedding.weight'] = split_embedding( + weights['transformer.vocab_embedding.weight'], tp_size, rank) + weights['lm_head.weight'] = split_matrix_tp(weights['lm_head.weight'], + tp_size, + rank, + dim=0) + + return weights + + +def convert_hf_weights(hf_model, config, args, rank): + torch_dtype = str_dtype_to_torch(args.dtype) + hf_state_dict = hf_model.state_dict() + weights = {} + + # replace key name + for key, value in hf_state_dict.items(): + # Decoder Layers + if "model.layers." in key: + key = key.replace("model.layers.", "transformer.layers.") + key = key.replace("self_attn.", "attention.") + key = key.replace("query_key_value.", "qkv.") + key = key.replace("mlp.up_proj.", "mlp.fc.") + key = key.replace("mlp.down_proj.", "mlp.proj.") + key = key.replace("post_attention_layernorm.", "post_layernorm.") + # Embedding + key = key.replace("model.embed_tokens.weight", + "transformer.vocab_embedding.weight") + # Final Layer norm + key = key.replace("model.final_layernorm.", "transformer.ln_f.") + weights[key] = value.to(torch_dtype).cpu() + + weights['lm_head.weight'] = weights[ + 'transformer.vocab_embedding.weight'].clone() + + # Transform QKV weights from custom Phi3Small format to TRT-LLM format + for key, value in weights.items(): + if "qkv." in key: + weights[key] = shuffle_qkv_weights(weights[key], config) + + weights = split_weights_tp(config, weights, args, rank, torch_dtype) + + return weights + + +def convert_hf_config(hf_config, dtype, args): + config = { + 'architecture': 'Phi3SmallForCausalLM', + 'dtype': dtype, + 'num_hidden_layers': hf_config.num_hidden_layers, + 'num_attention_heads': hf_config.num_attention_heads, + 'num_kv_heads': hf_config.num_key_value_heads, + 'rotary_embedding_base': hf_config.rope_embedding_base, + 'hidden_size': hf_config.hidden_size, + 'intermediate_size': hf_config.intermediate_size, + 'vocab_size': hf_config.vocab_size, + 'max_position_embeddings': hf_config.max_position_embeddings, + 'hidden_act': hf_config.hidden_act, + 'share_embedding_table': False, + 'gegelu_limit': hf_config.gegelu_limit, + 'mup_attn_multiplier': hf_config.mup_attn_multiplier, + 'mup_embedding_multiplier': hf_config.mup_embedding_multiplier, + 'mup_use_scaling': hf_config.mup_use_scaling, + 'mup_width_multiplier': hf_config.mup_width_multiplier, + 'blocksparse_block_size': hf_config.blocksparse_block_size, + 'blocksparse_homo_head_pattern': + hf_config.blocksparse_homo_head_pattern, + 'blocksparse_num_local_blocks': hf_config.blocksparse_num_local_blocks, + 'blocksparse_vertical_stride': hf_config.blocksparse_vert_stride, + 'dense_attention_every_n_layers': + hf_config.dense_attention_every_n_layers, + } + + if args is not None: + config.update({ + 'mapping': { + 'world_size': args.tp_size * args.pp_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + } + }) + + if args.use_weight_only and args.weight_only_precision == 'int8': + config.update({'quantization': {'quant_algo': QuantAlgo.W8A16}}) + elif args.use_weight_only and args.weight_only_precision == 'int4': + config.update({'quantization': {'quant_algo': QuantAlgo.W4A16}}) + + if hf_config.max_position_embeddings >= 128000: + config.update({ + 'original_max_position_embeddings': + hf_config.original_max_position_embeddings, + 'longrope_scaling_short_factors': + hf_config.rope_scaling["short_factor"], + 'longrope_scaling_long_factors': + hf_config.rope_scaling["long_factor"], + 'longrope_long_mscale': + hf_config.rope_scaling["long_mscale"], + 'longrope_short_mscale': + hf_config.rope_scaling["short_mscale"] + }) + return config diff --git a/tensorrt_llm/models/phi3/phi3small/model.py b/tensorrt_llm/models/phi3/phi3small/model.py new file mode 100644 index 000000000..2d0ff7d79 --- /dev/null +++ b/tensorrt_llm/models/phi3/phi3small/model.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import safetensors +from transformers import AutoModelForCausalLM + +from ...._utils import pad_vocab_size +from ....functional import PositionEmbeddingType, Tensor +from ....layers import (MLP, Attention, AttentionMaskType, + BlockSparseAttnParams, Embedding, LayerNorm, + ParallelLMHead) +from ....module import Module +from ...modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, + PretrainedConfig) +from .convert import convert_hf_config, convert_hf_weights + + +class Phi3SmallDecoderLayer(Module): + + def __init__(self, config: PretrainedConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + tp_group = config.mapping.tp_group + tp_size = config.mapping.tp_size + self.gegelu_limit = config.gegelu_limit + + self.input_layernorm = LayerNorm(normalized_shape=config.hidden_size, + dtype=config.dtype) + + # MuP uses norm_factor=attention_head_size (rather than sqrt(attention_head_size)) + # We achieve this using q_scaling = sqrt(attention_head_size) + hidden_size = config.hidden_size + num_attention_heads = config.num_attention_heads + attention_head_size = hidden_size / num_attention_heads + q_scaling = attention_head_size**.5 + + block_sparse = ( + (layer_idx + 1) % config.dense_attention_every_n_layers) != 0 + attention_mask_type = AttentionMaskType.blocksparse if block_sparse else AttentionMaskType.causal + + block_sparse_attn_params = BlockSparseAttnParams( + config.blocksparse_block_size, config.blocksparse_homo_head_pattern, + config.blocksparse_num_local_blocks, + config.blocksparse_vertical_stride) + + layers_range = config.mapping.pp_layers(config.num_hidden_layers) + local_layer_idx = layer_idx - layers_range[0] + + position_embedding_type = PositionEmbeddingType.rope_gpt_neox + original_max_position_embeddings = config.max_position_embeddings + + rope_scaling_short_factors, rope_scaling_long_factors = 1.0, 1.0 + rope_scaling_short_mscale, rope_scaling_long_mscale = 1.0, 1.0 + + if hasattr(config, "longrope_scaling_short_factors"): + rope_scaling_short_factors = np.asarray( + config.longrope_scaling_short_factors).astype(np.float32) + rope_scaling_long_factors = np.asarray( + config.longrope_scaling_long_factors).astype(np.float32) + rope_scaling_short_mscale = config.longrope_short_mscale + rope_scaling_long_mscale = config.longrope_long_mscale + + position_embedding_type = PositionEmbeddingType.long_rope + original_max_position_embeddings = config.original_max_position_embeddings + + self.attention = Attention( + local_layer_idx=local_layer_idx, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_kv_heads=config.num_kv_heads, + position_embedding_type=position_embedding_type, + rotary_embedding_base=config.rotary_embedding_base, + max_position_embeddings=config.max_position_embeddings, + original_max_position_embeddings=original_max_position_embeddings, + dtype=config.dtype, + attention_mask_type=attention_mask_type, + bias=True, + q_scaling=q_scaling, + tp_group=tp_group, + tp_size=tp_size, + quant_mode=config.quant_mode, + rope_scaling_short_factors=rope_scaling_short_factors, + rope_scaling_long_factors=rope_scaling_long_factors, + rope_scaling_short_mscale=rope_scaling_short_mscale, + rope_scaling_long_mscale=rope_scaling_long_mscale, + block_sparse_params=block_sparse_attn_params) + + self.post_layernorm = LayerNorm(normalized_shape=config.hidden_size, + dtype=config.dtype) + + self.mlp = MLP(hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + hidden_act=config.hidden_act, + dtype=config.dtype, + tp_group=tp_group, + tp_size=tp_size, + quant_mode=config.quant_mode) + + def forward( + self, + hidden_states: Tensor, + attention_mask=None, + use_cache=False, + kv_cache_params=None, + attention_params=None, + ): + residual = hidden_states + input_layernorm_output = self.input_layernorm(hidden_states) + + # Self attention + attention_output = self.attention( + input_layernorm_output, + attention_mask=attention_mask, + use_cache=use_cache, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + ) + + if use_cache: + attention_output, presents = attention_output + + hidden_states = residual + attention_output + + # Fully connected + residual = hidden_states + hidden_states = self.post_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, gegelu_limit=self.gegelu_limit) + hidden_states = residual + hidden_states + + if use_cache: + return (hidden_states, presents) + return hidden_states + + +class Phi3SmallModel(Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.vocab_embedding = Embedding(num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + dtype=config.dtype) + + self.layers = DecoderLayerList(Phi3SmallDecoderLayer, config) + self.ln_f = LayerNorm(normalized_shape=config.hidden_size, + dtype=config.dtype) + self.mup_embedding_multiplier = config.mup_embedding_multiplier + + def forward( + self, + input_ids: Tensor, + position_ids=None, + use_cache=False, + attention_mask=None, + kv_cache_params=None, + attention_params=None, + prompt_embedding_table=None, + prompt_tasks=None, + prompt_vocab_size=None, + ): + args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else [] + hidden_states = self.vocab_embedding(input_ids, *args) + + if self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0: + hidden_states = hidden_states * self.mup_embedding_multiplier + + hidden_states = self.layers( + hidden_states, + use_cache=use_cache, + attention_mask=attention_mask, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + ) + if use_cache: + hidden_states, presents = hidden_states + + hidden_states = self.ln_f(hidden_states) + + if use_cache: + return (hidden_states, tuple(presents)) + return hidden_states + + +class Phi3SmallForCausalLM(DecoderModelForCausalLM): + + def __init__(self, config: PretrainedConfig): + transformer = Phi3SmallModel(config) + vocab_size_padded = pad_vocab_size(config.vocab_size, + config.mapping.tp_size) + + lm_head = ParallelLMHead(config.hidden_size, + vocab_size_padded, + bias=False, + dtype=config.dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + gather_output=True) + + super().__init__(config, transformer, lm_head) + + @classmethod + def convert_hf_checkpoint(cls, model_dir, dtype, output_dir, args=None): + ''' + Convert Huggingface checkpoint to TRT-LLM checkpoint + ''' + + hf_model = AutoModelForCausalLM.from_pretrained(model_dir, + torch_dtype="auto", + trust_remote_code=True) + + config = convert_hf_config(hf_model.config, dtype, args) + with open(os.path.join(output_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) + + def covert_and_save(rank): + weights = convert_hf_weights(hf_model, config, args, rank) + safetensors.torch.save_file( + weights, os.path.join(output_dir, f'rank{rank}.safetensors')) + + world_size = args.tp_size * args.pp_size + if args.workers == 1: + for rank in range(world_size): + covert_and_save(rank) + else: + with ThreadPoolExecutor(max_workers=args.workers) as p: + futures = [ + p.submit(covert_and_save, rank) + for rank in range(world_size) + ] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." diff --git a/tensorrt_llm/models/qwen/convert.py b/tensorrt_llm/models/qwen/convert.py index 2478c0a1b..1b6a6a445 100644 --- a/tensorrt_llm/models/qwen/convert.py +++ b/tensorrt_llm/models/qwen/convert.py @@ -953,33 +953,12 @@ def convert_hf_qwen(hf_model, v = get_weight(model_params, key_list[7], dtype) - if hf_model.config.tie_word_embeddings: - # lm_head.weight has the same weights as embedding - if mapping.is_last_pp_rank(): - if vocab_size % mapping.tp_size != 0: - # padding - vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) - pad_width = vocab_size_padded - vocab_size - - v = torch.from_numpy( - np.pad(v.detach().cpu().numpy(), ((0, pad_width), (0, 0)), - 'constant', - constant_values=0)) - weights['lm_head.weight'] = split(v, mapping.tp_size, - mapping.tp_rank) - - if use_parallel_embedding: - v = split_matrix_tp(v, - mapping.tp_size, - mapping.tp_rank, - dim=sharding_dim) - - if mapping.is_first_pp_rank(): - weights['transformer.vocab_embedding.weight'] = v - - lm_head_weights = get_weight(model_params, 'lm_head', dtype) - if mapping.is_last_pp_rank(): + if hf_model.config.tie_word_embeddings: + # lm_head.weight has the same weights as embedding + lm_head_weights = v + else: + lm_head_weights = get_weight(model_params, 'lm_head', dtype) if vocab_size % mapping.tp_size != 0: # padding @@ -995,6 +974,17 @@ def convert_hf_qwen(hf_model, tensor_parallel, mapping.tp_rank, dim=0) + + if use_parallel_embedding: + v = split_matrix_tp(v, + mapping.tp_size, + mapping.tp_rank, + dim=sharding_dim) + + if mapping.is_first_pp_rank(): + weights['transformer.vocab_embedding.weight'] = v + + if mapping.is_last_pp_rank(): ln_f_w = get_weight(model_params, key_list[8], dtype) weights['transformer.ln_f.weight'] = ln_f_w @@ -1092,7 +1082,7 @@ def create_config_from_hugging_face(hf_model, 'pp_size': mapping.pp_size } }) - config['quantization'] = quantization.asdict() + config['quantization'] = quantization.to_dict() config.update(override_fields) return config diff --git a/tensorrt_llm/models/qwen/weight.py b/tensorrt_llm/models/qwen/weight.py index eb8b17d26..3df6e5afa 100644 --- a/tensorrt_llm/models/qwen/weight.py +++ b/tensorrt_llm/models/qwen/weight.py @@ -167,18 +167,19 @@ def process_and_assign_weight(v: List[torch.Tensor], if qwen_type == 'qwen': qkv_bias = model_params[prefix + key_list[0] + suf].to(torch_dtype).cpu().contiguous() + q_emb = qkv_bias.shape[0] // 3 + qkv_bias = qkv_bias.reshape(3, q_emb) + split_v = split(qkv_bias, mapping.tp_size, mapping.rank, dim=1) + qkv_bias = split_v.reshape(3 * (q_emb // mapping.tp_size)) else: qkv_bias_list = [] for comp in ["q_proj", "k_proj", "v_proj"]: comp_part = model_params[prefix + key_list[0] + comp + suf].to( torch_dtype).cpu().contiguous() + comp_part = torch_split(comp_part, dim=0) qkv_bias_list.append(comp_part) qkv_bias = torch.cat(qkv_bias_list, dim=0) - q_emb = qkv_bias.shape[0] // 3 - qkv_bias = qkv_bias.reshape(3, q_emb) - split_v = split(qkv_bias, mapping.tp_size, mapping.rank, dim=1) - split_v = split_v.reshape(3 * (q_emb // mapping.tp_size)) - weights[tllm_prex + ".attention.qkv.bias"] = split_v + weights[tllm_prex + ".attention.qkv.bias"] = qkv_bias # 4.3 attention.dense qkv_dense_list = [] for suf in suffixs: diff --git a/tensorrt_llm/models/recurrentgemma/model.py b/tensorrt_llm/models/recurrentgemma/model.py index e6d012d24..be11ed431 100644 --- a/tensorrt_llm/models/recurrentgemma/model.py +++ b/tensorrt_llm/models/recurrentgemma/model.py @@ -49,8 +49,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): num_heads=config.num_attention_heads, dtype=config.dtype, tp_group=config.mapping.tp_group, - tp_size=config.mapping.tp_size, - tp_rank=config.mapping.tp_rank) + tp_size=config.mapping.tp_size) elif self.temporal_block_type == 'attention': layer_types = config.layer_types * ( (layer_idx + 1) // layer_type_len) diff --git a/tensorrt_llm/module.py b/tensorrt_llm/module.py index d4626dda5..96575c0d8 100644 --- a/tensorrt_llm/module.py +++ b/tensorrt_llm/module.py @@ -103,6 +103,25 @@ def named_modules(self, memo=None, prefix='', remove_duplicate=True): remove_duplicate): yield m + def named_modules_with_parent(self, + memo=None, + prefix='', + parent=None, + remove_duplicate=True): + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self, parent + for name, module in self._modules.items(): + if module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + for m in module.named_modules_with_parent( + memo, submodule_prefix, self, remove_duplicate): + yield m + def named_children(self): memo = set() for name, module in self._modules.items(): diff --git a/tensorrt_llm/network.py b/tensorrt_llm/network.py index ad139d94f..c6be2f557 100644 --- a/tensorrt_llm/network.py +++ b/tensorrt_llm/network.py @@ -15,6 +15,7 @@ import collections import contextlib import hashlib +import inspect import weakref from collections import defaultdict from dataclasses import dataclass, field @@ -242,6 +243,16 @@ def _set_layer_name(self, layer): layer_name = str(layer.type).split('.')[-1] current_module = self._module_call_stack.get_current_module() + func_stack = [] + frame = inspect.currentframe().f_back.f_back + while frame: + func_name = frame.f_code.co_name + if func_name == "forward": + break + func_stack.insert(0, func_name) + frame = frame.f_back + current_module = f"{current_module}.{'.'.join(func_stack)}" + if layer.type == trt.LayerType.PLUGIN_V2: layer_name = '_'.join( [layer_name, diff --git a/tensorrt_llm/parameter.py b/tensorrt_llm/parameter.py index 42dc42bbd..aef8154fd 100644 --- a/tensorrt_llm/parameter.py +++ b/tensorrt_llm/parameter.py @@ -123,6 +123,11 @@ def raw_value(self) -> np.ndarray: @value.setter def value(self, v: Union[np.ndarray, torch.Tensor]): v = self._regularize_value(v) + + if v.shape != self.shape and v.ndim == 0 and max(self.shape) == 1: + # convert the scalar into a tensor which each dim is 1. + v = v.reshape(self.shape) + assert v.shape == self.shape, \ f'The value updated is not the same shape as the original. ' \ f'Updated: {v.shape}, original: {self.shape}' diff --git a/tensorrt_llm/plugin/plugin.py b/tensorrt_llm/plugin/plugin.py index 3a12b54a4..f1ddfea4a 100644 --- a/tensorrt_llm/plugin/plugin.py +++ b/tensorrt_llm/plugin/plugin.py @@ -59,7 +59,10 @@ class ContextFMHAType(IntEnum): enabled_with_fp32_acc = 2 -PLUGIN_DTYPE_OPTIONS = ["auto", "float16", "float32", "bfloat16", "int32", None] +DEFAULT_PLUGIN_DTYPE_OPTIONS = [ + "auto", "float16", "float32", "bfloat16", "int32", None +] +PLUGIN_DTYPE_OPTIONS_MAP = {"gemm_swiglu_plugin": ["fp8", None]} def _make_plugin_property(field_name: str, field_type: type): @@ -81,8 +84,11 @@ def prop(self, value): assert isinstance(value, bool), \ f"Plugin {field_name} expects {field_type}, got {type(value)}" elif field_type in (str, Optional[str]): - assert value in PLUGIN_DTYPE_OPTIONS, \ - f"Plugin {field_name} expects values in {PLUGIN_DTYPE_OPTIONS}, got {value}" + plugin_dtype_options = DEFAULT_PLUGIN_DTYPE_OPTIONS + if field_name in PLUGIN_DTYPE_OPTIONS_MAP: + plugin_dtype_options = PLUGIN_DTYPE_OPTIONS_MAP[field_name] + assert value in plugin_dtype_options, \ + f"Plugin {field_name} expects values in {plugin_dtype_options}, got {value}" if field_name == 'dtype': assert value not in ['auto', None], \ "Plugin dtype cannot be auto or None" @@ -110,7 +116,7 @@ class PluginConfig(metaclass=PluginConfigMeta): There are two option categories: * Plugin options (typically with xxx_plugin naming). These options can be assigned with: - * "float16"/"bfloat16"/"float32"/"int32", which means the plugin is enabled with the specified precision; + * "float16"/"bfloat16"/"float32"/"int32", which means the plugin is enabled with the specified precision; (Some plugins only support limited dtype, i.e., gemm_swiglu_plugin only supports fp8 now) * "auto", which means the plugin is enabled with the precision of `dtype` field (the `dtype` field must be same to model dtype, i.e., the one in PretrainedConfig); * None, which means the plugin is disabled. * Other features. These options can be assigned with boolean: @@ -126,6 +132,7 @@ class PluginConfig(metaclass=PluginConfigMeta): _bert_attention_plugin: Optional[str] = field(default="auto", init=False) _gpt_attention_plugin: Optional[str] = field(default="auto", init=False) _gemm_plugin: Optional[str] = field(default=None, init=False) + _gemm_swiglu_plugin: Optional[str] = field(default=None, init=False) _smooth_quant_gemm_plugin: Optional[str] = field(default=None, init=False) _identity_plugin: Optional[str] = field(default=None, init=False) _layernorm_quantization_plugin: Optional[str] = field(default=None, @@ -265,6 +272,7 @@ def set_nccl_plugin(self, "bert_attention_plugin", "gpt_attention_plugin", "gemm_plugin", + "gemm_swiglu_plugin", "lookup_plugin", "lora_plugin", "moe_plugin", @@ -297,11 +305,14 @@ def add_plugin_argument(parser): if field_name not in cli_plugin_args: continue if field.type in (str, Optional[str]): + plugin_dtype_options = DEFAULT_PLUGIN_DTYPE_OPTIONS + if field_name in PLUGIN_DTYPE_OPTIONS_MAP: + plugin_dtype_options = PLUGIN_DTYPE_OPTIONS_MAP[field_name] parser.add_argument( "--" + field_name, type=str, default=field.default if field.default else "disable", - choices=[x if x else "disable" for x in PLUGIN_DTYPE_OPTIONS], + choices=[x if x else "disable" for x in plugin_dtype_options], help=f"Whether to enable/disable {field_name} and the dtype.") elif field.type == bool: parser.add_argument( diff --git a/tensorrt_llm/quantization/layers.py b/tensorrt_llm/quantization/layers.py index 35232ec27..e7ebbeef6 100644 --- a/tensorrt_llm/quantization/layers.py +++ b/tensorrt_llm/quantization/layers.py @@ -249,6 +249,7 @@ def __init__( self.register_parameter('bias', None) self.eps = eps + self.dtype = dtype self.quant_mode = quant_mode if self.quant_mode.has_act_and_weight_quant(): @@ -300,6 +301,7 @@ def __init__( self.register_parameter('bias', None) self.eps = eps + self.dtype = dtype self.quant_mode = quant_mode if self.quant_mode.has_act_and_weight_quant(): @@ -855,7 +857,10 @@ def __init__( self.weights_scaling_factor = Parameter(shape=(1, ), dtype=trt.float32) def forward(self, x, lora_runtime_params=None): - assert lora_runtime_params is None, "lora is not supported on FP8Linear now" + assert lora_runtime_params is None or default_net( + ).plugin_config.lora_plugin == self.dtype + + lora_hidden_state = x if lora_runtime_params is not None else None if default_net().strongly_typed: assert is_same_dtype( x.dtype, @@ -881,7 +886,9 @@ def forward(self, x, lora_runtime_params=None): return self.multiply_gather(dequantized_out, w_deq_out, gemm_plugin=None, - use_fp8=True) + use_fp8=True, + lora_runtime_params=lora_runtime_params, + lora_hidden_state=lora_hidden_state) class FP8RowLinear(RowLinear): @@ -908,8 +915,10 @@ def __init__( self.weights_scaling_factor = Parameter(shape=(1, ), dtype=trt.float32) def forward(self, x, lora_runtime_params=None): - assert lora_runtime_params is None, "lora is not supported on FP8RowLinear now" + assert lora_runtime_params is None or default_net( + ).plugin_config.lora_plugin == self.dtype + lora_hidden_state = x if lora_runtime_params is not None else None activation_scaling_factor = cast(self.activation_scaling_factor.value, self.dtype) if x.dtype != trt.fp8: @@ -933,7 +942,9 @@ def forward(self, x, lora_runtime_params=None): return self.multiply_reduce(dequantized_out, w_deq_out, gemm_plugin=None, - use_fp8=True) + use_fp8=True, + lora_runtime_params=lora_runtime_params, + lora_hidden_state=lora_hidden_state) class SmoothQuantGatedMLP(SmoothQuantMLP): @@ -1002,7 +1013,7 @@ class SmoothQuantAttention(Module): def __init__( self, *, - layer_idx, + local_layer_idx, hidden_size, num_attention_heads, num_kv_heads=None, @@ -1012,7 +1023,7 @@ def __init__( attention_head_size=None, attention_mask_type=AttentionMaskType.padding, bias=True, - qkv_bias_only=False, + dense_bias=None, dtype=None, position_embedding_type=PositionEmbeddingType.learned_absolute, rotary_embedding_base=10000.0, @@ -1026,7 +1037,7 @@ def __init__( quant_mode=QuantMode(0), ): super().__init__() - self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx self.attention_mask_type = attention_mask_type self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size self.num_attention_heads = num_attention_heads // tp_size @@ -1037,6 +1048,9 @@ def __init__( self.max_position_embeddings = 0 if max_position_embeddings is None else max_position_embeddings self.tp_size = tp_size self.tp_rank = tp_rank + self.dense_bias = dense_bias + if dense_bias is None: + self.dense_bias = bias self.num_layers = num_layers self.apply_query_key_layer_scaling = apply_query_key_layer_scaling @@ -1108,7 +1122,7 @@ def __init__( tp_size * self.num_attention_heads * self.attention_head_size + (2 * tp_size * self.num_attention_kv_heads * self.attention_head_size), - bias=(bias or qkv_bias_only), + bias=bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size, @@ -1118,7 +1132,7 @@ def __init__( self.dense = SmoothQuantRowLinear(tp_size * self.num_attention_heads * self.attention_head_size, hidden_size, - bias=bias, + bias=self.dense_bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size, @@ -1190,7 +1204,7 @@ def forward( context_lengths=attention_params.context_lengths, cache_indirection=kv_cache_params.cache_indirection, host_request_types=attention_params.host_request_types, - layer_idx=self.layer_idx, + layer_idx=self.local_layer_idx, num_heads=self.num_attention_heads, num_kv_heads=self.num_attention_kv_heads, hidden_size_per_head=self.attention_head_size, diff --git a/tensorrt_llm/quantization/quantize.py b/tensorrt_llm/quantization/quantize.py index b0a4df49a..3999cbea8 100644 --- a/tensorrt_llm/quantization/quantize.py +++ b/tensorrt_llm/quantization/quantize.py @@ -1,3 +1,4 @@ +from .._utils import get_init_params from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP, LayerNorm, RmsNorm, RowLinear) from ..models.modeling_utils import QuantConfig @@ -12,11 +13,12 @@ from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo -def weight_only_quantize(model, - quant_config: QuantConfig, - current_key_name=None): - assert quant_config.quant_mode.is_weight_only() - +def quantize_layers( + model, + quant_config: QuantConfig, + quant_map, + preprocess_init_params=None, +): exclude_modules = quant_config.exclude_modules or [ 'lm_head', 'router', @@ -25,221 +27,129 @@ def weight_only_quantize(model, 'block_embedding', ] - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - current_key_name.append(name) - - if len(list(module.children())) > 0: - weight_only_quantize(module, quant_config, current_key_name) - - if isinstance(module, ColumnLinear) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - transb = True if name == "lm_head" else False - model._modules[name] = WeightOnlyQuantColumnLinear( - in_features=module.in_features, - out_features=module.out_features * module.tp_size, - bias=module.bias is not None, - dtype=module.dtype, - tp_group=module.tp_group, - tp_size=module.tp_size, - gather_output=module.gather_output, - quant_mode=quant_config.quant_mode, - transb=transb) - elif isinstance(module, RowLinear) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - model._modules[name] = WeightOnlyQuantRowLinear( - in_features=module.in_features * module.tp_size, - out_features=module.out_features, - bias=module.bias is not None, - dtype=module.dtype, - tp_group=module.tp_group, - tp_size=module.tp_size, - quant_mode=quant_config.quant_mode) - elif isinstance(module, Embedding) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - model._modules[name] = WeightOnlyQuantEmbedding( - num_embeddings=module.num_embeddings, - embedding_dim=module.embedding_dim, - dtype=module.dtype, - tp_size=module.tp_size, - tp_group=module.tp_group, - sharding_dim=module.sharding_dim, - tp_rank=module.tp_rank, - quant_mode=quant_config.quant_mode) - - current_key_name.pop(-1) + for name, module, parent in model.named_modules_with_parent(): + module_name = name.rsplit('.', 1)[-1] + if module_name not in exclude_modules: + quant_cls = None + for cls in quant_map: + if isinstance(module, cls): + quant_cls = quant_map[cls] + break + + if quant_cls is None: + continue + + init_params = get_init_params(module, quant_cls) + if "bias" in init_params: + init_params["bias"] = init_params["bias"] is not None + if isinstance(module, ColumnLinear): + init_params[ + "out_features"] = module.out_features * module.tp_size + elif isinstance(module, RowLinear): + init_params["in_features"] = module.in_features * module.tp_size + if preprocess_init_params is not None: + preprocess_init_params(init_params, name, module) + quant_layer = quant_cls(**init_params) + setattr(parent, module_name, quant_layer) setattr(model, 'quant_mode', quant_config.quant_mode) return model -def weight_only_groupwise_quantize(model, - quant_config: QuantConfig, - current_key_name=None): +def weight_only_quantize(model, quant_config: QuantConfig): assert quant_config.quant_mode.is_weight_only() - exclude_modules = quant_config.exclude_modules or ['lm_head', 'router'] - - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - current_key_name.append(name) - - if len(list(module.children())) > 0: - weight_only_groupwise_quantize(module, quant_config, - current_key_name) - - if isinstance(module, ColumnLinear) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - model._modules[name] = WeightOnlyGroupwiseQuantColumnLinear( - in_features=module.in_features, - out_features=module.out_features * module.tp_size, - group_size=quant_config.group_size, - pre_quant_scale=quant_config.pre_quant_scale, - zero=quant_config.has_zero_point, - bias=module.bias is not None, - use_w4a8_awq=quant_config.quant_algo == QuantAlgo.W4A8_AWQ, - dtype=module.dtype, - tp_group=module.tp_group, - tp_size=module.tp_size, - gather_output=module.gather_output) - elif isinstance(module, RowLinear) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - model._modules[name] = WeightOnlyGroupwiseQuantRowLinear( - in_features=module.in_features * module.tp_size, - out_features=module.out_features, - group_size=quant_config.group_size, - pre_quant_scale=quant_config.pre_quant_scale, - zero=quant_config.has_zero_point, - bias=module.bias is not None, - use_w4a8_awq=quant_config.quant_algo == QuantAlgo.W4A8_AWQ, - dtype=module.dtype, - tp_group=module.tp_group, - tp_size=module.tp_size) - - current_key_name.pop(-1) + quant_map = { + ColumnLinear: WeightOnlyQuantColumnLinear, + RowLinear: WeightOnlyQuantRowLinear, + Embedding: WeightOnlyQuantEmbedding, + } + + def preprocess_init_params(init_params, name, module): + init_params["quant_mode"] = quant_config.quant_mode + if isinstance(module, ColumnLinear): + module_name = name.rsplit('.', 1)[-1] + init_params["transb"] = module_name == "lm_head" + + quantize_layers( + model, + quant_config, + quant_map, + preprocess_init_params, + ) + return model - setattr(model, 'quant_mode', quant_config.quant_mode) + +def weight_only_groupwise_quantize(model, quant_config: QuantConfig): + assert quant_config.quant_mode.is_weight_only() + + quant_map = { + ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear, + RowLinear: WeightOnlyGroupwiseQuantRowLinear, + } + + def preprocess_init_params(init_params, name, module): + init_params["group_size"] = quant_config.group_size + init_params["pre_quant_scale"] = quant_config.pre_quant_scale + init_params["zero"] = quant_config.has_zero_point + init_params[ + "use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ + + quantize_layers( + model, + quant_config, + quant_map, + preprocess_init_params, + ) return model def smooth_quantize_ootb( model, quant_config: QuantConfig, - current_key_name=None, ): - exclude_modules = quant_config.exclude_modules or ['lm_head', 'router'] - - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - current_key_name.append(name) - - if len(list(module.children())) > 0: - smooth_quantize_ootb(module, quant_config, current_key_name) - - if isinstance(module, ColumnLinear) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - model._modules[name] = Int8SmoothQuantLinear( - module.in_features, module.out_features * module.tp_size, - module.bias, module.dtype, module.tp_group, module.tp_size, - module.gather_output) - elif isinstance(module, RowLinear) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - model._modules[name] = Int8SmoothQuantRowLinear( - module.in_features * module.tp_size, module.out_features, - module.bias, module.dtype, module.tp_group, module.tp_size) - - current_key_name.pop(-1) - - setattr(model, 'quant_mode', quant_config.quant_mode) + quant_map = { + ColumnLinear: Int8SmoothQuantLinear, + RowLinear: Int8SmoothQuantRowLinear, + } + + quantize_layers( + model, + quant_config, + quant_map, + ) return model def smooth_quantize_plugin(model, quant_mode): - for layer_idx, layer in enumerate(model.transformer.layers): - config = layer.config - - assert hasattr(layer, - "input_layernorm"), "The layer has no input_layernorm" - quant_norm_cls = None - if isinstance(layer.input_layernorm, RmsNorm): - quant_norm_cls = SmoothQuantRmsNorm - elif isinstance(layer.input_layernorm, LayerNorm): - quant_norm_cls = SmoothQuantLayerNorm - assert quant_norm_cls is not None - layer.input_layernorm = quant_norm_cls( - normalized_shape=config.hidden_size, - eps=config.norm_epsilon, - dtype=config.dtype, - quant_mode=quant_mode) - - assert hasattr(layer, "attention"), "The layer has no attention" - qkv_bias = layer.attention.qkv.bias is not None - dense_bias = layer.attention.dense.bias is not None - head_size = config.head_size if hasattr(config, 'head_size') else None - layer.attention = SmoothQuantAttention( - layer_idx=layer_idx, - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - attention_head_size=head_size, - max_position_embeddings=config.max_position_embeddings, - num_layers=config.num_hidden_layers, - dtype=config.dtype, - attention_mask_type=layer.attention.attention_mask_type, - position_embedding_type=layer.attention.position_embedding_type, - rotary_embedding_base=layer.attention.rotary_embedding_base, - rotary_embedding_scaling=layer.attention.rotary_embedding_scaling, - rotary_embedding_percentage=layer.attention. - rotary_embedding_percentage, - tp_group=config.mapping.tp_group, - tp_size=config.mapping.tp_size, - tp_rank=config.mapping.tp_rank, - quant_mode=quant_mode, - bias=(qkv_bias and dense_bias), - qkv_bias_only=(qkv_bias and not dense_bias)) - - assert hasattr(layer, "mlp"), "The layer has no mlp" - - mlp_norm_cls = None - if isinstance(layer.mlp, GatedMLP): - mlp_norm_cls = SmoothQuantGatedMLP - elif isinstance(layer.mlp, MLP): - mlp_norm_cls = SmoothQuantMLP - - mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size - layer.mlp = mlp_norm_cls(hidden_size=config.hidden_size, - ffn_hidden_size=mlp_hidden_size, - hidden_act=config.hidden_act, - dtype=config.dtype, - tp_group=config.mapping.tp_group, - tp_size=config.mapping.tp_size, - quant_mode=quant_mode, - bias=layer.mlp.bias) - assert hasattr(layer, - "post_layernorm"), "The layer has no post_layernorm" - - quant_norm_cls = None - if isinstance(layer.post_layernorm, RmsNorm): - quant_norm_cls = SmoothQuantRmsNorm - elif isinstance(layer.post_layernorm, LayerNorm): - quant_norm_cls = SmoothQuantLayerNorm - assert quant_norm_cls is not None - - layer.post_layernorm = quant_norm_cls( - normalized_shape=config.hidden_size, - eps=config.norm_epsilon, - dtype=config.dtype, - quant_mode=quant_mode) + quant_map = { + RmsNorm: SmoothQuantRmsNorm, + LayerNorm: SmoothQuantLayerNorm, + GatedMLP: SmoothQuantGatedMLP, + MLP: SmoothQuantMLP, + Attention: SmoothQuantAttention, + } + for name, layer, parent in model.named_modules_with_parent(): + layer_name = name.rsplit('.', 1)[-1] + if layer_name in ['ln_f']: + continue + + quant_cls = None + for cls in quant_map: + if isinstance(layer, cls): + quant_cls = quant_map[cls] + break + + if quant_cls is None: + continue + + init_params = get_init_params(layer, quant_cls) + init_params["quant_mode"] = quant_mode + if isinstance(layer, Attention): + init_params[ + "num_attention_heads"] = layer.num_attention_heads * layer.tp_size + quant_layer = quant_cls(**init_params) + setattr(parent, layer_name, quant_layer) setattr(model, 'quant_mode', quant_mode) return model @@ -253,49 +163,25 @@ def smooth_quantize(model, quant_config: QuantConfig): return smooth_quantize_ootb(model, quant_config) -def fp8_quantize(model, quant_config: QuantConfig, current_key_name=None): +def fp8_quantize(model, quant_config: QuantConfig): assert quant_config.quant_mode.has_fp8_qdq() - exclude_modules = quant_config.exclude_modules or ['lm_head', 'router'] - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - current_key_name.append(name) - - if len(list(module.children())) > 0: - fp8_quantize(module, quant_config, current_key_name) - - if isinstance(module, ColumnLinear) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - model._modules[name] = FP8Linear( - in_features=module.in_features, - out_features=module.out_features * module.tp_size, - bias=module.bias is not None, - dtype=module.dtype, - tp_group=module.tp_group, - tp_size=module.tp_size, - gather_output=module.gather_output) - elif isinstance(module, RowLinear) and name not in exclude_modules: - if not any(key in '.'.join(current_key_name) - for key in exclude_modules): - model._modules[name] = FP8RowLinear( - in_features=module.in_features * module.tp_size, - out_features=module.out_features, - bias=module.bias is not None, - dtype=module.dtype, - tp_group=module.tp_group, - tp_size=module.tp_size) - - current_key_name.pop(-1) + quant_map = { + ColumnLinear: FP8Linear, + RowLinear: FP8RowLinear, + } - setattr(model, 'quant_mode', quant_config.quant_mode) + quantize_layers( + model, + quant_config, + quant_map, + ) return model def kv_cache_quantize(model, quant_config: QuantConfig): assert quant_config.quant_mode.has_kv_cache_quant() - for name, module in model.named_modules(remove_duplicate=True): + for name, module in model.named_modules(): if isinstance(module, (Attention, SmoothQuantAttention)): module.kv_cache_scaling_factor = Parameter(shape=(1, ), dtype='float32') diff --git a/tensorrt_llm/quantization/quantize_by_modelopt.py b/tensorrt_llm/quantization/quantize_by_modelopt.py index 8ea3ee9bd..cfd19482b 100644 --- a/tensorrt_llm/quantization/quantize_by_modelopt.py +++ b/tensorrt_llm/quantization/quantize_by_modelopt.py @@ -122,6 +122,7 @@ def quant_cfg_choices(): "Gemma": "gemma", "MixtralForCausalLM": "llama", "ArcticForCausalLM": "llama", + "Phi3SmallForCausalLM": "phi", } @@ -408,11 +409,30 @@ def quantize_and_export(*, model_dir, calib_dataset, dtype, qformat, qwen_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) tensorrt_llm_config["qwen_type"] = qwen_config.model_type + if qwen_config.model_type == "qwen2": + tensorrt_llm_config["norm_epsilon"] = qwen_config.rms_norm_eps + tensorrt_llm_config["rotary_base"] = qwen_config.rope_theta tensorrt_llm_config[ "intermediate_size"] = qwen_config.intermediate_size with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) + if model_type == 'phi': + with open(f"{export_path}/config.json", "r") as f: + tensorrt_llm_config = json.load(f) + phi_config = AutoConfig.from_pretrained(model_dir, + trust_remote_code=True) + + from ..models.phi3.phi3small.convert import \ + convert_hf_config as phi_config_converter + phi_config = phi_config_converter(phi_config, dtype, None) + + for key, value in phi_config.items(): + tensorrt_llm_config[key] = value + + with open(f"{export_path}/config.json", "w") as f: + json.dump(tensorrt_llm_config, f, indent=4) + torch.cuda.empty_cache( ) # otherwise torch is keeping using GPU, other routine like build engine has less free GPU to use end_time = time.time() diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index c56b3f344..da8c065a4 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -16,7 +16,6 @@ import copy import csv import math -import os import platform from dataclasses import dataclass, field from functools import reduce, wraps @@ -478,6 +477,8 @@ class SamplingConfig: random_seed: Union[int, torch.Tensor] = field(init=False, default=None) output_cum_log_probs: bool = field(init=False, default=False) output_log_probs: bool = field(init=False, default=False) + no_repeat_ngram_size: Union[int, torch.Tensor] = field(init=False, + default=None) def update(self, **kwargs): unused_kwargs = dict() @@ -1142,6 +1143,18 @@ def __setup_decoder(self, input_ids: torch.Tensor, else: self.random_seed = None + if isinstance(scfg.no_repeat_ngram_size, torch.Tensor): + assert scfg.no_repeat_ngram_size.dtype == torch.int32, f"scfg.no_repeat_ngram_size.dtype ({scfg.no_repeat_ngram_size.dtype}) must be torch.int32" + assert scfg.no_repeat_ngram_size.shape[ + 0] == batch_size, f"scfg.no_repeat_ngram_size.shape[0] ({scfg.no_repeat_ngram_size.shape[0]}) must equal to batch_size ({batch_size})" + self.no_repeat_ngram_size = scfg.no_repeat_ngram_size + elif scfg.no_repeat_ngram_size is not None: + self.no_repeat_ngram_size = torch.full([batch_size], + scfg.no_repeat_ngram_size, + dtype=torch.int32) + else: + self.no_repeat_ngram_size = None + if self.mapping.is_last_pp_rank(): self.dynamic_decoder.setup( batch_size, scfg.num_beams, self.top_k, self.top_p, @@ -1150,8 +1163,8 @@ def __setup_decoder(self, input_ids: torch.Tensor, self.host_length_penalty, self.host_early_stopping, self.beam_search_diversity_rate, self.random_seed, self.top_p_decay, self.top_p_min, self.top_p_reset_ids, - scfg.output_log_probs, scfg.num_beams > 1 - or scfg.output_cum_log_probs) + self.no_repeat_ngram_size, scfg.output_log_probs, + scfg.num_beams > 1 or scfg.output_cum_log_probs) assert scfg.end_id is not None, "end_id cannot be none" assert scfg.pad_id is not None, 'pad_id cannot be none' @@ -1562,24 +1575,22 @@ def setup(self, else: # Without plugin, we need extra kv cache buffers. # Because we don't support inplace update, so we need separate buffer for inputs and outputs. - # Not applicable to cross KV buffers as it's constant - for i in range(self.first_layer, self.last_layer): - if self.layer_types[i] == 'attention': - trt_dtype = self.runtime.engine.get_tensor_dtype( - f'present_key_value_{i}') - - if trt_dtype == trt.fp8: - # PyTorch doesn't support fp8 datatype, use int8 instead of it because int8 datatype size is same with fp8. - # TODO: Remove this section when PyTorch support fp8 datatype - dtype = torch.int8 - else: - dtype = self._tensor_dtype(f'present_key_value_{i}') - self.buffer[f'1_present_key_value_{i}'] = torch.empty( - cache_shape, dtype=dtype, device=self.device) - if os.getenv('TRTLLM_DISABLE_OOTB_KVCACHE_REUSE') != 'ON': - # We can do reuse between different layers' inputs and outputs, i.e. current layer's output can - # reuse previous layer's input memory. But this need one extra buffer as the guard. - break + # We can do reuse between different layers' inputs and outputs, i.e. current layer's output can + # reuse previous layer's input memory. But this need one extra buffer as the guard. + i = self.first_layer + if self.layer_types[ + i] == 'attention': # Not applicable to cross KV buffers as it's constant + trt_dtype = self.runtime.engine.get_tensor_dtype( + f'present_key_value_{i}') + + if trt_dtype == trt.fp8: + # PyTorch doesn't support fp8 datatype, use int8 instead of it because int8 datatype size is same with fp8. + # TODO: Remove this section when PyTorch support fp8 datatype + dtype = torch.int8 + else: + dtype = self._tensor_dtype(f'present_key_value_{i}') + self.buffer[f'1_present_key_value_{i}'] = torch.empty( + cache_shape, dtype=dtype, device=self.device) if self.use_mamba_conv1d_plugin: conv_state_shape = ( @@ -1994,11 +2005,12 @@ def add_tensor_with_shape(x, name, shape): add_tensor(self.cross_qkv_reuse, 'cross_qkv_reuse') else: # minimize - # hacky way: such that qkv gemm becomes a gemv which is cheap and negligible + # use TensorRT Empty Tensor to skip redundant computation + # 0 for generation phase, >0 for context phase encoder_output_shape = [ - 1, encoder_output.shape[-1] + 0, encoder_output.shape[-1] ] if self.remove_input_padding else [ - 1, 1, encoder_output.shape[-1] + 1, 0, encoder_output.shape[-1] ] else: # OOTB path doesn't have kv cache for now, so this encoder_output is @@ -2057,32 +2069,28 @@ def add_tensor_with_shape(x, name, shape): idx] == 'attention': next_shape = (batch_size * beam_width, 2, self.num_heads_kv, max_context_length + step, self.head_size) - if os.getenv("TRTLLM_DISABLE_OOTB_KVCACHE_REUSE") != 'ON': - # We will make current layer's output KV-cache overwrite previous layers input KV-cache - # buffer id: ... 5, 6, 7, 8, 9, ... - # layer n: out in - # layer n+1: out in - # layer n+2 out in - # And when finish a step, we will make every layer's in/out buffer index subtract 1 in - # a circular buffer way to make sure current outputs become next step's inputs. - buffer_num = self.num_attn_layers + 1 # attention layer num + 1 extra buffer. - # Subtract 1 for every step. - input_ind = attn_layer_idx - (step % buffer_num) - # When underflow, go to the back to achieve a circular buffers. - if input_ind < 0: - input_ind = self.num_attn_layers + 1 + input_ind - # Output buffer is just before input buffer. When input is buffer 0, output should use the back buffer to achieve circular buffers. - output_ind = input_ind - 1 if input_ind > 0 else self.num_attn_layers - - # We only allocate layer num of normal buffers. If index is overflow, use the extra buffer. - input_name = f'present_key_value_{self.attn_to_general_idx[input_ind]}' if input_ind != self.num_attn_layers \ - else f'1_present_key_value_{self.attn_to_general_idx[0]}' - output_name = f'present_key_value_{self.attn_to_general_idx[output_ind]}' if output_ind != self.num_attn_layers \ - else f'1_present_key_value_{self.attn_to_general_idx[0]}' - attn_layer_idx += 1 - else: - input_name = f'1_present_key_value_{idx}' if step % 2 else f'present_key_value_{idx}' - output_name = f'present_key_value_{idx}' if step % 2 else f'1_present_key_value_{idx}' + # We will make current layer's output KV-cache overwrite previous layers input KV-cache + # buffer id: ... 5, 6, 7, 8, 9, ... + # layer n: out in + # layer n+1: out in + # layer n+2 out in + # And when finish a step, we will make every layer's in/out buffer index subtract 1 in + # a circular buffer way to make sure current outputs become next step's inputs. + buffer_num = self.num_attn_layers + 1 # attention layer num + 1 extra buffer. + # Subtract 1 for every step. + input_ind = attn_layer_idx - (step % buffer_num) + # When underflow, go to the back to achieve a circular buffers. + if input_ind < 0: + input_ind = self.num_attn_layers + 1 + input_ind + # Output buffer is just before input buffer. When input is buffer 0, output should use the back buffer to achieve circular buffers. + output_ind = input_ind - 1 if input_ind > 0 else self.num_attn_layers + + # We only allocate layer num of normal buffers. If index is overflow, use the extra buffer. + input_name = f'present_key_value_{self.attn_to_general_idx[input_ind]}' if input_ind != self.num_attn_layers \ + else f'1_present_key_value_{self.attn_to_general_idx[0]}' + output_name = f'present_key_value_{self.attn_to_general_idx[output_ind]}' if output_ind != self.num_attn_layers \ + else f'1_present_key_value_{self.attn_to_general_idx[0]}' + attn_layer_idx += 1 add_tensor_with_shape(self.buffer[input_name], f'past_key_value_{idx}', next_shape) @@ -2628,7 +2636,7 @@ def handle_per_step( sequence_limit_lengths: torch.Tensor, sequence_lengths: torch.Tensor, next_step_tensors: Dict[str, RuntimeTensor], stop_words_data, - bad_words_data, no_repeat_ngram_size, encoder_output: torch.Tensor, + bad_words_data, encoder_output: torch.Tensor, encoder_input_lengths: torch.Tensor, stopping_criteria: StoppingCriteria, logits_processor: LogitsProcessor, **kwargs): @@ -2677,6 +2685,11 @@ def handle_per_step( host_cross_kv_cache_block_offsets, hidden_states, prompt_embedding_table, tasks, prompt_vocab_size, encoder_output, encoder_input_lengths) + # print(f"=============step {step} Before ctx phase") + # for name, tensor in ctx_tensors.items(): + # print(name, ":", tensor.shape) + # if "key_value" not in name: + # print(tensor.to_torch()) context = self.runtime.ctx_context self.runtime._set_tensors(context, ctx_tensors) if self.debug_mode: @@ -2736,6 +2749,9 @@ def handle_per_step( dim=1, index=last_token_ids.to(dtype=torch.int64)).view( batch_size, self.vocab_size_padded) + # print(f"=============step {step} After context phase") + # print("logits", self.buffer['logits'].shape) + # print(self.buffer['logits']) if step == 0 and beam_width > 1: assert not self.is_medusa_mode @@ -2840,6 +2856,11 @@ def handle_per_step( host_cross_kv_cache_block_offsets, hidden_states, prompt_embedding_table, tasks, prompt_vocab_size, encoder_output, encoder_input_lengths) + # print(f"=============step {step} Before gen phase") + # for name, tensor in next_step_tensors.items(): + # print(name, ":", tensor.shape) + # if "key_value" not in name: + # print(tensor.to_torch()) # there are some tensors created inside the _get_next_step_shape_buffer, not owned by any object # needs to pro-long the life time of the tensors inside the next_step_tensors array # otherwise, it maybe released before the next step actually enqueued @@ -2889,11 +2910,11 @@ def handle_per_step( context_lengths, sequence_limit_lengths, stop_words_list_ptrs, stop_words_lens, max_stop_words_len, bad_words_list_ptrs, bad_words_lens, - max_bad_words_len, no_repeat_ngram_size, - this_src_cache_indirection, self.output_ids, - self.new_tokens, self.finished, self.finished, - self.sequence_length_buffer, self.cum_log_probs, - self.log_probs, self.log_probs_tiled, self.parent_ids, + max_bad_words_len, this_src_cache_indirection, + self.output_ids, self.new_tokens, self.finished, + self.finished, self.sequence_length_buffer, + self.cum_log_probs, self.log_probs, + self.log_probs_tiled, self.parent_ids, this_tgt_cache_indirection, self.beam_hyps_output_ids_cba, self.beam_hyps_seq_len_cba, @@ -2989,7 +3010,6 @@ def decode_regular(self, sequence_limit_lengths: torch.Tensor, stop_words_data, bad_words_data, - no_repeat_ngram_size, output_sequence_lengths: bool = False, return_dict: bool = False, encoder_output: torch.Tensor = None, @@ -3054,9 +3074,8 @@ def profile_fn(benchmark_profiler_obj, step_count): host_context_lengths, attention_mask, cross_attention_mask, prompt_vocab_size, ite, sequence_limit_lengths, sequence_lengths, next_step_tensors, stop_words_data, - bad_words_data, no_repeat_ngram_size, encoder_output, - encoder_input_lengths, stopping_criteria, logits_processor, - **kwargs) + bad_words_data, encoder_output, encoder_input_lengths, + stopping_criteria, logits_processor, **kwargs) if step == 0: if benchmark_profiler is not None: benchmark_profiler.record_cuda_event('first_token') @@ -3136,7 +3155,6 @@ def decode_stream(self, sequence_limit_lengths: torch.Tensor, stop_words_data, bad_words_data, - no_repeat_ngram_size, output_sequence_lengths: bool = False, return_dict: bool = False, encoder_output: torch.Tensor = None, @@ -3175,8 +3193,8 @@ def get_outputs_dict(output_ids): host_context_lengths, attention_mask, cross_attention_mask, prompt_vocab_size, ite, sequence_limit_lengths, sequence_lengths, next_step_tensors, stop_words_data, - bad_words_data, no_repeat_ngram_size, encoder_output, - encoder_input_lengths, stopping_criteria, logits_processor) + bad_words_data, encoder_output, encoder_input_lengths, + stopping_criteria, logits_processor) if step == 0: outputs_context_logits = context_logits if should_stop is not None: @@ -3232,7 +3250,6 @@ def decode(self, prompt_vocab_size: torch.Tensor = None, stop_words_list=None, bad_words_list=None, - no_repeat_ngram_size=None, streaming: bool = False, output_sequence_lengths: bool = False, return_dict: bool = False, @@ -3432,9 +3449,9 @@ def decode(self, cache_indirections, input_ids, hidden_states, prompt_embedding_table, tasks, prompt_vocab_size, ite, sequence_limit_lengths, stop_words_data, bad_words_data, - no_repeat_ngram_size, output_sequence_lengths, return_dict, - encoder_output, encoder_input_lengths, stopping_criteria, - logits_processor, cross_attention_mask, **kwargs) + output_sequence_lengths, return_dict, encoder_output, + encoder_input_lengths, stopping_criteria, logits_processor, + cross_attention_mask, **kwargs) else: return self.decode_regular( batch_size, scfg, sequence_lengths, context_lengths, @@ -3442,9 +3459,9 @@ def decode(self, cache_indirections, input_ids, hidden_states, prompt_embedding_table, tasks, prompt_vocab_size, ite, sequence_limit_lengths, stop_words_data, bad_words_data, - no_repeat_ngram_size, output_sequence_lengths, return_dict, - encoder_output, encoder_input_lengths, stopping_criteria, - logits_processor, cross_attention_mask, **kwargs) + output_sequence_lengths, return_dict, encoder_output, + encoder_input_lengths, stopping_criteria, logits_processor, + cross_attention_mask, **kwargs) class ChatGLMGenerationSession(GenerationSession): diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index fa5c25412..d310c81e8 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -83,6 +83,7 @@ def from_dir( max_tokens_in_paged_kv_cache: int | None = None, kv_cache_enable_block_reuse: bool = False, enable_chunked_context: bool = False, + is_enc_dec: bool = False, ) -> 'ModelRunnerCpp': """ Create a ModelRunnerCpp instance from an engine directory. @@ -128,14 +129,68 @@ def from_dir( Enables block reuse in kv cache. enable_chunked_context (bool): Enables chunked context. + is_enc_dec (bool): + Whether the model is encoder-decoder architecture. Returns: ModelRunnerCpp: An instance of ModelRunnerCpp. """ + if is_enc_dec: + encoder_config_path = Path(engine_dir) / "encoder" / "config.json" + encoder_json_config = GptJsonConfig.parse_file(encoder_config_path) + encoder_json_config.model_config + decoder_config_path = Path(engine_dir) / "decoder" / "config.json" + decoder_json_config = GptJsonConfig.parse_file(decoder_config_path) + decoder_model_config = decoder_json_config.model_config + + tp_size = decoder_json_config.tensor_parallelism + pp_size = decoder_json_config.pipeline_parallelism + gpus_per_node = decoder_json_config.gpus_per_node + world_config = WorldConfig.mpi(tensor_parallelism=tp_size, + pipeline_parallelism=pp_size, + gpus_per_node=gpus_per_node) + assert rank == world_config.rank + + profiler.start('load tensorrt_llm engine') + + kv_cache_config = trtllm.KvCacheConfig( + free_gpu_memory_fraction=0.45, # hardcode for now + max_attention_window=max_attention_window_size, + sink_token_length=sink_token_length) + + executor = trtllm.Executor( + Path(engine_dir) / "encoder", + Path(engine_dir) / "decoder", trtllm.ModelType.ENCODER_DECODER, + trtllm.ExecutorConfig(max_beam_width=max_beam_width, + kv_cache_config=kv_cache_config)) + + profiler.stop('load tensorrt_llm engine') + + loading_time = profiler.elapsed_time_in_sec( + "load tensorrt_llm engine") + logger.info(f'Load engine takes: {loading_time} sec') + + return cls(executor, + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_seq_len=max_input_len + max_output_len, + max_beam_width=max_beam_width, + model_config=decoder_model_config, + world_config=world_config) + config_path = Path(engine_dir) / "config.json" json_config = GptJsonConfig.parse_file(config_path) model_config = json_config.model_config + if max_batch_size is None: + max_batch_size = model_config.max_batch_size + if max_input_len is None: + max_input_len = model_config.max_input_len + if max_output_len is None: + max_output_len = model_config.max_seq_len - model_config.max_input_len + if max_beam_width is None: + max_beam_width = model_config.max_beam_width + # Note: Parallel configuration will be fetched automatically from trtllm.Executor constructor # by inspecting the json file. These lines serve the purpose of serving vocab_size_padded and # num_layers properties. @@ -160,11 +215,12 @@ def from_dir( if medusa_choices is not None: decoding_config.medusa_choices = medusa_choices - executor = trtllm.Executor( - engine_dir, trtllm.ModelType.DECODER_ONLY, - trtllm.ExecutorConfig(max_beam_width=max_beam_width, - kv_cache_config=kv_cache_config, - decoding_config=decoding_config)) + trtllm_config = trtllm.ExecutorConfig(max_beam_width=max_beam_width, + kv_cache_config=kv_cache_config, + decoding_config=decoding_config) + trtllm_config.enable_chunked_context = enable_chunked_context + executor = trtllm.Executor(engine_dir, trtllm.ModelType.DECODER_ONLY, + trtllm_config) profiler.stop('load tensorrt_llm engine') @@ -250,6 +306,7 @@ def gather_generation_logits(self) -> bool: def generate(self, batch_input_ids: List[torch.Tensor], *, + encoder_input_ids: List[torch.Tensor] = None, sampling_config: Optional[SamplingConfig] = None, lora_uids: Optional[list] = None, streaming: bool = False, @@ -318,6 +375,8 @@ def generate(self, # Convert tensor input to plain lists batch_input_ids_list = [a.tolist() for a in batch_input_ids] + encoder_input_ids_list = [a.tolist() for a in encoder_input_ids + ] if encoder_input_ids else None if sampling_config is None: # Convert from old API of SamplingConfig @@ -327,7 +386,7 @@ def generate(self, "top_p_decay", "random_seed", "temperature", "min_length", "beam_search_diversity_rate", "repetition_penalty", "presence_penalty", "frequency_penalty", "length_penalty", - "early_stopping" + "early_stopping", "no_repeat_ngram_size" ] rename_params = {"num_beams": "beam_width"} sampling_params = { @@ -344,8 +403,9 @@ def generate(self, else: sampling_config = copy.deepcopy(sampling_config) - self._check_inputs(batch_input_ids_list, sampling_config, - max_new_tokens) + self._check_inputs( + encoder_input_ids_list if encoder_input_ids else + batch_input_ids_list, sampling_config, max_new_tokens) output_config = trtllm.OutputConfig( return_context_logits=self.gather_context_logits, @@ -363,6 +423,8 @@ def generate(self, requests = [ trtllm.Request(input_token_ids=input_ids, + encoder_input_token_ids=encoder_input_ids_list[i] + if encoder_input_ids is not None else None, max_new_tokens=max_new_tokens, pad_id=pad_id, end_id=end_id, @@ -372,9 +434,10 @@ def generate(self, streaming=streaming, output_config=output_config, prompt_tuning_config=prompt_tuning_config) - for input_ids, stop_words, bad_words, prompt_tuning_config in zip( - batch_input_ids_list, stop_words_list, bad_words_list, - prompt_tuning_configs) + for i, (input_ids, stop_words, bad_words, + prompt_tuning_config) in enumerate( + zip(batch_input_ids_list, stop_words_list, + bad_words_list, prompt_tuning_configs)) ] request_ids = self.session.enqueue_requests(requests) diff --git a/tensorrt_llm/runtime/session.py b/tensorrt_llm/runtime/session.py index bd89f8803..6161193ce 100644 --- a/tensorrt_llm/runtime/session.py +++ b/tensorrt_llm/runtime/session.py @@ -167,7 +167,7 @@ def set_shapes(self, if not ok: raise ValueError( f"Couldn't assign {name} with shape {tensor_dict[name].shape}, " - f"engine supports [min, opt, max] = {self.engine.get_profile_shape(context.active_optimization_profile, name)}" + f"engine supports [min, opt, max] = {self.engine.get_tensor_profile_shape(name, context.active_optimization_profile)}" ) def infer_shapes( diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index 134419857..64c431872 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.11.0.dev2024052800" +__version__ = "0.11.0.dev2024060400" diff --git a/tests/functional/test_alibi.py b/tests/functional/test_alibi.py index dbdfc5d1d..bef52da41 100644 --- a/tests/functional/test_alibi.py +++ b/tests/functional/test_alibi.py @@ -16,27 +16,30 @@ import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner from transformers.models.bloom.modeling_bloom import build_alibi_tensor import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func -class TestFunctional(unittest.TestCase): +class TestAlibi(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def create_random_bool_mask(self, batch_size, seq_len): - mask = torch.zeros(size=[batch_size, seq_len], dtype=torch.bool) - seq_lens = torch.randint(low=1, high=seq_len + 1, size=[batch_size]) + mask = torch.zeros(size=[batch_size, seq_len], + dtype=torch.bool, + device="cuda") + seq_lens = torch.randint(low=1, + high=seq_len + 1, + size=[batch_size], + device="cuda") for b in range(batch_size): mask[b, :seq_lens[b]] = True @@ -52,9 +55,8 @@ def test_alibi_biases(self, num_heads, batch_size, seq_len): # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): trt_key = Tensor(name='fake_key', shape=(seq_len, ), dtype=tensorrt_llm.str_dtype_to_trt('int32')) @@ -64,40 +66,35 @@ def test_alibi_biases(self, num_heads, batch_size, seq_len): tensorrt_llm.functional.generate_alibi_slopes( num_heads=num_heads)) output = tensorrt_llm.functional.generate_alibi_biases( - slopes, key_len).trt_tensor - output.name = 'output' - network.mark_output(output) + slopes, key_len) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - print(seq_len) - outputs = runner.infer( - feed_dict={ - 'fake_key': np.empty(shape=(seq_len, ), dtype=np.int32) - }) + inputs = { + 'fake_key': torch.empty((seq_len, ), + dtype=torch.int32, + device="cuda") + } + session = create_session(builder, network, precision="float32") + outputs = run_session(session, inputs) trt_alibi_output = outputs['output'] # transformers reference binary_mask = self.create_random_bool_mask(batch_size, seq_len) - ref = build_alibi_tensor(binary_mask, num_heads, - torch.float32).cpu().numpy() + ref = build_alibi_tensor(binary_mask, num_heads, torch.float32) ref = ref.reshape(batch_size, num_heads, 1, seq_len) # We only require that the alibi bias matches in the "valid" regions. Our TRT, # implementation differs in this regard for efficiency reasons but it does not matter # because these values will get masked before the softmax. - binary_mask = binary_mask.cpu().numpy().reshape(batch_size, 1, 1, - seq_len) + binary_mask = binary_mask.reshape(batch_size, 1, 1, seq_len) ref *= binary_mask - trt_alibi_output = np.repeat(trt_alibi_output, batch_size, axis=0) + trt_alibi_output = torch.repeat_interleave(trt_alibi_output, + batch_size, + dim=0) trt_alibi_output *= binary_mask # compare diff - np.testing.assert_allclose(ref, trt_alibi_output, atol=1e-3) - - -if __name__ == "__main__": - unittest.main() + torch.testing.assert_close(trt_alibi_output, ref, atol=1e-3, rtol=1e-2) diff --git a/tests/functional/test_arange.py b/tests/functional/test_arange.py index 7e11db211..019c9132a 100644 --- a/tests/functional/test_arange.py +++ b/tests/functional/test_arange.py @@ -12,16 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestArange(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -30,30 +34,24 @@ def test_arange_int(self): # test data start = 0 end = 128 - dtype = 'int32' # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): output = tensorrt_llm.functional.arange(start=start, end=end, - dtype=dtype).trt_tensor - output.name = 'output' - network.mark_output(output) - output.dtype = tensorrt_llm.str_dtype_to_trt(dtype) + dtype="int32") + output.mark_output('output', "int32") # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={}) + inputs = {} + session = create_session(builder, network, precision="float32") + outputs = run_session(session, inputs) - ref = torch.arange(start, end).int() - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + ref = torch.arange(start, end).int().cuda() + torch.testing.assert_close(outputs['output'], ref) def test_arange_tensor(self): # test data @@ -63,9 +61,8 @@ def test_arange_tensor(self): # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): start = tensorrt_llm.functional.constant(np.array(s, dtype=np.int32)) @@ -75,17 +72,17 @@ def test_arange_tensor(self): output = tensorrt_llm.functional.arange( start=start, end=tensorrt_llm.functional.shape(end_tensor, 0), - dtype=dtype).trt_tensor - output.name = 'output' - network.mark_output(output) - output.dtype = tensorrt_llm.str_dtype_to_trt(dtype) + dtype=dtype) + + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={}) - - ref = torch.arange(s, e).int() - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + inputs = {} + session = create_session(builder, network, precision="float32") + outputs = run_session(session, inputs) + + # pytorch run + ref = torch.arange(s, e).int().cuda() + + # compare diff + torch.testing.assert_close(outputs['output'], ref) diff --git a/tests/functional/test_argmax.py b/tests/functional/test_argmax.py index fc832801d..4be7fdb2f 100644 --- a/tests/functional/test_argmax.py +++ b/tests/functional/test_argmax.py @@ -12,22 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest from itertools import product -import numpy as np - # isort: off import torch # isort: on from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session + -class TestFunctional(unittest.TestCase): +class TestArgmax(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -38,31 +40,28 @@ def test_argmax(self, dtype, keep_dim, dim): # test data x_shape = (4, 12, 32) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.argmax(x, dim, - keepdim=keep_dim).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.argmax(x, dim, keepdim=keep_dim) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + inputs = {'x': x_data} + session = create_session(builder, network, precision=dtype) + outputs = run_session(session, inputs) # pytorch run ref = x_data.argmax(dim=dim, keepdim=keep_dim) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + # ref is torch.int64, outputs is torch.int32 + torch.testing.assert_close(ref.int(), outputs['output'].int()) diff --git a/tests/functional/test_assertion.py b/tests/functional/test_assertion.py index e35fff2c3..3f3b934ac 100644 --- a/tests/functional/test_assertion.py +++ b/tests/functional/test_assertion.py @@ -17,37 +17,33 @@ import unittest import torch -from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor from tensorrt_llm.functional import shape sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): +class TestAssertion(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - @parameterized.expand([('float32', )], name_func=unittest_name_func) - def test_assertion(self, dtype): + def test_assertion(self): + dtype = 'float32' + torch_dtype = tensorrt_llm.str_dtype_to_torch(dtype) # test data x_shape = (2, 4, 8) y_shape = (4, 4, 4) - x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) - y_data = torch.rand(y_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + x_data = torch.rand(x_shape, dtype=torch_dtype, device="cuda") + y_data = torch.rand(y_shape, dtype=torch_dtype, device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -56,12 +52,10 @@ def test_assertion(self, dtype): dtype=tensorrt_llm.str_dtype_to_trt(dtype)) tensorrt_llm.functional.assertion(shape(x, 1) == shape(y, 1)) - output = tensorrt_llm.functional.identity(x).trt_tensor - output.name = 'output' - network.mark_output(output) - output.dtype = tensorrt_llm.str_dtype_to_trt(dtype) + output = tensorrt_llm.functional.identity(x) + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - runner.infer(feed_dict={'x': x_data.numpy(), 'y': y_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = {'x': x_data, 'y': y_data} + run_session(session, inputs) diff --git a/tests/functional/test_avg_pool2d.py b/tests/functional/test_avg_pool2d.py index e5ab93e57..f30bcedd0 100644 --- a/tests/functional/test_avg_pool2d.py +++ b/tests/functional/test_avg_pool2d.py @@ -12,17 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestAvgPool2D(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -30,27 +33,29 @@ def setUp(self): def test_avg_pool2d(self): # test data dtype = 'float32' - x_data = torch.randn(16, 50, 32) + x_data = torch.randn(16, 50, 32, device="cuda") kernel_size = (3, 2) stride = (2, 1) # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.avg_pool2d( - x, kernel_size=kernel_size, stride=stride).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.avg_pool2d(x, + kernel_size=kernel_size, + stride=stride) + + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run ref = torch.nn.functional.avg_pool2d(x_data, @@ -58,6 +63,4 @@ def test_avg_pool2d(self): stride=stride) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-6) + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_cast.py b/tests/functional/test_cast.py index dfbb8b688..cedfadcd7 100644 --- a/tests/functional/test_cast.py +++ b/tests/functional/test_cast.py @@ -12,43 +12,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestCast(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - def test_cast(self): + def test_cast_fp16_to_fp32(self): dtype = 'float16' x_data = torch.randn( - (2, 3, 4, 5), dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + (2, 3, 4, 5), + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.cast(x, 'float32').trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.cast(x, 'float32') + output.mark_output('output', 'float32') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = x_data.to(torch.float32) - self.assertEqual(ref.cpu().numpy().dtype, outputs['output'].dtype) - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + + # compare diff + assert ref.dtype == outputs[ + 'output'].dtype, "data type after cast is not the same" + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_conv2d.py b/tests/functional/test_conv2d.py index 17c020a8f..2aa6ff6f9 100644 --- a/tests/functional/test_conv2d.py +++ b/tests/functional/test_conv2d.py @@ -12,54 +12,55 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestConv2D(unittest.TestCase): def setUp(self): + # Disable TF32 because accuracy is bad + torch.backends.cudnn.allow_tf32 = False tensorrt_llm.logger.set_level('error') def test_conv2d(self): # test data dtype = 'float32' - x_data = torch.randn(8, 4, 5, 5) - weight_data = torch.randn(8, 4, 3, 3) + x_data = torch.randn(8, 4, 5, 5, device="cuda") + weight_data = torch.randn(8, 4, 3, 3, device="cuda") padding = (1, 1) # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - weight = tensorrt_llm.constant(weight_data.numpy()) + weight = tensorrt_llm.constant(weight_data.cpu().numpy()) - output = tensorrt_llm.functional.conv2d(x, weight, - padding=padding).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.conv2d(x, weight, padding=padding) + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run ref = torch.nn.functional.conv2d(x_data, weight_data, padding=padding) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_cos.py b/tests/functional/test_cos.py index c5f96be99..4e25cd9d7 100644 --- a/tests/functional/test_cos.py +++ b/tests/functional/test_cos.py @@ -12,43 +12,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestCos(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_exp(self): dtype = 'float32' - x_data = torch.randn(2, 3, 4, 5) + x_data = torch.randn(2, 3, 4, 5, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.cos(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.cos(x) + output.mark_output('output', dtype) - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.cos(x_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_cumsum.py b/tests/functional/test_cumsum.py index b8f21f2dd..60c524649 100644 --- a/tests/functional/test_cumsum.py +++ b/tests/functional/test_cumsum.py @@ -15,20 +15,19 @@ import os import sys import unittest +from itertools import product -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func -class TestFunctional(unittest.TestCase): +class TestCumsum(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -42,37 +41,116 @@ def setUp(self): ('float16', (5, 6, 8), 1), ('float16', (5, 6, 8), 2), ('float16', (5, 6, 8), -3), + ('float32', (1, 512), -1), + ('float16', (3, 5, 5, 6), -1), + ('int32', (1, 33), -1), + ('int32', (1, 65), -1), + ('float32', (1, 50000), -1), + ('float32', (1, 2, 50000), -1), + ('float32', (3, 5, 5, 50000), -1), ], name_func=unittest_name_func) def test_cumsum(self, dtype, x_shape, dim): + torch_dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype) if 'float' in dtype: - x_data = torch.rand( - x_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + x_data = torch.rand(x_shape, dtype=torch_dtype, device="cuda") else: - x_data = torch.randint( - -100, - 100, - x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + x_data = torch.randint(-100, + 100, + x_shape, + dtype=torch_dtype, + device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.cumsum(x, dim=dim).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.cumsum(x, dim=dim) + output.mark_output('output', dtype) + + # trt run + session = create_session( + builder, + network, + precision='float32' if 'int32' in dtype else dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + + # pytorch run + ref = torch.cumsum(x_data, dim=dim).to(torch_dtype) + + # compare diff + tols = { + "float32": { + "rtol": 1e-05, + "atol": 1e-05 + }, + "float16": { + "rtol": 1e-02, + "atol": 1e-02 + }, + "int32": { + "rtol": 0, + "atol": 0 + }, + } + torch.testing.assert_close(outputs['output'], ref, **tols[dtype]) + + @parameterized.expand( + list( + product(['float32', 'float16', 'int32'], + [(256, ), (3, 16), (5, 6, 8)], [True, False])) + + list(product(['float32'], [(3, 5, 5, 50000)], + [True])), # False seems to be running into a TRT bug + name_func=unittest_name_func) + def test_cumsum_dynamic_last_dim(self, dtype, x_shape, prefer_plugin=True): + dim = -1 + torch_dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype) + if 'float' in dtype: + x_data = torch.rand(x_shape, dtype=torch_dtype, device="cuda") + else: + x_data = torch.randint(-100, + 100, + x_shape, + dtype=torch_dtype, + device="cuda") + + shape_except_last_dim = list(x_data.shape[:-1]) + last_dim_size = x_data.shape[-1] + assert last_dim_size >= 1 + builder = tensorrt_llm.Builder() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor( + name='x', + shape=shape_except_last_dim + [-1], # last dim dynamic + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) + output = tensorrt_llm.functional.cumsum(x, + dim=dim, + prefer_plugin=prefer_plugin) + output.mark_output('output', dtype) + # needs profile for dynamic shape + profile = builder.trt_builder.create_optimization_profile() + profile.set_shape('x', shape_except_last_dim + [1], + shape_except_last_dim + [last_dim_size], + shape_except_last_dim + [last_dim_size * 2]) + session = create_session( + builder, + network, + precision='float32' if 'int32' in dtype else dtype, + optimization_profiles=[profile]) + inputs = {'x': x_data} + outputs = run_session(session, inputs) - build_engine = EngineFromNetwork( - (builder.trt_builder, net.trt_network), - config=CreateConfig(fp16=(dtype == 'float16'))) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data.numpy()}) + ref = torch.cumsum(x_data, dim=dim).to(torch_dtype) - ref = torch.cumsum(x_data.cuda(), dim=dim) + # compare diff tols = { "float32": { "rtol": 1e-05, @@ -87,5 +165,4 @@ def test_cumsum(self, dtype, x_shape, dim): "atol": 0 }, } - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output'], - **tols[dtype]) + torch.testing.assert_close(outputs['output'], ref, **tols[dtype]) diff --git a/tests/functional/test_einsum.py b/tests/functional/test_einsum.py index f42f9ad75..2f4e99c90 100644 --- a/tests/functional/test_einsum.py +++ b/tests/functional/test_einsum.py @@ -16,68 +16,54 @@ import sys import unittest -import numpy as np import torch -from parameterized import parameterized -from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile, - TrtRunner) import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): +class TestEinsum(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - @parameterized.expand([('float32', )], name_func=unittest_name_func) - def test_einsum(self, dtype): - # torch 1.13: "baddbmm_with_gemm" not implemented for 'Half' + def test_einsum(self): + dtype = 'float32' # test data x_shape = (12, 12, 96, 96) y_shape = (12, 12, 96, 64) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") y_data = torch.rand(y_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") equation = 'bnth,bnhs->bnts' # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) y = Tensor(name='y', shape=y_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.einsum(equation, [x, y]).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.einsum(equation, [x, y]) + output.mark_output('output', dtype) # trt run - profiles = [ - Profile().add('x', x_shape, x_shape, - x_shape).add('y', y_shape, y_shape, y_shape) - ] - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network), - config=CreateConfig(profiles=profiles)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - 'y': y_data.numpy() - }) + session = create_session(builder, network, precision=dtype) + inputs = {'x': x_data, 'y': y_data} + outputs = run_session(session, inputs) # pytorch run ref = torch.functional.einsum(equation, [x_data, y_data]) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-4) + torch.testing.assert_close(outputs['output'], ref, atol=5e-3, rtol=2e-4) diff --git a/tests/functional/test_embedding_single_gpu.py b/tests/functional/test_embedding_single_gpu.py index 7903026e5..145ce6545 100644 --- a/tests/functional/test_embedding_single_gpu.py +++ b/tests/functional/test_embedding_single_gpu.py @@ -12,88 +12,70 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math import os import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func -def split_vocab_size(vocab_size, tp_size): - return int(math.ceil(vocab_size / tp_size)) - - -def split(v, tp_size, idx, dim=0): - if tp_size == 1: - return v - if len(v.shape) == 1: - return np.ascontiguousarray(np.split(v, tp_size)[idx]) - elif len(v.shape) == 2: - return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) - return None - - -class TestFunctional(unittest.TestCase): +class TestEmbedding(unittest.TestCase): def setUp(self): + torch.random.manual_seed(0) tensorrt_llm.logger.set_level('error') @parameterized.expand([( 'float32', - 1, + True, ), ( 'float32', - 0, + False, ), ( 'float16', - 1, + True, ), ( 'float16', - 0, + False, )], name_func=unittest_name_func) def test_embedding(self, dtype, use_lookup_plugin): - # torch gelu does not support float16 - fp16 = (dtype == 'float16') # meta data batch_size = 10 vocab_size = 1000 n_embed = 1024 - np.random.seed(0) # test data ## input index - index_shape = (batch_size) - index_np = np.random.randint(low=0, - high=vocab_size, - size=index_shape, - dtype=np.int32) - index_data = torch.from_numpy(index_np) + index_shape = (batch_size, ) + index_data = torch.randint(0, + vocab_size, + index_shape, + dtype=torch.int32, + device="cuda") ## weight data - weight_np = np.random.rand(vocab_size, n_embed).astype(dtype) - weight_data = torch.from_numpy(weight_np) + weight_data = torch.rand(vocab_size, + n_embed, + dtype=tensorrt_llm.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() + network = builder.create_network() - net = builder.create_network() if use_lookup_plugin: - net.plugin_config.lookup_plugin = dtype + network.plugin_config.lookup_plugin = dtype - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + with tensorrt_llm.net_guard(network): index = Tensor(name='index', shape=index_data.shape, dtype=tensorrt_llm.str_dtype_to_trt('int32')) @@ -104,27 +86,19 @@ def test_embedding(self, dtype, use_lookup_plugin): output = tensorrt_llm.functional.embedding(input=index, weight=weight) - - output = output.trt_tensor - output.name = 'output' - network.mark_output(output) + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork( - (builder.trt_builder, net.trt_network), - config=CreateConfig(fp16=(dtype == 'float16'))) - - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'index': index_np, - 'weight': weight_np - }) + session = create_session(builder, network, precision=dtype) + inputs = { + 'index': index_data, + 'weight': weight_data, + } + outputs = run_session(session, inputs) # pytorch run embedding = torch.nn.Embedding.from_pretrained(weight_data) ref = embedding(index_data) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-3) + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_exp.py b/tests/functional/test_exp.py index 471d0ffd0..f1d5b33ba 100644 --- a/tests/functional/test_exp.py +++ b/tests/functional/test_exp.py @@ -12,43 +12,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestExp(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_exp(self): dtype = 'float32' - x_data = torch.randn(2, 3, 4, 5) + x_data = torch.randn(2, 3, 4, 5, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.exp(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.exp(x) + output.mark_output('output', dtype) - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.exp(x_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_expand.py b/tests/functional/test_expand.py index ff91560af..0b57e085a 100644 --- a/tests/functional/test_expand.py +++ b/tests/functional/test_expand.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest import numpy as np @@ -20,114 +22,116 @@ import torch import tensorrt as trt # isort: on -from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile, - TrtRunner) import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestExpand(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - def test_expand_1(self): + def test_expand_2d(self): # test data dtype = 'float32' input_shape = (1, 10) output_shape = (2, 10) - input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_data = torch.rand(input_shape, + dtype=tensorrt_llm.str_dtype_to_torch(dtype), + device="cuda") shape_data = torch.tensor(output_shape).int() # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) shape = Tensor(name='shape', shape=(len(input_shape), ), dtype=trt.int32) - output = tensorrt_llm.functional.expand(input, shape).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.expand(input, shape) + output.mark_output('output', dtype) - # trt run - profiles = [Profile().add('shape', (1, 1), input_shape, (10, 10))] - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network), - config=CreateConfig(profiles=profiles)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'input': input_data.numpy(), - 'shape': shape_data.numpy() - }) + profile = builder.trt_builder.create_optimization_profile() + profile.set_shape_input('shape', output_shape, output_shape, + output_shape) + + session = create_session(builder, + network, + precision=dtype, + optimization_profiles=[profile]) + inputs = {'input': input_data, 'shape': shape_data} + outputs = run_session(session, inputs) # pytorch run ref = input_data.expand(output_shape) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + torch.testing.assert_close(ref, outputs['output']) - def test_expand_2(self): + def test_expand_4d(self): # test data dtype = 'float32' input_shape = (2, 1, 1, 10) output_shape = (2, 1, 12, 10) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") shape_data = torch.tensor(output_shape).int() # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) shape = Tensor(name='shape', shape=(len(input_shape), ), dtype=trt.int32) - output = tensorrt_llm.functional.expand(input, shape).trt_tensor - output.name = 'output' - network.mark_output(output) - - # trt run - profiles = [ - Profile().add('shape', (1, 1, 1, 1), input_shape, (10, 10, 10, 10)) - ] - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network), - config=CreateConfig(profiles=profiles)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'input': input_data.numpy(), - 'shape': shape_data.numpy() - }) + output = tensorrt_llm.functional.expand(input, shape) + output.mark_output('output') + + profile = builder.trt_builder.create_optimization_profile() + profile.set_shape_input('shape', output_shape, output_shape, + output_shape) + session = create_session(builder, + network, + precision=dtype, + optimization_profiles=[profile]) + inputs = {'input': input_data, 'shape': shape_data} + outputs = run_session(session, inputs) # pytorch run ref = input_data.expand(output_shape) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + torch.testing.assert_close(ref, outputs['output']) - def test_expand_3(self): + def test_expand_implicit(self): # test data dtype = 'float32' hidden_dim = 10 input_shape = (1, hidden_dim) batch_size = 8 - input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_data = torch.rand(input_shape, + dtype=tensorrt_llm.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -135,18 +139,18 @@ def test_expand_3(self): np.array([0] * batch_size, dtype=np.int32)) expand_shape = tensorrt_llm.functional.concat( [tensorrt_llm.functional.shape(input_length, 0), hidden_dim]) - output = tensorrt_llm.functional.expand(input, - expand_shape).trt_tensor - output.name = 'output' - network.mark_output(output) - output.dtype = tensorrt_llm.str_dtype_to_trt(dtype) + output = tensorrt_llm.functional.expand(input, expand_shape) + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = input_data.expand([batch_size, hidden_dim]) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_flip.py b/tests/functional/test_flip.py index 1af7e2ded..281122b62 100644 --- a/tests/functional/test_flip.py +++ b/tests/functional/test_flip.py @@ -12,41 +12,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestFlip(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_flip(self): dtype = 'float32' - x_data = torch.randn(4, 6, 3, 4) + x_data = torch.randn(4, 6, 3, 4, device="cuda") dims = [-2, 0, -3] + + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.flip(x, dims).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.flip(x, dims) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data.numpy()}) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.flip(x_data, dims) - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_gather.py b/tests/functional/test_gather.py index 4c1341e20..054870738 100644 --- a/tests/functional/test_gather.py +++ b/tests/functional/test_gather.py @@ -12,30 +12,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestGather(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_gather(self): dtype = 'float32' - x_data = torch.randn(2, 128, 768) - y_data = torch.tensor([101, 127]).int() + x_data = torch.randn(2, 128, 768, device="cuda") + y_data = torch.tensor([101, 127], device="cuda").int() + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -57,24 +61,19 @@ def test_gather(self): tensorrt_llm.functional.shape(x, 0), tensorrt_llm.functional.shape(x, 2) ])) + output.mark_output('output', dtype) - output = output.trt_tensor - output.name = 'output' - network.mark_output(output) - - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - 'y': y_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'x': x_data, 'y': y_data} + outputs = run_session(session, inputs) + # pytorch run y_data = y_data.reshape(y_data.size(0), 1, 1) y_data = y_data.expand(y_data.size(0), 1, x_data.size(-1)) ref = torch.gather(x_data, dim=1, index=y_data.to(dtype=torch.int64)).view( x_data.size(0), x_data.size(2)) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_gather_nd.py b/tests/functional/test_gather_nd.py index ac297334e..1cf479650 100644 --- a/tests/functional/test_gather_nd.py +++ b/tests/functional/test_gather_nd.py @@ -12,21 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np - # isort: off import torch # isort: on from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session + -class TestFunctional(unittest.TestCase): +class TestGatherND(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -71,26 +73,29 @@ def setUp(self): [6, 7], [4, 5]]]), ( - torch.rand((2, 9, 4), dtype=torch.float32), + torch.rand((2, 9, 4), dtype=torch.float32, device="cuda"), torch.tensor([[[0, 1, 2, 3, 4, 5], [0, 1, 3, 4, 5, 6], [0, 1, 4, 5, 6, 7], [0, 2, 3, 4, 6, 8]], [[0, 1, 2, 3, 4, 5], [0, 1, 3, 4, 5, 7], - [0, 2, 3, 5, 6, 7], [0, 3, 4, 5, 6, 7]]]), + [0, 2, 3, 5, 6, 7], [0, 3, 4, 5, 6, 7]]], + device="cuda"), [], ), ]) def test_gatherND(self, data, indices, ref): - data = data if isinstance(data, torch.Tensor) else torch.tensor(data) - indices = indices if isinstance(indices, - torch.Tensor) else torch.tensor(indices) - ref = ref if isinstance(ref, torch.Tensor) else torch.tensor(ref) + dtype = "float32" + data = data if isinstance(data, + torch.Tensor) else torch.tensor(data).cuda() + indices = indices if isinstance( + indices, torch.Tensor) else torch.tensor(indices).cuda() + ref = ref if isinstance(ref, torch.Tensor) else torch.tensor(ref).cuda() indices = indices.unsqueeze(-1) # needed for TRT gatherND # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + d = Tensor(name='d', shape=data.shape, dtype=tensorrt_llm.torch_dtype_to_trt(data.dtype)) @@ -98,64 +103,46 @@ def test_gatherND(self, data, indices, ref): shape=indices.shape, dtype=tensorrt_llm.torch_dtype_to_trt(indices.dtype)) - output = tensorrt_llm.functional.gather_nd(d, idx, 1).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.gather_nd(d, idx, 1) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'd': data.numpy(), - 'idx': indices.numpy(), - }) - + session = create_session(builder, network, precision=dtype) + inputs = {'d': data, 'idx': indices} + outputs = run_session(session, inputs) # compare diff indices = indices.squeeze(-1) tref = torch.stack([data[i, indices[i]] for i in range(data.shape[0])]) if ref.numel() == 0: - np.testing.assert_allclose(tref, outputs['output'], atol=1e-5) + torch.testing.assert_close(tref, outputs['output']) else: - np.testing.assert_allclose(ref, outputs['output'], atol=1e-5) - np.testing.assert_allclose(ref, tref, atol=1e-5) - return - - @parameterized.expand([( - [[91, 92, 93, 95, -1, -1, 94, 96, -1, -1, -1, 97], - [93, 94, 95, 92, -1, 95, 96, 93, -1, -1, 97, 96]], - [ - # [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], - # [0, 1, 2, 3, 6, 7, 11, 0, 1, 2, 3, 5, 6, 7, 10, 11] - [0, 0], - [0, 1], - [0, 2], - [0, 3], - [0, 6], - [0, 7], - [0, 11], - [1, 0], - [1, 1], - [1, 2], - [1, 3], - [1, 5], - [1, 6], - [1, 7], - [1, 10], - [1, 11] - ], - [91, 92, 93, 95, 94, 96, 97, 93, 94, 95, 92, 95, 96, 93, 97, 96])]) + torch.testing.assert_close(ref, outputs['output']) + torch.testing.assert_close(ref, tref) + + @parameterized.expand([ + ([[91, 92, 93, 95, -1, -1, 94, 96, -1, -1, -1, 97], + [93, 94, 95, 92, -1, 95, 96, 93, -1, -1, 97, 96]], [[0, 0], [0, 1], + [0, 2], [0, 3], + [0, 6], [0, 7], + [0, 11], [1, 0], + [1, 1], [1, 2], + [1, 3], [1, 5], + [1, 6], [1, 7], + [1, 10], [1, 11]], + [91, 92, 93, 95, 94, 96, 97, 93, 94, 95, 92, 95, 96, 93, 97, 96]) + ]) def test_gatherND_b0(self, data, indices, ref): - data = data if isinstance(data, torch.Tensor) else torch.tensor(data) - indices = indices if isinstance(indices, - torch.Tensor) else torch.tensor(indices) - ref = ref if isinstance(ref, torch.Tensor) else torch.tensor(ref) - # indices = indices.unsqueeze(-1) # needed for TRT gatherND + dtype = "float32" + data = data if isinstance(data, + torch.Tensor) else torch.tensor(data).cuda() + indices = indices if isinstance( + indices, torch.Tensor) else torch.tensor(indices).cuda() + ref = ref if isinstance(ref, torch.Tensor) else torch.tensor(ref).cuda() # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): d = Tensor(name='d', shape=data.shape, dtype=tensorrt_llm.torch_dtype_to_trt(data.dtype)) @@ -163,51 +150,34 @@ def test_gatherND_b0(self, data, indices, ref): shape=indices.shape, dtype=tensorrt_llm.torch_dtype_to_trt(indices.dtype)) - output = tensorrt_llm.functional.gather_nd(d, idx, 0).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.gather_nd(d, idx, 0) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'd': data.numpy(), - 'idx': indices.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = {'d': data, 'idx': indices} + outputs = run_session(session, inputs) # compare diff - # indices = indices.squeeze(-1) tref = data[indices[:, 0], indices[:, 1]] if ref.numel() == 0: - np.testing.assert_allclose(tref, outputs['output'], atol=1e-5) + torch.testing.assert_close(tref, outputs['output']) else: - np.testing.assert_allclose(ref, outputs['output'], atol=1e-5) - np.testing.assert_allclose(ref, tref, atol=1e-5) - return - - -#### - - def test_gatherND_selectH(self): #, data, indices, ref): - # This usecase is used to gather in ReDrafter for validated end-tokens ( diff stopping point for diff seqs ) - data = torch.rand((2, 9, 4), dtype=torch.float32) - indices = torch.randint(9, size=(2, ), dtype=torch.int32) - ref = [] - data = data if isinstance(data, torch.Tensor) else torch.tensor(data) - indices = indices if isinstance(indices, - torch.Tensor) else torch.tensor(indices) - ref = ref if isinstance(ref, torch.Tensor) else torch.tensor(ref) - indices = torch.stack([torch.arange(2, dtype=torch.int32), indices], - dim=1) - # print(data) - # print(indices) - # indices = indices.unsqueeze(-1) # needed for TRT gatherND + torch.testing.assert_close(ref, outputs['output']) + torch.testing.assert_close(ref, tref) + + def test_gatherND_selectH(self): + dtype = "float32" + # This usecase is used to gather in ReDrafter for validated end-tokens (diff stopping point for diff seqs) + data = torch.rand((2, 9, 4), dtype=torch.float32, device="cuda") + indices = torch.randint(9, size=(2, ), dtype=torch.int32, device="cuda") + indices = torch.stack( + [torch.arange(2, dtype=torch.int32).cuda(), indices], dim=1) # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): d = Tensor(name='d', shape=data.shape, dtype=tensorrt_llm.torch_dtype_to_trt(data.dtype)) @@ -215,28 +185,16 @@ def test_gatherND_selectH(self): #, data, indices, ref): shape=indices.shape, dtype=tensorrt_llm.torch_dtype_to_trt(indices.dtype)) - output = tensorrt_llm.functional.gather_nd(d, idx, 0).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.gather_nd(d, idx, 0) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'd': data.numpy(), - 'idx': indices.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = {'d': data, 'idx': indices} + outputs = run_session(session, inputs) + + # pytorch run + ref = data[indices[:, 0], indices[:, 1]] # compare diff - # indices = indices.squeeze(-1) - tref = data[indices[:, 0], indices[:, 1]] - # tref = torch.stack([data[i, indices[i]] for i in range(data.shape[0])]) - if ref.numel() == 0: - np.testing.assert_allclose(tref, outputs['output'], atol=1e-5) - else: - np.testing.assert_allclose(ref, outputs['output'], atol=1e-5) - np.testing.assert_allclose(ref, tref, atol=1e-5) - # print(tref.numpy()) - # print(outputs['output']) - # assert False, "FORCED" - return + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_geglu.py b/tests/functional/test_geglu.py index bd3d92389..ebda46d70 100644 --- a/tests/functional/test_geglu.py +++ b/tests/functional/test_geglu.py @@ -16,52 +16,49 @@ import sys import unittest -import numpy as np import torch -from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner from torch_ref import geglu import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): +class TestGeglu(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - @parameterized.expand([('float32', )], name_func=unittest_name_func) - def test_geglu(self, dtype): + def test_geglu(self): + dtype = 'float32' # test data x_shape = (12, 2, 96) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.geglu(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.geglu(x) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run - ref = geglu(x_data.cuda()) + ref = geglu(x_data) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-3) + torch.testing.assert_close(ref, outputs['output'], atol=1e-3, rtol=1e-2) diff --git a/tests/functional/test_gelu.py b/tests/functional/test_gelu.py index 35513ee72..72517f204 100644 --- a/tests/functional/test_gelu.py +++ b/tests/functional/test_gelu.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import math import os import sys @@ -20,16 +19,16 @@ import torch from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import skip_bf16_pre_ampere, unittest_name_func +from utils.util import (create_session, run_session, skip_bf16_pre_ampere, + unittest_name_func) -class TestFunctional(unittest.TestCase): +class TestGelu(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -43,47 +42,40 @@ def gelu(x, dtype): math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) return res - @parameterized.expand(itertools.product( - ('float32', 'float16', 'bfloat16'), - (False, True), - ), + @parameterized.expand(('float32', 'float16', 'bfloat16'), name_func=unittest_name_func) - def test_gelu(self, dtype, strongly_typed): + def test_gelu(self, dtype): # Skip tests that are not supported in pre-ampere architecture skip_bf16_pre_ampere(dtype) torch_dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype) x_shape = (12, 12, 96, 96) - x_data = torch.rand(x_shape, dtype=torch_dtype) + x_data = torch.rand(x_shape, dtype=torch_dtype, device="cuda") # construct trt network builder = tensorrt_llm.Builder() - builder.strongly_typed = strongly_typed - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.gelu(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.gelu(x) + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data}) - out = outputs['output'].to(torch_dtype) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) - # Reference - ref = self.gelu(x_data, dtype) + # pytorch run + ref = self.gelu(x_data, dtype).to(torch_dtype) + # compare diff if dtype == 'bfloat16': atol, rtol = 1e-5, 2e-2 else: atol, rtol = 1e-5, 2e-3 - torch.testing.assert_close(out, ref, atol=atol, rtol=rtol) - - -if __name__ == '__main__': - unittest.main() + torch.testing.assert_close(outputs['output'], ref, atol=atol, rtol=rtol) diff --git a/tests/functional/test_gemm_swiglu.py b/tests/functional/test_gemm_swiglu.py new file mode 100644 index 000000000..6cb349a65 --- /dev/null +++ b/tests/functional/test_gemm_swiglu.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import unittest + +import numpy as np +import pytest +import tensorrt as trt +import torch +from parameterized import parameterized +from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner + +import tensorrt_llm +from tensorrt_llm import Tensor +from tensorrt_llm._utils import str_dtype_to_torch, str_dtype_to_trt +from tensorrt_llm.functional import gemm_swiglu + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +# Monkey Patching for torch.float8_e4m3fn support +from polygraphy.datatype import DataType +from utils.util import getSMVersion + +original_to_dtype = DataType.to_dtype + + +def patched_to_dtype(dtype, target_module): + if dtype == DataType.FLOAT8E4M3FN and target_module == 'torch': + return torch.float8_e4m3fn + else: + return original_to_dtype(dtype, target_module) + + +DataType.to_dtype = patched_to_dtype + + +class TestGemmSwiglu(unittest.TestCase): + + def setUp(self): + tensorrt_llm.logger.set_level('error') + + def reference_gemm_swiglu_sm90(self, x: torch.Tensor, w: torch.Tensor, + scale_d0: float, scale_d1: float, + scale_output: float, dtype): + silu = torch.nn.SiLU() + y = torch.matmul(x.to(torch.float32), w.to(torch.float32)) + split, split_gate = torch.split(y, y.size(1) // 2, dim=1) + y_swiglu = ( + (scale_d0 * split) * silu(scale_d1 * split_gate)) * scale_output + return y_swiglu.to(str_dtype_to_torch(dtype)) + + def run_gemm_swiglu_sm90(self, m, n, k, scale_d0, scale_d1, scale_output, + dtype): + assert n % 32 == 0, "dim N must be a integer multiples of 32" + assert k % 16 == 0, "dim K must be a integer multiples of 16" + + torch.random.manual_seed(42) + + shape_x = (m, k) + x = torch.randint(-2, 2, shape_x).to(str_dtype_to_torch(dtype)) + shape_w = (k, n) + w = torch.randint(-2, 2, shape_w).to(str_dtype_to_torch(dtype)) + + # Create builder + builder = tensorrt_llm.Builder() + # Create empty network + net = builder.create_network() + # Allow plugin of dtype type + net.plugin_config.set_gemm_swiglu_plugin(dtype) + with tensorrt_llm.net_guard(net): + network = tensorrt_llm.default_trtnet() + # Init TensorRT-LLM tensor for x + x_tensor = Tensor(name='x', + shape=x.shape, + dtype=str_dtype_to_trt(dtype)) + # Init TensorRT-LLM tensor for w + w_tensor = Tensor(name='w', + shape=w.shape, + dtype=str_dtype_to_trt(dtype)) + # Get output tensor + output = gemm_swiglu(x_tensor, w_tensor, None, scale_d0, scale_d1, + scale_output).trt_tensor + output.name = 'output' + network.mark_output(output) + output.dtype = str_dtype_to_trt(dtype) + + # Build engine + build_engine = EngineFromNetwork( + (builder.trt_builder, net.trt_network), + config=CreateConfig( + fp16=(dtype == "float16"), + fp8=(dtype == 'fp8'), + memory_pool_limits={trt.MemoryPoolType.WORKSPACE: 33554432})) + + # Infer engine + feed_dict = {'x': x, 'w': w.t().reshape(shape_w)} + with TrtRunner(build_engine) as runner: + outputs = runner.infer(feed_dict=feed_dict, check_inputs=False) + ref = self.reference_gemm_swiglu_sm90(x, w, scale_d0, scale_d1, + scale_output, dtype) + # print(f"ref:\n{ref.float().cpu().numpy()}") + # print(f"trt:\n{outputs['output'].float()}") + np.testing.assert_allclose(ref.float().cpu().numpy(), + outputs['output'].float(), + rtol=1e-3) + + @parameterized.expand([('fp8')]) + @pytest.mark.skipif(getSMVersion() != 90, + reason="GemmSwigluSm90 is only supported in SM90" + ) # Skip tests that are not supported in SM90 + def test_gemm_swiglu_sm90(self, dtype): + bs = 2 + inseq = 13 + hidden_size = 256 + out_size = 32 + scale_d0 = 0.2 + scale_d1 = 1.3 + scale_output = 0.001 + + self.run_gemm_swiglu_sm90(bs * inseq, out_size, hidden_size, scale_d0, + scale_d1, scale_output, dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/functional/test_group_norm.py b/tests/functional/test_group_norm.py index 71b35bdb5..efa2b34be 100644 --- a/tests/functional/test_group_norm.py +++ b/tests/functional/test_group_norm.py @@ -16,19 +16,17 @@ import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func -class TestFunctional(unittest.TestCase): +class TestFGroupNorm(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -41,32 +39,29 @@ def test_group_norm(self, dtype): num_groups = 3 x_shape = (2, num_channels, 3, 3) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.group_norm(x, - num_groups).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.group_norm(x, num_groups) + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork( - (builder.trt_builder, net.trt_network), - config=CreateConfig(fp16=(dtype == 'float16'))) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run - ref = torch.nn.functional.group_norm(x_data.cuda(), num_groups) + ref = torch.nn.functional.group_norm(x_data, num_groups) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-2) + torch.testing.assert_close(ref, outputs['output'], atol=1e-2, rtol=1e-2) diff --git a/tests/functional/test_identity.py b/tests/functional/test_identity.py index d1d8489a6..61629ff27 100644 --- a/tests/functional/test_identity.py +++ b/tests/functional/test_identity.py @@ -16,19 +16,18 @@ import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import skip_bf16_pre_ampere, unittest_name_func +from utils.util import (create_session, run_session, skip_bf16_pre_ampere, + unittest_name_func) -class TestFunctional(unittest.TestCase): +class TestIdentity(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -42,27 +41,29 @@ def test_identity(self, dtype, use_plugin): skip_bf16_pre_ampere(dtype) x_data = torch.randn( - (4, 6, 3, 4), dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + (4, 6, 3, 4), + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") + + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() + network = builder.create_network() if use_plugin: - net.plugin_config.identity_plugin = dtype - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network.plugin_config.identity_plugin = dtype + + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.identity(x).trt_tensor - output.name = 'output' - network.mark_output(output) - output.dtype = tensorrt_llm.str_dtype_to_trt(dtype) + output = tensorrt_llm.functional.identity(x) + output.mark_output('output', dtype) - build_engine = EngineFromNetwork( - (builder.trt_builder, net.trt_network), - config=CreateConfig(fp16=(dtype == 'float16'), - bf16=(dtype == 'bfloat16'))) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data}) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) - np.testing.assert_allclose(x_data.to(torch.float32), - outputs['output'].to(torch.float32)) + # compare diff + torch.testing.assert_close(x_data, outputs['output']) diff --git a/tests/functional/test_index_select.py b/tests/functional/test_index_select.py index c76626fbe..3be12a836 100644 --- a/tests/functional/test_index_select.py +++ b/tests/functional/test_index_select.py @@ -12,49 +12,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestIndexSelect(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_index_select(self): dtype = 'float32' - x_data = torch.randn(1, 512, 4) - y_data = torch.tensor([128, 256, 384, 512]).int() - 1 + x_data = torch.randn(1, 512, 4, device="cuda") + y_data = torch.tensor([128, 256, 384, 512], device="cuda").int() - 1 + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) y = Tensor(name='y', shape=y_data.shape, dtype=tensorrt_llm.str_dtype_to_trt('int32')) - output = tensorrt_llm.functional.index_select(x, dim=1, - index=y).trt_tensor - output.name = 'output' - network.mark_output(output) - - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - 'y': y_data.numpy() - }) + output = tensorrt_llm.functional.index_select(x, dim=1, index=y) + output.mark_output('output') + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'x': x_data, 'y': y_data} + outputs = run_session(session, inputs) + + # pytorch run ref = torch.index_select(x_data, dim=1, index=y_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_interpolate.py b/tests/functional/test_interpolate.py index fcf2a1fee..ec663e45d 100644 --- a/tests/functional/test_interpolate.py +++ b/tests/functional/test_interpolate.py @@ -12,17 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestInterpolate(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -34,15 +37,17 @@ def test_interpolate_without_scales_nearest_5d(self): output_shape = (16, 24, 32) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") mode = 'nearest' # construct trt network align_corners_flag = False builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -51,20 +56,23 @@ def test_interpolate_without_scales_nearest_5d(self): size=output_shape, mode=mode, align_corners=align_corners_flag, - ).trt_tensor - output.name = 'output' - network.mark_output(output) + ) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.interpolate(input_data, size=output_shape, mode=mode) - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + # compare diff + torch.testing.assert_close(ref, outputs['output']) def test_interpolate_without_scales_bilinear_4d_disable_align_corner(self): # test data @@ -73,15 +81,17 @@ def test_interpolate_without_scales_bilinear_4d_disable_align_corner(self): output_shape = (16, 24) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") mode = 'bilinear' # construct trt network align_corners_flag = False builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -90,22 +100,23 @@ def test_interpolate_without_scales_bilinear_4d_disable_align_corner(self): size=output_shape, mode=mode, align_corners=align_corners_flag, - ).trt_tensor - output.name = 'output' - network.mark_output(output) + ) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.interpolate(input_data, size=output_shape, mode=mode) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + # compare diff + torch.testing.assert_close(ref, outputs['output']) def test_interpolate_without_scales_bilinear_4d_enable_align_corner(self): # test data @@ -114,15 +125,17 @@ def test_interpolate_without_scales_bilinear_4d_enable_align_corner(self): output_shape = (16, 24) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") mode = 'bilinear' # construct trt network align_corners_flag = True builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -131,22 +144,23 @@ def test_interpolate_without_scales_bilinear_4d_enable_align_corner(self): size=output_shape, mode=mode, align_corners=align_corners_flag, - ).trt_tensor - output.name = 'output' - network.mark_output(output) + ) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.interpolate(input_data, size=output_shape, mode=mode) - - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + # compare diff + torch.testing.assert_close(ref, outputs['output']) def test_interpolate_without_scales_bicubic_4d_enable_align_corner(self): # test data @@ -155,15 +169,17 @@ def test_interpolate_without_scales_bicubic_4d_enable_align_corner(self): output_shape = (16, 24) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") mode = 'bicubic' # construct trt network align_corners_flag = True builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -172,22 +188,23 @@ def test_interpolate_without_scales_bicubic_4d_enable_align_corner(self): size=output_shape, mode=mode, align_corners=align_corners_flag, - ).trt_tensor - output.name = 'output' - network.mark_output(output) + ) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.interpolate(input_data, size=output_shape, mode=mode) - - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-3) + # compare diff + torch.testing.assert_close(ref, outputs['output']) def test_interpolate_with_scale_3d_nearest_exact(self): # test data @@ -195,15 +212,17 @@ def test_interpolate_with_scale_3d_nearest_exact(self): input_shape = (1, 4, 8, 16) scales_factor = (2, 4) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") mode = 'nearest-exact' # construct trt network align_corners_flag = False builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -212,19 +231,23 @@ def test_interpolate_with_scale_3d_nearest_exact(self): scale_factor=scales_factor, mode=mode, align_corners=align_corners_flag, - ).trt_tensor - output.name = 'output' - network.mark_output(output) + ) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + outputs = run_session(session, inputs) + + # pytorch run ref = torch.nn.functional.interpolate(input_data, scale_factor=scales_factor, mode=mode) - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + # compare diff + torch.testing.assert_close(ref, outputs['output']) def test_interpolate_with_scale_4d_bicubic(self): # test data @@ -232,15 +255,17 @@ def test_interpolate_with_scale_4d_bicubic(self): input_shape = (1, 4, 8, 12) scales_factor = (2.5, 2) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") mode = 'bicubic' # construct trt network align_corners_flag = False builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -249,21 +274,24 @@ def test_interpolate_with_scale_4d_bicubic(self): scale_factor=scales_factor, mode=mode, align_corners=align_corners_flag, - ).trt_tensor - output.name = 'output' - network.mark_output(output) + ) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.interpolate(input_data, scale_factor=scales_factor, mode=mode) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) def test_interpolate_with_scale_4d_bilinear(self): # test data @@ -271,15 +299,17 @@ def test_interpolate_with_scale_4d_bilinear(self): input_shape = (1, 1, 8, 32) scales_factor = (2.5, 4) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") mode = 'bilinear' # construct trt network align_corners_flag = False builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -288,22 +318,25 @@ def test_interpolate_with_scale_4d_bilinear(self): scale_factor=scales_factor, mode=mode, align_corners=align_corners_flag, - ).trt_tensor - output.name = 'output' - network.mark_output(output) + ) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + outputs = run_session(session, inputs) + + # pytorch run ref = torch.nn.functional.interpolate(input_data, scale_factor=scales_factor, align_corners=align_corners_flag, mode=mode) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) def test_interpolate_with_scale_5d_trilinear_enable_align_corner(self): # test data @@ -311,15 +344,17 @@ def test_interpolate_with_scale_5d_trilinear_enable_align_corner(self): input_shape = (1, 1, 8, 16, 32) scales_factor = (2.5, 2, 4) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") mode = 'trilinear' # construct trt network align_corners_flag = True builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -328,19 +363,21 @@ def test_interpolate_with_scale_5d_trilinear_enable_align_corner(self): scale_factor=scales_factor, mode=mode, align_corners=align_corners_flag, - ).trt_tensor - output.name = 'output' - network.mark_output(output) + ) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.interpolate(input_data, scale_factor=scales_factor, align_corners=align_corners_flag, mode=mode) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_logsoftmax.py b/tests/functional/test_logsoftmax.py index 3a6af43a1..944cbe3ae 100644 --- a/tests/functional/test_logsoftmax.py +++ b/tests/functional/test_logsoftmax.py @@ -12,38 +12,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest from itertools import product -import numpy as np - # isort: off import torch # isort: on from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session, unittest_name_func + -class TestFunctional(unittest.TestCase): +class TestLogSoftmax(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - def test_lt(self, dtype='float32'): + def test_lt(self): + dtype = 'float32' t_shape = (2, 3) x_data = torch.rand(t_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") y_data = torch.rand(t_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=t_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -51,127 +55,113 @@ def test_lt(self, dtype='float32'): shape=t_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.lt(x, y).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.lt(x, y) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - 'y': y_data.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = {'x': x_data, 'y': y_data} + outputs = run_session(session, inputs) # pytorch run ref = torch.lt(x_data, y_data) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) - return + torch.testing.assert_close(ref, outputs['output']) @parameterized.expand(list(product(['float32']))) def test_log(self, dtype): # test data x_shape = (4, 6, 8) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.log(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.log(x) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run ref = x_data.log() # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + torch.testing.assert_close(ref, outputs['output']) - @parameterized.expand(list(product(['float32'], [0, 1, 2], [False, True]))) + @parameterized.expand(list(product(['float32'], [0, 1, 2], [False, True])), + name_func=unittest_name_func) def test_sum(self, dtype, dim, keepdim): # test data x_shape = (4, 6, 8) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.sum(x, dim, keepdim).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.sum(x, dim, keepdim) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run ref = x_data.sum(dim=dim, keepdim=keepdim) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + torch.testing.assert_close(ref, outputs['output']) - @parameterized.expand(list(product(['float32'], [0, 1, 2]))) + @parameterized.expand(list(product(['float32'], [0, 1, 2])), + name_func=unittest_name_func) def test_log_softmax(self, dtype, dim): # test data x_shape = (4, 6, 8) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.log_softmax(x, dim=dim).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.log_softmax(x, dim=dim) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run ref = x_data.log_softmax(dim=dim) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_matmul.py b/tests/functional/test_matmul.py index 3accef2a3..7a013c390 100644 --- a/tests/functional/test_matmul.py +++ b/tests/functional/test_matmul.py @@ -14,72 +14,60 @@ # limitations under the License. import unittest -import numpy as np - # isort: off import torch -import tensorrt as trt # isort: on import os import sys from parameterized import parameterized -from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import skip_bf16_pre_ampere, unittest_name_func +from utils.util import (create_session, run_session, skip_bf16_pre_ampere, + unittest_name_func) class TestMatmul(unittest.TestCase): def setUp(self): + torch.backends.cudnn.allow_tf32 = False tensorrt_llm.logger.set_level('error') def _matmul(self, m, n, k, dtype, ta, tb): shape1 = (k, m) if ta else (m, k) - mat1 = torch.randn( - shape1, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) * 1e-1 + mat1 = torch.randn(shape1, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") * 1e-1 shape2 = (n, k) if tb else (k, n) - mat2 = torch.randn( - shape2, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) * 1e-1 + mat2 = torch.randn(shape2, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") * 1e-1 + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=mat1.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) y = Tensor(name='y', shape=mat2.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.matmul(x, y, transa=ta, - transb=tb).trt_tensor - output.name = 'output' - network.mark_output(output) - output.dtype = tensorrt_llm.str_dtype_to_trt(dtype) - - build_engine = EngineFromNetwork( - (builder.trt_builder, net.trt_network), - config=CreateConfig( - fp16=(dtype == 'float16'), - bf16=(dtype == 'bfloat16'), - precision_constraints='obey', - memory_pool_limits={trt.MemoryPoolType.WORKSPACE: 33554432})) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': mat1, 'y': mat2}) + output = tensorrt_llm.functional.matmul(x, y, transa=ta, transb=tb) + output.mark_output('output', dtype) - if ta: - mat1 = mat1.cuda().transpose(0, 1) - if tb: - mat2 = mat2.cuda().transpose(0, 1) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'x': mat1, 'y': mat2} + outputs = run_session(session, inputs) tols = { "float32": { - "rtol": 1e-05, - "atol": 1e-05 + "rtol": 4e-4, + "atol": 1e-02 }, "float16": { "rtol": 1e-02, @@ -91,17 +79,13 @@ def _matmul(self, m, n, k, dtype, ta, tb): }, } - if dtype != "float32": - mat1 = mat1.cuda() - mat2 = mat2.cuda() - else: - mat1 = mat1.cpu() - mat2 = mat2.cpu() - - ref = torch.matmul(mat1, mat2).to(torch.float32) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'].to(torch.float32), - **tols[dtype]) + # pytorch run + if ta: + mat1 = mat1.transpose(0, 1) + if tb: + mat2 = mat2.transpose(0, 1) + ref = torch.matmul(mat1, mat2) + torch.testing.assert_close(ref, outputs['output'], **tols[dtype]) @parameterized.expand([('float16', False, False), ('float16', False, True), ('float16', True, False), ('float16', True, True), @@ -128,31 +112,30 @@ def test_matmul(self, dtype, transa, transb): def test_matmul_broadcast(self): dtype = 'float32' - x_data = torch.randn(16, 4, 4, 5) - y_data = torch.randn(16, 1, 5, 4) + x_data = torch.randn(16, 4, 4, 5, device="cuda") + y_data = torch.randn(16, 1, 5, 4, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) y = Tensor(name='y', shape=y_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.matmul(x, y).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.matmul(x, y) + output.mark_output('output', dtype) - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - 'y': y_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'x': x_data, 'y': y_data} + outputs = run_session(session, inputs) + # pytorch run ref = torch.matmul(x_data, y_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_moe.py b/tests/functional/test_moe.py index ae4efbf5c..1c7609c49 100644 --- a/tests/functional/test_moe.py +++ b/tests/functional/test_moe.py @@ -12,8 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import math import unittest +from collections import OrderedDict import numpy as np @@ -25,8 +27,6 @@ import sys from parameterized import parameterized -from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile, - TrtRunner) import tensorrt_llm from tensorrt_llm import Tensor @@ -36,7 +36,8 @@ from tensorrt_llm.quantization import QuantMode sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import getSMVersion, skip_bf16_pre_ampere, unittest_name_func +from utils.util import (create_session, getSMVersion, run_session, + skip_bf16_pre_ampere, unittest_name_func) default_actfn = 'gelu' default_hidden_size = { @@ -85,7 +86,7 @@ def config_is_allowed(config): def gen_uniform_weights(*args, **kwargs): - return (torch.rand(*args, **kwargs) * 2 - 1).contiguous() + return (torch.rand(*args, **kwargs) * 2 - 1).contiguous().cuda() def quant_dequant_int(weights, quant_mode): @@ -148,7 +149,7 @@ def gated_matmul(input, weights, bias, actfn): return fc1 * doact(gate, gated2act(actfn)) -class TestFunctional(unittest.TestCase): +class TestMoE(unittest.TestCase): def setUp(self): # There is a known precision issues where the topk may select different experts when the routing probabilities are similar. @@ -417,14 +418,15 @@ def test_mixture_of_experts(self, num_experts, top_k, hidden_size, actfn, act_2_quant = 0.0 for i, input in enumerate(inputs): - result, act2_quant_values = self.referenceImpl( + result, act2_quant_values = self.generate_reference( input, top_k, actfn, weight_dtype, quant_mode, norm_mode) - reference_values.append(result.cpu().float()) + reference_values.append(result) act_2_quant = max(act_2_quant, act2_quant_values) self.create_fp8_scaling_factors(act_1_quant, act_2_quant) - engine = self.buildTrtEngine( + # build trt engine + session = self.create_trt_session( (-1, -1, hidden_size), num_experts, top_k, @@ -440,8 +442,9 @@ def test_mixture_of_experts(self, num_experts, top_k, hidden_size, actfn, max_sizes=[max_num_seq, max_seq_len, hidden_size]) for input, ref in zip(inputs, reference_values): - # construct trt network - trt_res = self.runTrtEngine(engine, input)['output'].float() + # run trt output + inputs = {"input_hidden_states": input} + outputs = run_session(session, inputs) tolerances = { 'float32': 1e-2, @@ -455,18 +458,11 @@ def test_mixture_of_experts(self, num_experts, top_k, hidden_size, actfn, # Bit of a hack to allow bigger tolerance for the Mixtral tests if hidden_size > 1024: - # Do some extra checks on the full distribution - self.assertAlmostEqual(np.mean((trt_res - ref).numpy()), - 0.0, - delta=2e-4) - self.assertAlmostEqual(np.var((trt_res - ref).numpy()), - 0.0, - delta=tolerance) # Set a higher tolerance because we hit a small fraction of outlier cases (<<1%) tolerance = 0.3 - np.testing.assert_allclose(trt_res, - ref, + torch.testing.assert_close(outputs['output'].float(), + ref.float(), rtol=tolerance, atol=tolerance) @@ -554,23 +550,26 @@ def MLP(network, trt_key): mlp.proj.bias.value = np.ascontiguousarray( torch_to_numpy(self.fc2_bias[0].cpu())) - output = mlp(trt_key).trt_tensor - output.name = 'mlp_output' - network.mark_output(output) - output.dtype = dtype + output = mlp(trt_key) + output.mark_output('mlp_output', dtype) - res = self.trtImpl(input_data, - num_experts, - top_k, - hidden_size, - ffn_hidden_size, - actfn, - bias, - dtype, - weight_dtype=weight_dtype, - quant_mode=quant_mode, - custom_network=MLP, - use_plugin=use_plugin) + session = self.create_trt_session( + tuple(input_data.shape), + num_experts, + top_k, + hidden_size, + ffn_hidden_size, + actfn, + bias, + dtype, + weight_dtype, + quant_mode, + norm_mode=MoeConfig.ExpertScaleNormalizationMode.NONE, + custom_network=MLP, + use_plugin=use_plugin) + + inputs = {"input_hidden_states": input_data} + outputs = run_session(session, inputs) tolerances = { 'float32': 1e-2, @@ -580,8 +579,8 @@ def MLP(network, trt_key): 'int8': 2e-1, 'int4': 2e-1, } - np.testing.assert_allclose(res['output'].float(), - res['mlp_output'].float(), + torch.testing.assert_close(outputs['output'], + outputs['mlp_output'], rtol=tolerances[dtype_str], atol=tolerances[dtype_str]) @@ -614,30 +613,38 @@ def set_weight_layer(self, moe_weight_wrapper.weight.value = np.ascontiguousarray( torch_to_numpy(input_weights)) - def buildTrtEngine(self, - input_shape, - num_experts, - top_k, - hidden_size, - ffn_hidden_size, - actfn, - bias, - dtype: trt.DataType, - weight_dtype: trt.DataType, - quant_mode, - norm_mode, - custom_network=None, - use_plugin=True, - max_sizes=None): + def create_trt_session(self, + input_shape, + num_experts, + top_k, + hidden_size, + ffn_hidden_size, + actfn, + bias, + dtype: trt.DataType, + weight_dtype: trt.DataType, + quant_mode, + norm_mode, + custom_network=None, + use_plugin=True, + max_sizes=None): builder = tensorrt_llm.Builder() - builder.strongly_typed = weight_dtype == trt.fp8 - net = builder.create_network() - net.plugin_config.moe_plugin = (trt_dtype_to_str(dtype) - if use_plugin else None) - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + if use_plugin: + network.plugin_config.moe_plugin = trt_dtype_to_str(dtype) + with tensorrt_llm.net_guard(network): + if max_sizes: + dim_range = OrderedDict([("max_num_seq", [[1, 1, + max_sizes[0]]]), + ("max_seq_len", [[1, 1, + max_sizes[1]]]), + ("hidden_size", [hidden_size])]) + else: + dim_range = None + trt_key = Tensor(name='input_hidden_states', shape=input_shape, + dim_range=dim_range, dtype=dtype) moe = tensorrt_llm.layers.MOE(moe_config=MoeConfig( @@ -674,67 +681,20 @@ def buildTrtEngine(self, if custom_network: custom_network(network, trt_key) - output = moe(trt_key).trt_tensor - output.name = 'output' - network.mark_output(output) - output.dtype = dtype - - profiles = None - if max_sizes: - profiles = [ - Profile().add('input_hidden_states', (1, 1, hidden_size), - (1, 1, hidden_size), max_sizes) - ] - - config = CreateConfig(builder_optimization_level=4, profiles=profiles) - if not builder.strongly_typed: - config = CreateConfig(fp16=(dtype == trt.float16), - bf16=(dtype == trt.bfloat16), - int8=(weight_dtype == trt.int8), - fp8=(weight_dtype == trt.fp8), - precision_constraints='obey', - builder_optimization_level=4, - profiles=profiles) + output = moe(trt_key) + output.mark_output('output', dtype) # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network), - config=config) - assert build_engine is not None - return build_engine - - def runTrtEngine(self, engine, input_data): - with TrtRunner(engine) as runner: - feed_dict = { - 'input_hidden_states': input_data, - } - outputs = runner.infer(feed_dict=feed_dict) - return outputs - - def trtImpl(self, - input_data, - num_experts, - top_k, - hidden_size, - ffn_hidden_size, - actfn, - bias, - dtype: trt.DataType, - weight_dtype: trt.DataType = None, - quant_mode=QuantMode(0), - norm_mode=MoeConfig.ExpertScaleNormalizationMode.NONE, - custom_network=None, - use_plugin=True): - build_engine = self.buildTrtEngine(tuple(input_data.shape), num_experts, - top_k, hidden_size, ffn_hidden_size, - actfn, bias, dtype, weight_dtype, - quant_mode, norm_mode, - custom_network, use_plugin) - - outputs = self.runTrtEngine(build_engine, input_data) - return outputs - - def referenceImpl(self, inputs, k, actfn, weight_dtype, quant_mode, - norm_mode): + session = create_session(builder, + network, + precision=trt_dtype_to_str(dtype), + int8=weight_dtype == trt.int8, + quant_mode=quant_mode, + opt_level=4) + return session + + def generate_reference(self, inputs, k, actfn, weight_dtype, quant_mode, + norm_mode): # Always run the ref implementation at full precision TODO is this a good choice? inputs = inputs.cuda().float() inputs_merged = inputs.view(-1, inputs.shape[-1]) @@ -771,7 +731,3 @@ def referenceImpl(self, inputs, k, actfn, weight_dtype, quant_mode, assert final.shape == (inputs.shape[-1], ) results[i] += scale * final return results.view(*inputs.shape), max_act_2 - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/functional/test_nccl.py b/tests/functional/test_nccl.py index 90d5f8c23..facc59e7a 100644 --- a/tests/functional/test_nccl.py +++ b/tests/functional/test_nccl.py @@ -25,9 +25,8 @@ from cuda import cudart from parameterized import parameterized -from polygraphy.backend.trt import CreateConfig, EngineFromNetwork -import tensorrt_llm as tllm +import tensorrt_llm from tensorrt_llm import Mapping, Tensor from tensorrt_llm._ipc_utils import peer_access from tensorrt_llm.functional import (AllReduceConfig, AllReduceStrategy, @@ -35,15 +34,16 @@ from tensorrt_llm.plugin.plugin import current_all_reduce_helper sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import (create_session, run_session, skip_bf16_pre_ampere, + unittest_name_func) class TestCommunicationPlugin(unittest.TestCase): def setUp(self): - tllm.logger.set_level('error') - self.world_size = tllm.mpi_world_size() - self.rank = tllm.mpi_rank() + tensorrt_llm.logger.set_level('error') + self.world_size = tensorrt_llm.mpi_world_size() + self.rank = tensorrt_llm.mpi_rank() torch.cuda.set_device(self.rank) cudart.cudaSetDevice(self.rank) @@ -67,6 +67,8 @@ def setUp(self): name_func=unittest_name_func) def test_allreduce(self, dtype: str, strategy: AllReduceStrategy, config: AllReduceConfig, size: int): + + skip_bf16_pre_ampere(dtype) if self.world_size == 1: pytest.skip("Skip single GPU NCCL") @@ -75,7 +77,7 @@ def test_allreduce(self, dtype: str, strategy: AllReduceStrategy, workspace = None - torch_dtype = tllm._utils.str_dtype_to_torch(dtype) + torch_dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype) dtype_size = torch.finfo(torch_dtype).bits // 8 allreduce_ref = torch.zeros(self.reference_tensors[0][:size].shape, @@ -85,9 +87,10 @@ def test_allreduce(self, dtype: str, strategy: AllReduceStrategy, allreduce_ref = allreduce_ref + self.reference_tensors[i][:size].to( torch_dtype) - builder = tllm.Builder() - net = builder.create_network() - net.plugin_config.set_nccl_plugin(dtype, use_custom_all_reduce=True) + # construct trt network + builder = tensorrt_llm.Builder() + network = builder.create_network() + network.plugin_config.set_nccl_plugin(dtype, use_custom_all_reduce=True) _, workspace = current_all_reduce_helper().allocate_workspace( self.mapping, size * dtype_size) @@ -95,48 +98,26 @@ def test_allreduce(self, dtype: str, strategy: AllReduceStrategy, inner_loop = 5 with peer_access(self.mapping): - with tllm.net_guard(net): - network = tllm.default_trtnet() + with tensorrt_llm.net_guard(network): x = Tensor(name='x', shape=input.shape, - dtype=tllm.str_dtype_to_trt(dtype)) + dtype=tensorrt_llm.str_dtype_to_trt(dtype)) current_all_reduce_helper().set_workspace_tensor(self.mapping) current = x for i in range(inner_loop): current = allreduce(current, self.mapping.tp_group, strategy, config) - output = current.trt_tensor - - output.name = 'output' - output.dtype = tllm.str_dtype_to_trt(dtype) - network.mark_output(output) - - build_engine = EngineFromNetwork( - (builder.trt_builder, net.trt_network), - config=CreateConfig( - fp16=(dtype == 'float16'), - bf16=(dtype == 'bfloat16'), - precision_constraints='obey', - )) - - output = torch.zeros_like(input) - - stream = torch.cuda.current_stream() - feed_dict = {'x': input, 'all_reduce_workspace': workspace} - - session = tllm.runtime.Session.from_engine(build_engine()) - session.run(inputs=feed_dict, - outputs={"output": output}, - stream=stream.cuda_stream) - torch.cuda.synchronize() - self.assertTrue( - torch.allclose(output.cpu(), - (self.mapping.tp_size**(inner_loop - 1)) * - allreduce_ref.cpu())) + current.mark_output('output', dtype) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'x': input, 'all_reduce_workspace': workspace} + outputs = run_session(session, inputs) -if __name__ == "__main__": - unittest.main() + # compare diff + torch.testing.assert_close(outputs['output'], + (self.mapping.tp_size**(inner_loop - 1)) * + allreduce_ref) diff --git a/tests/functional/test_outer.py b/tests/functional/test_outer.py index c0d58f5ae..4634b6e39 100644 --- a/tests/functional/test_outer.py +++ b/tests/functional/test_outer.py @@ -12,17 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestOuter(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -30,13 +33,14 @@ def setUp(self): def test_outer(self): # test data dtype = 'float32' - x_data = torch.arange(1., 5.) - y_data = torch.arange(1., 4.) + x_data = torch.arange(1., 5.).cuda() + y_data = torch.arange(1., 4.).cuda() + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -44,20 +48,16 @@ def test_outer(self): shape=y_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.outer(x, y).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.outer(x, y) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - 'y': y_data.numpy() - }) + session = create_session(builder, network, precision=dtype) + inputs = {'x': x_data, 'y': y_data} + outputs = run_session(session, inputs) # pytorch run ref = torch.outer(x_data, y_data) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_permute.py b/tests/functional/test_permute.py index d9a933088..3a15ff41c 100644 --- a/tests/functional/test_permute.py +++ b/tests/functional/test_permute.py @@ -16,19 +16,17 @@ import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func -class TestFunctional(unittest.TestCase): +class TestPermute(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -39,27 +37,29 @@ def test_permute(self, dtype): x_shape = (4, 12, 64, 129) dims = [0, 1, 3, 2] x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.permute(x, dims).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.permute(x, dims) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run ref = torch.permute(x_data, dims) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_repeat_interleave.py b/tests/functional/test_repeat_interleave.py index 68cdd79e8..2d758f531 100644 --- a/tests/functional/test_repeat_interleave.py +++ b/tests/functional/test_repeat_interleave.py @@ -1,20 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func -class TestFunctional(unittest.TestCase): +class TestRepeatInterleave(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -24,25 +36,30 @@ def test_repeat_interleave(self, axis): dtype = 'float32' repeats = 3 x_data = torch.randn( - (2, 3, 4), dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + (2, 3, 4), + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.repeat_interleave( - x, repeats, axis).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.repeat_interleave(x, repeats, axis) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.repeat_interleave(x_data, repeats, axis) - np.testing.assert_allclose(ref, outputs['output']) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_scatter.py b/tests/functional/test_scatter.py index 24f11bc44..4d7b8487d 100644 --- a/tests/functional/test_scatter.py +++ b/tests/functional/test_scatter.py @@ -12,19 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor from tensorrt_llm._utils import str_dtype_to_torch +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestScatter(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -51,28 +54,22 @@ def setUp(self): 2, ), ]) - def test_scatter(self, - input_data=[[[-3.0, -2.0, -1.0, 10.0, -25.0]], - [[0.0, 1.0, 2.0, -2.0, -1.0]]], - indices=[[[1, 2, 3, 0, 4]], [[4, 1, 2, 3, 0]]], - updates=[[[-1.0, 2.4, 3.2, 10.8, 8.9]], - [[0, -11.2, 34.2, 223.9, -100]]], - dim=2): + def test_scatter(self, input_data, indices, updates, dim): dtype = 'float32' torch_dtype = str_dtype_to_torch(dtype) - input_data = input_data if isinstance( - input_data, torch.Tensor) else torch.tensor(input_data) - indices = indices if isinstance( - indices, torch.Tensor) else torch.tensor(indices).int() - updates = updates if isinstance(updates, - torch.Tensor) else torch.tensor(updates) + input_data = input_data.cuda() if isinstance( + input_data, torch.Tensor) else torch.tensor(input_data).cuda() + indices = indices.cuda() if isinstance( + indices, torch.Tensor) else torch.tensor(indices).int().cuda() + updates = updates.cuda() if isinstance( + updates, torch.Tensor) else torch.tensor(updates).cuda() input_data = input_data.to(torch_dtype) updates = updates.to(torch_dtype) + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): input_t = Tensor(name='input', shape=input_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -88,27 +85,18 @@ def test_scatter(self, indices=indices_t, updates=updates_t) - output = output.trt_tensor - output.name = 'output' - network.mark_output(output) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer( - feed_dict={ - 'input': input_data.numpy(), - 'indices': indices.numpy(), - 'updates': updates.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'input': input_data, 'indices': indices, 'updates': updates} + outputs = run_session(session, inputs) + # pytorch run ref = torch.scatter(input_data, dim=dim, index=indices.to(dtype=torch.int64), src=updates) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) - # print(ref) - # print(outputs['output']) - return + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_select.py b/tests/functional/test_select.py index 6c39be7c7..4e81df049 100644 --- a/tests/functional/test_select.py +++ b/tests/functional/test_select.py @@ -12,72 +12,77 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestSelect(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_select_from_int(self): dtype = 'float32' - x_data = torch.randn(2, 3, 4) + x_data = torch.randn(2, 3, 4, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.select(x, dim=0, - index=0).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.select(x, dim=0, index=0) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.select(x_data, dim=0, index=0) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) def test_select_from_tensor(self): dtype = 'float32' - x_data = torch.randn(2, 3, 4) + x_data = torch.randn(2, 3, 4, device="cuda") builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) y = tensorrt_llm.functional.constant(np.array([1], dtype=np.int32)) - output = tensorrt_llm.functional.select(x, dim=2, - index=y).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.select(x, dim=2, index=y) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.select(x_data, dim=2, index=1) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_sigmoid.py b/tests/functional/test_sigmoid.py index 4cc0a86c6..0c93fc55a 100644 --- a/tests/functional/test_sigmoid.py +++ b/tests/functional/test_sigmoid.py @@ -12,43 +12,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestSigmoid(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_sigmoid(self): dtype = 'float32' - x_data = torch.randn(2, 3, 4, 5) + x_data = torch.randn(2, 3, 4, 5, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.sigmoid(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.sigmoid(x) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.sigmoid(x_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_silu.py b/tests/functional/test_silu.py index 9ec8cbf1d..6e0802137 100644 --- a/tests/functional/test_silu.py +++ b/tests/functional/test_silu.py @@ -12,43 +12,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestSiLU(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_silu(self): dtype = 'float32' - x_data = torch.randn(2, 3, 4, 5) + x_data = torch.randn(2, 3, 4, 5, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.silu(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.silu(x) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.silu(x_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_sin.py b/tests/functional/test_sin.py index 6c228207a..5288ad7b9 100644 --- a/tests/functional/test_sin.py +++ b/tests/functional/test_sin.py @@ -12,15 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session + class TestFunctional(unittest.TestCase): @@ -29,26 +32,28 @@ def setUp(self): def test_exp(self): dtype = 'float32' - x_data = torch.randn(2, 3, 4, 5) + x_data = torch.randn(2, 3, 4, 5, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.sin(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.sin(x) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.sin(x_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_slice.py b/tests/functional/test_slice.py index 226d2cfd7..e6373c96c 100644 --- a/tests/functional/test_slice.py +++ b/tests/functional/test_slice.py @@ -24,14 +24,12 @@ import sys from parameterized import parameterized -from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile, - TrtRunner) import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func class TestFunctional(unittest.TestCase): @@ -41,19 +39,20 @@ def setUp(self): @parameterized.expand([('float32', ), ('float16', )], name_func=unittest_name_func) - def test_slice_1(self, dtype): + def test_slice_explicit(self, dtype): # test data x_shape = (1, 256) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") starts_data = torch.tensor([0, 128]).int() sizes_data = torch.tensor([1, 1]).int() # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -61,43 +60,40 @@ def test_slice_1(self, dtype): sizes = Tensor(name='sizes', shape=(2, ), dtype=trt.int32) - output = tensorrt_llm.functional.slice(x, starts, sizes).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.slice(x, starts, sizes) + output.mark_output('output') + + profile = builder.trt_builder.create_optimization_profile() + profile.set_shape_input('starts', (0, 128), (0, 128), (0, 128)) + profile.set_shape_input('sizes', (1, 1), (1, 1), (1, 1)) # trt run - profiles = [ - Profile().add('starts', (0, 0), (0, 128), - (0, 256)).add('sizes', (1, 1), (1, 1), (1, 256)) - ] - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network), - config=CreateConfig(profiles=profiles)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer( - feed_dict={ - 'x': x_data.numpy(), - 'starts': starts_data.numpy(), - 'sizes': sizes_data.numpy(), - }) + session = create_session(builder, + network, + precision=dtype, + optimization_profiles=[profile]) + inputs = {'x': x_data, 'starts': starts_data, 'sizes': sizes_data} + outputs = run_session(session, inputs) # pytorch run ref = x_data[0:1, 128:129] # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + torch.testing.assert_close(ref, outputs['output']) - def test_slice_2(self): + def test_slice_implicit(self): dtype = 'float32' x_shape = (256, ) slice_length = 128 x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -107,17 +103,19 @@ def test_slice_2(self): np.array([0] * slice_length, dtype=np.int32)) sizes = tensorrt_llm.functional.shape(output_length, 0) - output = tensorrt_llm.functional.slice(x, starts, - sizes.view([1])).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.slice(x, starts, sizes.view([1])) + + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'x': x_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = x_data[0:slice_length] - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_softplus.py b/tests/functional/test_softplus.py index 8846537ca..d0ef18e96 100644 --- a/tests/functional/test_softplus.py +++ b/tests/functional/test_softplus.py @@ -12,43 +12,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestSoftPlus(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_softplus(self): dtype = 'float32' - x_data = torch.randn(3, 5) + x_data = torch.randn(3, 5, device="cuda") + + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.softplus(x, 1.6, 3.2).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.softplus(x, 1.6, 3.2) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.nn.functional.softplus(x_data, 1.6, 3.2) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-6) + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_split.py b/tests/functional/test_split.py index a49d50591..4ccb2591b 100644 --- a/tests/functional/test_split.py +++ b/tests/functional/test_split.py @@ -16,19 +16,17 @@ import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func -class TestFunctional(unittest.TestCase): +class TestSplit(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -56,13 +54,14 @@ def test_split(self, dtype, dim, split_size_or_sections): # test data x_shape = (128, 256) x_data = torch.rand(x_shape, - dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -70,16 +69,14 @@ def test_split(self, dtype, dim, split_size_or_sections): outputs = tensorrt_llm.functional.split(x, split_size_or_sections, dim) for i in range(len(outputs)): - output = outputs[i].trt_tensor - output.name = f'output_{i}' - network.mark_output(output) + outputs[i].mark_output(f'output_{i}') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run ref_outputs = torch.split(x_data, split_size_or_sections, dim) @@ -87,5 +84,4 @@ def test_split(self, dtype, dim, split_size_or_sections): # compare diff assert len(outputs.keys()) == len(ref_outputs) for i in range(len(ref_outputs)): - np.testing.assert_allclose(ref_outputs[i].cpu().numpy(), - outputs[f'output_{i}']) + torch.testing.assert_close(ref_outputs[i], outputs[f'output_{i}']) diff --git a/tests/functional/test_squeeze.py b/tests/functional/test_squeeze.py index ff515a4f5..0aa2e523e 100644 --- a/tests/functional/test_squeeze.py +++ b/tests/functional/test_squeeze.py @@ -12,57 +12,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor from tensorrt_llm._utils import str_dtype_to_torch +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestSqueeze(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - def test_squeeze(self, - input_data=[[[-3.0, -2.0, -1.0, 10.0, -25.0]], - [[0.0, 1.0, 2.0, -2.0, -1.0]]], - dim=1): + def test_squeeze(self): dtype = 'float32' + input_data = torch.tensor([[[-3.0, -2.0, -1.0, 10.0, -25.0]], + [[0.0, 1.0, 2.0, -2.0, -1.0]]]).cuda() + dim = 1 torch_dtype = str_dtype_to_torch(dtype) - input_data = input_data if isinstance( - input_data, torch.Tensor) else torch.tensor(input_data) input_data = input_data.to(torch_dtype) + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): input_t = Tensor(name='input', shape=input_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) output = tensorrt_llm.functional.squeeze(input_t, dim=dim) + output.mark_output('output') - output = output.trt_tensor - output.name = 'output' - network.mark_output(output) - - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'input': input_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = torch.squeeze(input_data, dim=dim) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) - # print(ref) - # print(outputs['output']) - return + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_swiglu.py b/tests/functional/test_swiglu.py index c0a9f86e6..28a27aca4 100644 --- a/tests/functional/test_swiglu.py +++ b/tests/functional/test_swiglu.py @@ -12,45 +12,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner from torch_ref import swiglu import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestSwiglu(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') def test_swiglu(self): dtype = 'float32' - x_data = torch.randn(12, 2, 96) + x_data = torch.randn(12, 2, 96, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.swiglu(x).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.swiglu(x) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) + # pytorch run ref = swiglu(x_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-4, - rtol=1e-4) + + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_topk.py b/tests/functional/test_topk.py index 1bd232ecf..aed1dbaff 100644 --- a/tests/functional/test_topk.py +++ b/tests/functional/test_topk.py @@ -16,18 +16,16 @@ import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -from utils.util import unittest_name_func +from utils.util import create_session, run_session, unittest_name_func -class TestFunctional(unittest.TestCase): +class TestTopK(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -44,35 +42,27 @@ def test_topk(self, input_shape, k, d, largest): indices_dtype = 'int32' # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() + network = builder.create_network() + input_data = torch.rand(*input_shape, + dtype=torch.float32, + device="cuda") + with tensorrt_llm.net_guard(network): - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() - input_data = np.random.rand(*input_shape).astype(np.float32) - m = tensorrt_llm.functional.constant(input_data) + m = tensorrt_llm.functional.constant(input_data.cpu().numpy()) topk_values, topk_indices = tensorrt_llm.functional.topk( m, k, d, largest=largest) - topk_values = topk_values.trt_tensor - topk_indices = topk_indices.trt_tensor - topk_values.name = 'output_values' - topk_indices.name = 'output_indices' - network.mark_output(topk_values) - network.mark_output(topk_indices) - topk_values.dtype = tensorrt_llm.str_dtype_to_trt(value_dtype) - topk_indices.dtype = tensorrt_llm.str_dtype_to_trt(indices_dtype) + topk_values.mark_output('output_values', value_dtype) + topk_indices.mark_output('topk_indices', indices_dtype) - # trt run - build_engine = EngineFromNetwork( - (builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={}) - values, indices = torch.topk(torch.Tensor(input_data), - k, - dim=d, - largest=largest) + # trt run + session = create_session(builder, network) + inputs = {} + outputs = run_session(session, inputs) - np.testing.assert_allclose(values.cpu().numpy(), - outputs['output_values'], - atol=1e-5) - np.testing.assert_allclose(indices.cpu().numpy(), - outputs['output_indices']) + # pytorch run + values, indices = torch.topk(input_data, k, dim=d, largest=largest) + + # compare diff + torch.testing.assert_close(values, outputs['output_values']) + # dtype does not match + torch.testing.assert_close(indices.int(), outputs['topk_indices'].int()) diff --git a/tests/functional/test_transpose.py b/tests/functional/test_transpose.py index ec542ae55..718f2bcca 100644 --- a/tests/functional/test_transpose.py +++ b/tests/functional/test_transpose.py @@ -12,17 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestTranspose(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -30,31 +33,30 @@ def setUp(self): def test_transpose(self): # test data dtype = 'float32' - x_data = torch.randn(2, 3) + x_data = torch.randn(2, 3, device="cuda") dim0 = 0 dim1 = 1 # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + x = Tensor(name='x', shape=x_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.transpose(x, dim0, dim1).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.transpose(x, dim0, dim1) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'x': x_data.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = { + 'x': x_data, + } + outputs = run_session(session, inputs) # pytorch run ref = torch.transpose(x_data, dim0, dim1) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_unsqueeze.py b/tests/functional/test_unsqueeze.py index 15d6713eb..695631370 100644 --- a/tests/functional/test_unsqueeze.py +++ b/tests/functional/test_unsqueeze.py @@ -12,57 +12,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor from tensorrt_llm._utils import str_dtype_to_torch +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session -class TestFunctional(unittest.TestCase): + +class TestUnsqueeze(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') - def test_unsqueeze(self, - input_data=[[[-3.0, -2.0, -1.0, 10.0, -25.0]], - [[0.0, 1.0, 2.0, -2.0, -1.0]]], - axis=0): + def test_unsqueeze(self): dtype = 'float32' - torch_dtype = str_dtype_to_torch(dtype) - input_data = input_data if isinstance( - input_data, torch.Tensor) else torch.tensor(input_data) - input_data = input_data.to(torch_dtype) + str_dtype_to_torch(dtype) + input_data = torch.tensor([[[-3.0, -2.0, -1.0, 10.0, -25.0]], + [[0.0, 1.0, 2.0, -2.0, -1.0]]]).cuda() + axis = 0 + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): input_t = Tensor(name='input', shape=input_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) output = tensorrt_llm.functional.unsqueeze(input_t, axis=axis) + output.mark_output('output') - output = output.trt_tensor - output.name = 'output' - network.mark_output(output) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'input': input_data} - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'input': input_data.numpy(), - }) + outputs = run_session(session, inputs) + # pytorch run ref = torch.unsqueeze(input_data, dim=axis) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) - # print(ref) - # print(outputs['output']) - return + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_view.py b/tests/functional/test_view.py index 955e7954f..180f89a19 100644 --- a/tests/functional/test_view.py +++ b/tests/functional/test_view.py @@ -12,22 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np - # isort: off import torch import tensorrt as trt # isort: on -from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile, - TrtRunner) import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session + -class TestFunctional(unittest.TestCase): +class TestView(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('error') @@ -38,31 +39,34 @@ def test_view_static(self): input_shape = (4, 3) output_shape = (12, 1) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) output = tensorrt_llm.functional.view(input=input, - shape=output_shape).trt_tensor - output.name = 'output' - network.mark_output(output) + shape=output_shape) + output.mark_output('output') # trt run - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={'input': input_data.numpy()}) + session = create_session(builder, network, precision=dtype) + inputs = { + 'input': input_data, + } + outputs = run_session(session, inputs) # pytorch run ref = input_data.view(output_shape) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + torch.testing.assert_close(ref, outputs['output']) def test_view_dynamic(self): # test data @@ -70,37 +74,38 @@ def test_view_dynamic(self): input_shape = (4, 3) output_shape = (2, 6) input_data = torch.rand( - input_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype)) + input_shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), + device="cuda") shape_data = torch.tensor(output_shape).int() # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): + input = Tensor(name='input', shape=input_shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) shape = Tensor(name='shape', shape=(len(input_shape), ), dtype=trt.int32) - output = tensorrt_llm.functional.view(input=input, - shape=shape).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.view(input=input, shape=shape) + output.mark_output('output') # trt run - profiles = [Profile().add('shape', (1, 1), input_shape, (12, 12))] - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network), - config=CreateConfig(profiles=profiles)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 'input': input_data.numpy(), - 'shape': shape_data.numpy() - }) + profile = builder.trt_builder.create_optimization_profile() + profile.set_shape_input('shape', output_shape, output_shape, + output_shape) + session = create_session(builder, + network, + precision=dtype, + optimization_profiles=[profile]) + inputs = {'input': input_data, 'shape': shape_data} + outputs = run_session(session, inputs) # pytorch run ref = input_data.view(output_shape) # compare diff - np.testing.assert_allclose(ref.cpu().numpy(), outputs['output']) + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/functional/test_where.py b/tests/functional/test_where.py index 032a27001..180f05cd0 100644 --- a/tests/functional/test_where.py +++ b/tests/functional/test_where.py @@ -12,18 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import unittest -import numpy as np import torch from parameterized import parameterized -from polygraphy.backend.trt import EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.util import create_session, run_session, unittest_name_func -class TestFunctional(unittest.TestCase): + +class TestWhere(unittest.TestCase): def setUp(self): tensorrt_llm.logger.set_level('warning') @@ -31,50 +34,42 @@ def setUp(self): @parameterized.expand([ (True, ), (False, ), - ]) - def test_where_from_bool(self, condition=True): + ], name_func=unittest_name_func) + def test_where_from_bool(self, condition): dtype = 'float32' - t_data = torch.randn(2, 3) - f_data = torch.randn(2, 3) + t_data = torch.randn(2, 3, device="cuda") + f_data = torch.randn(2, 3, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): t = Tensor(name='t', shape=t_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) f = Tensor(name='f', shape=f_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) - output = tensorrt_llm.functional.where(condition, t, f).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.where(condition, t, f) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 't': t_data.numpy(), - 'f': f_data.numpy(), - }) + session = create_session(builder, network, precision=dtype) + inputs = {'t': t_data, 'f': f_data} + outputs = run_session(session, inputs) - ref = torch.where(torch.tensor(condition), t_data, f_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) + ref = torch.where(torch.tensor(condition).cuda(), t_data, f_data) + torch.testing.assert_close(ref, outputs['output']) def test_where_from_tensor(self): dtype = 'float32' - t_data = torch.randn(3, 4) - f_data = torch.randn(3, 4) - c_data = torch.randint(2, size=(3, 1), dtype=torch.bool) - ref = torch.where(c_data, t_data, f_data) - print(ref) + t_data = torch.randn(3, 4, device="cuda") + f_data = torch.randn(3, 4, device="cuda") + c_data = torch.randint(2, size=(3, 1), dtype=torch.bool, device="cuda") + # construct trt network builder = tensorrt_llm.Builder() - net = builder.create_network() - with tensorrt_llm.net_guard(net): - network = tensorrt_llm.default_trtnet() + network = builder.create_network() + with tensorrt_llm.net_guard(network): t = Tensor(name='t', shape=t_data.shape, dtype=tensorrt_llm.str_dtype_to_trt(dtype)) @@ -84,23 +79,16 @@ def test_where_from_tensor(self): c = Tensor(name='c', shape=c_data.shape, dtype=tensorrt_llm.str_dtype_to_trt('bool')) - output = tensorrt_llm.functional.where(c, t, f).trt_tensor - output.name = 'output' - network.mark_output(output) + output = tensorrt_llm.functional.where(c, t, f) + output.mark_output('output') - build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network)) - with TrtRunner(build_engine) as runner: - outputs = runner.infer(feed_dict={ - 't': t_data.numpy(), - 'f': f_data.numpy(), - 'c': c_data.numpy(), - }) + # trt run + session = create_session(builder, network, precision=dtype) + inputs = {'t': t_data, 'f': f_data, 'c': c_data} + outputs = run_session(session, inputs) + + # pytorch run + ref = torch.where(c_data, t_data, f_data) - np.testing.assert_allclose(ref.cpu().numpy(), - outputs['output'], - atol=1e-5) - print(t_data) - print(f_data) - print(c_data) - print(outputs['output']) - # assert False, "FORCED" + # compare diff + torch.testing.assert_close(ref, outputs['output']) diff --git a/tests/hlapi/test_llm_download.py b/tests/hlapi/test_llm_download.py new file mode 100644 index 000000000..feb39b9e6 --- /dev/null +++ b/tests/hlapi/test_llm_download.py @@ -0,0 +1,17 @@ +from tensorrt_llm.hlapi import LLM, ModelConfig +from tensorrt_llm.hlapi.utils import download_hf_model + +prompts = ["A B C"] + + +def test_download_hf_model(): + dir = download_hf_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + assert dir.exists() + print(f"Downloaded model to {dir}") + + +def test_llm_with_model_downloaded(): + config = ModelConfig(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + llm = LLM(config) + for output in llm.generate(prompts): + print(output) diff --git a/tests/hlapi/test_llm_multi_gpu.py b/tests/hlapi/test_llm_multi_gpu.py index 9f32f7149..5bcc0a044 100644 --- a/tests/hlapi/test_llm_multi_gpu.py +++ b/tests/hlapi/test_llm_multi_gpu.py @@ -94,11 +94,12 @@ def test_llm_generate_tp2(engine_from_checkpoint): print(output) +# TODO[yuxianq]: Enable auto_parallel after fixing the issue +#@pytest.mark.parametrize("use_auto_parallel", [True, False], ids=[ "enable_auto_parallel", "disable_auto_parallel"]) @skip_single_gpu -@pytest.mark.parametrize("use_auto_parallel", [True, False], - ids=["enable_auto_parallel", "disable_auto_parallel"]) def test_llm_generate_async_tp2( - use_auto_parallel, engine_from_checkpoint: tempfile.TemporaryDirectory): + engine_from_checkpoint: tempfile.TemporaryDirectory, + use_auto_parallel=False): model_dir = engine_from_checkpoint.name if not use_auto_parallel else get_model_path( llama_model_path) tokenizer_dir = get_model_path(llama_model_path) diff --git a/tests/model/test_arctic.py b/tests/model/test_arctic.py index b599824ee..135b7c247 100644 --- a/tests/model/test_arctic.py +++ b/tests/model/test_arctic.py @@ -26,8 +26,8 @@ import tensorrt_llm from tensorrt_llm import Builder from tensorrt_llm._utils import str_dtype_to_trt -from tensorrt_llm.models.llama.weight import load_from_hf_llama -from tensorrt_llm.models.modeling_utils import PretrainedConfig +from tensorrt_llm.models import PretrainedConfig +from tensorrt_llm.models.llama.convert import load_weights_from_hf_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType @@ -71,27 +71,25 @@ def _gen_tensorrt_llm_network(self, network, hf_mistral, 'mapping': { 'world_size': tensor_parallel, 'tp_size': tensor_parallel, + 'rank': rank, }, 'use_parallel_embedding': False, 'embedding_sharding_dim': 0, - 'moe_num_experts': 0, - 'moe_top_k': 0, - 'moe_tp_mode': 1, - 'moe_normalization_mode': 1, + 'moe': { + 'num_experts': 0, + 'top_k': 0, + 'tp_mode': 1, + 'normalization_mode': 1, + }, 'use_fused_mlp': False, } # Initialize model - tensorrt_llm_mistral = tensorrt_llm.models.LLaMAForCausalLM( - PretrainedConfig.from_dict(config)) + config = PretrainedConfig.from_dict(config) + tensorrt_llm_mistral = tensorrt_llm.models.LLaMAForCausalLM(config) + if not mistral_config.residual_mlp: - weights = load_from_hf_llama(tensorrt_llm_mistral, - hf_mistral, - dtype=dtype, - mapping=tensorrt_llm.Mapping( - world_size=tensor_parallel, - rank=rank, - tp_size=tensor_parallel)) + weights = load_weights_from_hf_model(hf_mistral, config) tensorrt_llm_mistral.load(weights) # Prepare network.set_named_parameters( @@ -133,7 +131,7 @@ def _gen_tensorrt_llm_engine(self, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, - strongly_typed=(dtype in ["float16", "bfloat16"]), + strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() diff --git a/tests/model/test_bloom.py b/tests/model/test_bloom.py index c5ec08001..4faa5c0bf 100644 --- a/tests/model/test_bloom.py +++ b/tests/model/test_bloom.py @@ -59,10 +59,9 @@ def _gen_hf_bloom(self, hidden_act, n_layer, max_length, dtype): def _gen_tensorrt_llm_network(self, network, builder, hf_bloom, bloom_config, batch_size, input_len, - output_len, fp16, gpt_attention_plugin, + output_len, dtype, gpt_attention_plugin, tensor_parallel, apply_query_key_layer_scaling): - dtype = 'float16' if fp16 else 'float32' config = { 'architecture': 'BloomForCausalLM', 'dtype': dtype, @@ -123,7 +122,6 @@ def _gen_tensorrt_llm_runtime(self, runtime = None builder = Builder() - fp16 = (dtype == 'float16') with tempfile.TemporaryDirectory() as tmpdirname: builder_config = builder.create_builder_config( @@ -132,7 +130,7 @@ def _gen_tensorrt_llm_runtime(self, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, - strongly_typed=fp16, + strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() @@ -146,7 +144,7 @@ def _gen_tensorrt_llm_runtime(self, self._gen_tensorrt_llm_network(network, builder, hf_bloom, bloom_config, batch_size, input_len, - output_len, fp16, use_plugin, + output_len, dtype, use_plugin, world_size, apply_query_key_layer_scaling) diff --git a/tests/model/test_falcon.py b/tests/model/test_falcon.py index 65b46e3d2..0e493b5dc 100644 --- a/tests/model/test_falcon.py +++ b/tests/model/test_falcon.py @@ -163,7 +163,7 @@ def generate_trtllm_runtime(self, use_alibi=hf_config.alibi, parallel_attention=hf_config.parallel_attn, use_refit=use_refit, - strongly_typed=(dtype == "float16"), + strongly_typed=True, ) network = builder.create_network() diff --git a/tests/model/test_gpt.py b/tests/model/test_gpt.py index f86a35b6c..23f48d7e6 100644 --- a/tests/model/test_gpt.py +++ b/tests/model/test_gpt.py @@ -61,12 +61,11 @@ def _gen_hf_gpt(self, hidden_act, n_layer, max_length, dtype): return gpt_config, hf_gpt def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config, - batch_size, input_len, output_len, fp16, + batch_size, input_len, output_len, dtype, gpt_attention_plugin, tensor_parallel, apply_query_key_layer_scaling, gather_context_logits, gather_generation_logits): - dtype = 'float16' if fp16 else 'float32' config = { 'architecture': 'GPTForCausalLM', 'dtype': dtype, @@ -138,7 +137,6 @@ def _gen_tensorrt_llm_runtime(self, runtime = None builder = Builder() - fp16 = (dtype == 'float16') with tempfile.TemporaryDirectory() as tmpdirname: @@ -150,7 +148,7 @@ def _gen_tensorrt_llm_runtime(self, use_refit=use_refit, gather_context_logits=gather_context_logits, gather_generation_logits=gather_generation_logits, - strongly_typed=fp16, + strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() @@ -166,7 +164,7 @@ def _gen_tensorrt_llm_runtime(self, self._gen_tensorrt_llm_network(network, builder, hf_gpt, gpt_config, batch_size, input_len, output_len, - fp16, use_plugin, world_size, + dtype, use_plugin, world_size, apply_query_key_layer_scaling, gather_context_logits, gather_generation_logits) diff --git a/tests/model/test_gpt_e2e.py b/tests/model/test_gpt_e2e.py index f375eed9a..c22ecf027 100644 --- a/tests/model/test_gpt_e2e.py +++ b/tests/model/test_gpt_e2e.py @@ -112,30 +112,27 @@ def build_engines(): convert_ckpt(str(gpt2_dir), str(fp16_ckpt_dir), "--dtype=float16") print("\nBuilding fp16 engines") - build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-default/1-gpu'), - '--strongly_typed') + build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-default/1-gpu')) build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-plugin/1-gpu'), - '--gpt_attention_plugin=float16', '--strongly_typed') + '--gpt_attention_plugin=float16') # Skip tests that are not supported in pre-ampere architecture if getSMVersion() >= 80: build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-plugin-fmha/1-gpu'), - '--gpt_attention_plugin=float16', '--context_fmha=enable', - '--strongly_typed') + '--gpt_attention_plugin=float16', '--context_fmha=enable') build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-plugin-packed/1-gpu'), '--gpt_attention_plugin=float16', - '--remove_input_padding=enable', '--strongly_typed') + '--remove_input_padding=enable') # Skip tests that are not supported in pre-ampere architecture if getSMVersion() >= 80: build_engine(fp16_ckpt_dir, str(engine_dir / 'fp16-plugin-packed-fmha/1-gpu'), '--gpt_attention_plugin=float16', - '--remove_input_padding=enable', '--context_fmha=enable', - '--strongly_typed') + '--remove_input_padding=enable', '--context_fmha=enable') print("Done.") diff --git a/tests/model/test_gptj.py b/tests/model/test_gptj.py index 56986775e..1fa6d660e 100644 --- a/tests/model/test_gptj.py +++ b/tests/model/test_gptj.py @@ -132,7 +132,7 @@ def _gen_tensorrt_llm_runtime(self, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, - strongly_typed=(dtype == "float16"), + strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() diff --git a/tests/model/test_gptneox.py b/tests/model/test_gptneox.py index 696f3d575..313f3470a 100644 --- a/tests/model/test_gptneox.py +++ b/tests/model/test_gptneox.py @@ -68,7 +68,7 @@ def _gen_hf_gpt_neox(self, hidden_act, n_layer, max_length, dtype): def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config, batch_size, beam_width, input_len, output_len, - fp16, gpt_attention_plugin, rank, + dtype, gpt_attention_plugin, rank, tensor_parallel, apply_query_key_layer_scaling): num_layers = gpt_config.num_hidden_layers @@ -80,7 +80,6 @@ def _gen_tensorrt_llm_network(self, network, builder, hf_gpt, gpt_config, list(range(tensor_parallel)) - dtype = 'float16' if fp16 else 'float32' config = { 'architecture': 'GPTNeoXForCausalLM', 'dtype': dtype, @@ -147,7 +146,6 @@ def _gen_tensorrt_llm_runtime(self, runtime = None builder = Builder() - fp16 = (dtype == 'float16') with tempfile.TemporaryDirectory() as tmpdirname: builder_config = builder.create_builder_config( @@ -156,7 +154,7 @@ def _gen_tensorrt_llm_runtime(self, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, - strongly_typed=fp16, + strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() @@ -170,7 +168,7 @@ def _gen_tensorrt_llm_runtime(self, self._gen_tensorrt_llm_network(network, builder, hf_gpt, gpt_config, batch_size, beam_width, input_len, - output_len, fp16, + output_len, dtype, use_attention_plugin, rank, world_size, apply_query_key_layer_scaling) diff --git a/tests/model/test_llama.py b/tests/model/test_llama.py index 500d0befc..4a33ab3a5 100644 --- a/tests/model/test_llama.py +++ b/tests/model/test_llama.py @@ -29,9 +29,10 @@ import tensorrt_llm from tensorrt_llm import Builder from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_str -from tensorrt_llm.models.llama.weight import (load_from_hf_llama, - load_from_meta_llama) -from tensorrt_llm.models.modeling_utils import PretrainedConfig, optimize_model +from tensorrt_llm.models import PretrainedConfig +from tensorrt_llm.models.llama.convert import (load_weights_from_hf_model, + load_weights_from_meta_ckpt) +from tensorrt_llm.models.modeling_utils import optimize_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType @@ -73,8 +74,9 @@ def _gen_tensorrt_llm_network(self, network, hf_llama, 'mapping': { 'world_size': tensor_parallel, 'tp_size': tensor_parallel, + 'rank': rank, }, - "moe_config": { + "moe": { "num_experts": 0, "top_k": 0, "tp_mode": 2, @@ -82,22 +84,12 @@ def _gen_tensorrt_llm_network(self, network, hf_llama, }, 'use_parallel_embedding': False, 'embedding_sharding_dim': 0, - 'moe_num_experts': 0, - 'moe_top_k': 0, - 'moe_tp_mode': 2, - 'moe_normalization_mode': 1, } # Initialize model - tensorrt_llm_llama = tensorrt_llm.models.LLaMAForCausalLM( - PretrainedConfig.from_dict(config)) - weights = load_from_hf_llama(tensorrt_llm_llama, - hf_llama, - dtype=dtype, - mapping=tensorrt_llm.Mapping( - world_size=tensor_parallel, - rank=rank, - tp_size=tensor_parallel)) + config = tensorrt_llm.models.LLaMAConfig.from_dict(config) + tensorrt_llm_llama = tensorrt_llm.models.LLaMAForCausalLM(config) + weights = load_weights_from_hf_model(hf_llama, config) tensorrt_llm_llama.load(weights) optimize_model(tensorrt_llm_llama, **opt_flags) @@ -141,7 +133,7 @@ def _gen_tensorrt_llm_engine(self, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, - strongly_typed=(dtype in ["float16", "bfloat16"]), + strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() @@ -508,7 +500,6 @@ def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): tp_size = tp_info[0] rank = tp_info[1] - dtype = "float16" use_parallel_embedding = (emb_sharding_dim >= 0) embedding_sharding_dim = abs(emb_sharding_dim) hf_llama = LlamaForCausalLM.from_pretrained( @@ -540,48 +531,29 @@ def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): 'mapping': { 'world_size': tp_size, 'tp_size': tp_size, + 'rank': rank, }, - "moe_config": { + "moe": { "num_experts": 0, "top_k": 0, - "tp_mode": 2, - "normalization_mode": 1 + "tp_mode": 1, + "normalization_mode": 1, }, 'use_parallel_embedding': use_parallel_embedding, 'embedding_sharding_dim': embedding_sharding_dim, - 'moe_num_experts': 0, - 'moe_top_k': 0, - 'moe_tp_mode': 1, - 'moe_normalization_mode': 1, 'use_fused_mlp': False, } - cfg = PretrainedConfig.from_dict(config) - tensorrt_llm_llama_wHF = tensorrt_llm.models.LLaMAForCausalLM(cfg) - tensorrt_llm_llama_wHF = optimize_model( - tensorrt_llm_llama_wHF, - use_parallel_embedding=use_parallel_embedding) + + config = PretrainedConfig.from_dict(config) + tensorrt_llm_llama_wHF = tensorrt_llm.models.LLaMAForCausalLM(config) # print_layers(tensorrt_llm_llama_wHF) - weights_wHF = load_from_hf_llama(tensorrt_llm_llama_wHF, - hf_llama, - mapping=tensorrt_llm.Mapping( - world_size=tp_size, - rank=rank, - tp_size=tp_size), - dtype=dtype) + weights_wHF = load_weights_from_hf_model(hf_llama, config) tensorrt_llm_llama_wHF.load(weights_wHF) # print_layers(tensorrt_llm_llama_wHF) - tensorrt_llm_llama_wMETA = tensorrt_llm.models.LLaMAForCausalLM(cfg) - tensorrt_llm_llama_wMETA = optimize_model( - tensorrt_llm_llama_wMETA, - use_parallel_embedding=use_parallel_embedding) + tensorrt_llm_llama_wMETA = tensorrt_llm.models.LLaMAForCausalLM(config) # print_layers(tensorrt_llm_llama_wMETA) - weights_wMETA = load_from_meta_llama(meta_path, - mapping=tensorrt_llm.Mapping( - world_size=tp_size, - rank=rank, - tp_size=tp_size), - config=cfg) + weights_wMETA = load_weights_from_meta_ckpt(meta_path, config) tensorrt_llm_llama_wMETA.load(weights_wMETA) # print_layers(tensorrt_llm_llama_wMETA) # token embedding diff --git a/tests/model/test_mistral.py b/tests/model/test_mistral.py index 2edfb5266..235acc7f7 100644 --- a/tests/model/test_mistral.py +++ b/tests/model/test_mistral.py @@ -29,9 +29,9 @@ import tensorrt_llm from tensorrt_llm import Builder from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_str -from tensorrt_llm.models.llama.weight import (load_from_hf_llama, - load_from_meta_llama) -from tensorrt_llm.models.modeling_utils import PretrainedConfig, optimize_model +from tensorrt_llm.models import PretrainedConfig +from tensorrt_llm.models.llama.convert import (load_weights_from_hf_model, + load_weights_from_meta_ckpt) from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType @@ -75,26 +75,23 @@ def _gen_tensorrt_llm_network(self, network, hf_mistral, 'mapping': { 'world_size': tensor_parallel, 'tp_size': tensor_parallel, + 'rank': rank, }, 'use_parallel_embedding': False, 'embedding_sharding_dim': 0, - 'moe_num_experts': 0, - 'moe_top_k': 0, - 'moe_tp_mode': 1, - 'moe_normalization_mode': 1, + "moe": { + "num_experts": 0, + "top_k": 0, + "tp_mode": 1, + "normalization_mode": 1, + }, 'use_fused_mlp': False, } # Initialize model - tensorrt_llm_mistral = tensorrt_llm.models.LLaMAForCausalLM( - PretrainedConfig.from_dict(config)) - weights = load_from_hf_llama(tensorrt_llm_mistral, - hf_mistral, - dtype=dtype, - mapping=tensorrt_llm.Mapping( - world_size=tensor_parallel, - rank=rank, - tp_size=tensor_parallel)) + config = PretrainedConfig.from_dict(config) + tensorrt_llm_mistral = tensorrt_llm.models.LLaMAForCausalLM(config) + weights = load_weights_from_hf_model(hf_mistral, config) tensorrt_llm_mistral.load(weights) # Prepare network.set_named_parameters( @@ -136,7 +133,7 @@ def _gen_tensorrt_llm_engine(self, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, - strongly_typed=(dtype in ["float16", "bfloat16"]), + strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() @@ -457,7 +454,6 @@ def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): tp_size = tp_info[0] rank = tp_info[1] - dtype = "float16" use_parallel_embedding = (emb_sharding_dim >= 0) embedding_sharding_dim = abs(emb_sharding_dim) hf_mistral = MistralForCausalLM.from_pretrained( @@ -491,53 +487,29 @@ def print_layers(m: tensorrt_llm.models.LLaMAForCausalLM): 'mapping': { 'world_size': tp_size, 'tp_size': tp_size, + 'rank': rank, }, - "moe_config": { + "moe": { "num_experts": 0, "top_k": 0, - "tp_mode": 2, + "tp_mode": 1, "normalization_mode": 1 }, 'use_parallel_embedding': use_parallel_embedding, 'embedding_sharding_dim': embedding_sharding_dim, - 'moe_num_experts': 0, - 'moe_top_k': 0, - 'moe_tp_mode': 1, - 'moe_normalization_mode': 1, 'use_fused_mlp': False, } - cfg = PretrainedConfig.from_dict(config) + config = PretrainedConfig.from_dict(config) + tensorrt_llm_mistral_wHF = tensorrt_llm.models.LLaMAForCausalLM(config) # print_layers(tensorrt_llm_mistral_wHF) - tensorrt_llm_mistral_wHF = tensorrt_llm.models.LLaMAForCausalLM(cfg) - tensorrt_llm_mistral_wHF = optimize_model( - tensorrt_llm_mistral_wHF, - use_parallel_embedding=use_parallel_embedding) - - weights = load_from_hf_llama(tensorrt_llm_mistral_wHF, - hf_mistral, - mapping=tensorrt_llm.Mapping( - world_size=tp_size, - rank=rank, - tp_size=tp_size), - dtype=dtype) - + weights = load_weights_from_hf_model(hf_mistral, config) tensorrt_llm_mistral_wHF.load(weights) - # print_layers(tensorrt_llm_mistral_wHF) - tensorrt_llm_mistral_wMAI = tensorrt_llm.models.LLaMAForCausalLM(cfg) - tensorrt_llm_mistral_wMAI = optimize_model( - tensorrt_llm_mistral_wMAI, - use_parallel_embedding=use_parallel_embedding) - + tensorrt_llm_mistral_wMAI = tensorrt_llm.models.LLaMAForCausalLM(config) # print_layers(tensorrt_llm_mistral_wMAI) - weights = load_from_meta_llama(mistralai_path, - mapping=tensorrt_llm.Mapping( - world_size=tp_size, - rank=rank, - tp_size=tp_size), - dtype=dtype) + weights = load_weights_from_meta_ckpt(mistralai_path, config) tensorrt_llm_mistral_wMAI.load(weights) # print_layers(tensorrt_llm_mistral_wMAI) # token embedding diff --git a/tests/model/test_phi.py b/tests/model/test_phi.py index bbfc891f2..584889e36 100644 --- a/tests/model/test_phi.py +++ b/tests/model/test_phi.py @@ -21,7 +21,7 @@ import numpy as np import torch from parameterized import parameterized -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import PhiConfig, PhiForCausalLM import tensorrt_llm from tensorrt_llm import Builder @@ -35,9 +35,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import skip_fp32_accum_pre_ampere, unittest_name_func -# Fixed code revision or updated config can break the tests. -HF_CODE_REVISION = "cb2f4533604d8b67de604e7df03bfe6f3ca22869" - def compare_max_abs_error(ref, res, str): # calculate max abs error @@ -55,22 +52,17 @@ def setUp(self): torch.random.manual_seed(1773) def generate_hf_model(self, dtype: str): - # Need to use the latest remote code for config and model class. - gpt_config = AutoConfig.from_pretrained("microsoft/phi-2", - code_revision=HF_CODE_REVISION, - trust_remote_code=True) - gpt_config.num_hidden_layers = 2 - model = AutoModelForCausalLM.from_config( - gpt_config, trust_remote_code=True).cuda().to( - tensorrt_llm._utils.str_dtype_to_torch(dtype)).eval() - return gpt_config, model + phi_config = PhiConfig(num_hidden_layers=2) + model = PhiForCausalLM(phi_config).cuda().to( + tensorrt_llm._utils.str_dtype_to_torch(dtype)).eval() + return phi_config, model def initialize_network(self, network: tensorrt_llm.Network, hf_model, hf_config, dtype: str, batch_size: int, beam_width: int, input_len: int, output_len: int, tensor_parallel: int, rank: int): config = { - 'architecture': hf_config.architectures[0], + 'architecture': 'PhiForCausalLM', 'dtype': dtype, 'num_hidden_layers': hf_config.num_hidden_layers, 'num_attention_heads': hf_config.num_key_value_heads, @@ -129,7 +121,6 @@ def generate_trtllm_runtime(self, runtime = None builder = Builder() - fp16 = (dtype == 'float16') with tempfile.TemporaryDirectory() as tmpdirname: builder_config = builder.create_builder_config( @@ -138,7 +129,7 @@ def generate_trtllm_runtime(self, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, - strongly_typed=fp16, + strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() diff --git a/tests/model_api/test_model_level_api.py b/tests/model_api/test_model_level_api.py index 6eb0fc641..04120ecf6 100644 --- a/tests/model_api/test_model_level_api.py +++ b/tests/model_api/test_model_level_api.py @@ -10,6 +10,7 @@ from tensorrt_llm.builder import BuildConfig, build from tensorrt_llm.executor import GenerationExecutor, SamplingConfig from tensorrt_llm.models import LLaMAForCausalLM +from tensorrt_llm.models.llama.config import LLaMAConfig sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.llm_data import llm_models_root @@ -80,7 +81,7 @@ def test_save_load(): @profile(tag="fake-weights") @force_ampere def test_high_level_fake_weights(): - '''sanity to make sure the flow works. The key is "skip_loading_weights" param + '''sanity to make sure the flow works. ''' input_text = [ 'Born in north-east France, Soyer trained as a', @@ -90,9 +91,8 @@ def test_high_level_fake_weights(): hf_model_dir = llm_models_root() / "llama-models/llama-7b-hf" # Fake weights, skipping save and load engine. Make it faster to sanity test - llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, - 'float16', - skip_loading_weights=True) + config = LLaMAConfig.from_hugging_face(hf_model_dir, dtype='float16') + llama = LLaMAForCausalLM(config) build_config = BuildConfig(max_batch_size=max_batch_size, max_input_len=max_isl, max_output_len=max_osl, diff --git a/tests/test_module.py b/tests/test_module.py index 81685b7ee..30e3513fe 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -76,6 +76,15 @@ def test_module(self): self.assertEqual(4, len(list(m.named_modules()))) self.assertEqual(5, len(list(m.named_network_outputs()))) + self.assertEqual( + [("", m), ("m1", m.m1), ("m1.m1", m.m1.m1), ("m1.m2", m.m1.m2)], + list(m.named_modules()), + ) + self.assertEqual( + [("", m, None), ("m1", m.m1, m), ("m1.m1", m.m1.m1, m.m1), + ("m1.m2", m.m1.m2, m.m1)], + list(m.named_modules_with_parent()), + ) def test_module_list(self): m = Module4() diff --git a/tests/utils/util.py b/tests/utils/util.py index d7e59d348..5c93b786d 100644 --- a/tests/utils/util.py +++ b/tests/utils/util.py @@ -11,6 +11,7 @@ import tensorrt_llm from tensorrt_llm._utils import torch_dtype_to_trt, trt_dtype_to_torch from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.quantization import QuantMode from tensorrt_llm.runtime import TensorInfo @@ -155,7 +156,8 @@ def create_session(builder, int8=False, opt_level=None, memory_pool_limit=None, - optimization_profiles=[]): + optimization_profiles=[], + quant_mode=QuantMode(0)): """ This function creates an engine and a tensorrt_llm.runtime.Session for the engine. Args: @@ -167,7 +169,8 @@ def create_session(builder, """ builder_config = builder.create_builder_config(precision=precision, int8=int8, - opt_level=opt_level) + opt_level=opt_level, + quant_mode=quant_mode) # Some tests require to set mem pool limit to avoid OOM if memory_pool_limit is not None: builder_config.trt_builder_config.set_memory_pool_limit( @@ -176,6 +179,8 @@ def create_session(builder, if len(optimization_profiles) > 0: for profile in optimization_profiles: builder_config.trt_builder_config.add_optimization_profile(profile) + # Disable TF32 for accuracy in testing. + builder_config.trt_builder_config.clear_flag(trt.BuilderFlag.TF32) engine = builder.build_engine(network, builder_config) assert engine is not None, "Failed to build engine" session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)