Skip to content

Commit

Permalink
open source 7f370deb0090d885d7518c2b146399ba3933c004 (#2273)
Browse files Browse the repository at this point in the history
* Update TensorRT-LLM

---------
Co-authored-by: Qingquan Song <[email protected]>
  • Loading branch information
DanBlanaru authored Sep 30, 2024
1 parent 40274aa commit 48686bc
Show file tree
Hide file tree
Showing 253 changed files with 6,042 additions and 2,958 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ TensorRT-LLM
[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/)
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.5.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.3.0-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.13.0.dev-green)](./tensorrt_llm/version.py)
[![trt](https://img.shields.io/badge/TRT-10.4.0-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.14.0.dev-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)

[Architecture](./docs/source/architecture/overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Results](./docs/source/performance/perf-overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](./examples/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](./docs/source/)
Expand All @@ -17,6 +17,12 @@ TensorRT-LLM
<div align="left">

## Latest News
* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
<div align="center">
<img src="docs/source/media/image-09-29-2024.png" width="50%">
<div align="left">

* [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup
[➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link)

Expand Down
71 changes: 61 additions & 10 deletions benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ struct BenchmarkParams
std::optional<int> sinkTokenLength{std::nullopt};
bool multiBlockMode{true};
bool enableContextFMHAFP32Acc{false};
bool cudaGraphMode{false};
SizeType32 cudaGraphCacheSize{0};

// lora / peft params
std::optional<std::string> loraDir{std::nullopt};
Expand Down Expand Up @@ -470,7 +472,38 @@ class Recorder
mRequestBenchInfos[requestId].firstTokenSeen = true;
}

mRequestBenchInfos[requestId].outputLength += 1;
mRequestBenchInfos[requestId].decodingIter += 1;
}

void recordToken(uint64_t requestId, std::list<NamedTensor> const& responseTensors)
{
int32_t outputLength = 1;
for (auto& tensor : responseTensors)
{
if (tensor.name == inference_request::kSequenceLengthTensorName)
{
// Tensor of shape nBeams, and we only need the first one
outputLength = *(bufferCast<int32_t>(*(tensor.tensor)));
break;
}
}

mRequestBenchInfos[requestId].outputLength += outputLength;
this->recordToken(requestId);
}

void recordToken(uint64_t requestId, texec::Response const& response)
{
auto outputTokenIds = response.getResult().outputTokenIds;

int32_t outputLength = 1;
for (auto const& beam : outputTokenIds)
{
outputLength = std::max(static_cast<int32_t>(beam.size()), outputLength);
}

mRequestBenchInfos[requestId].outputLength += outputLength;
this->recordToken(requestId);
}

void recordEnd(uint64_t requestId, std::list<NamedTensor> const& responseTensors, bool hasError)
Expand Down Expand Up @@ -500,7 +533,7 @@ class Recorder
}
else
{
this->recordToken(requestId);
this->recordToken(requestId, responseTensors);
}
}

Expand Down Expand Up @@ -532,7 +565,7 @@ class Recorder
}
else
{
this->recordToken(requestId);
this->recordToken(requestId, response);
}
}
}
Expand Down Expand Up @@ -821,8 +854,9 @@ class ExecutorServer
benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks);
texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8,
std::nullopt, benchmarkParams.loraHostCacheSize);
texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode,
benchmarkParams.cudaGraphCacheSize);
texec::ExecutorConfig executorConfig(
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
executorConfig.setGpuWeightsPercent(benchmarkParams.gpuWeightsPercent);
Expand Down Expand Up @@ -940,7 +974,7 @@ class ExecutorServer
{
if (!warmup && !response.hasError())
{
mRecorder->recordToken(reqId);
mRecorder->recordToken(reqId, response);
}
}
}
Expand Down Expand Up @@ -1228,7 +1262,7 @@ class GptServer
{
if (errMsg.empty())
{
mRecorder->recordToken(requestId);
mRecorder->recordToken(requestId, response_tensors);
}
}
}
Expand Down Expand Up @@ -1458,8 +1492,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, benchmarkParams.cudaGraphCacheSize);

auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
auto const worldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
Expand Down Expand Up @@ -1895,7 +1929,8 @@ int main(int argc, char* argv[])
options.add_options()("return_generation_logits", "Whether to return generation logits.",
cxxopts::value<bool>()->default_value("false"));

options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
options.add_options()("scheduler_policy",
"Choose scheduler policy between max_utilization/guaranteed_no_evict/static_batch.",
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));

options.add_options()("first_batch_delay",
Expand Down Expand Up @@ -1946,6 +1981,12 @@ int main(int argc, char* argv[])
cxxopts::value<bool>()->default_value("true"));
options.add_options()(
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
options.add_options()("cuda_graph_mode", "When enabled, inference is executed with cuda graph.",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("cuda_graph_cache_size",
"Specify how many cuda graphs are cached in the runtime. Larger cache gives better perf, but consumes more GPU "
"memory.",
cxxopts::value<SizeType32>()->default_value("0"));

options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation",
cxxopts::value<bool>()->default_value("false"));
Expand Down Expand Up @@ -2131,6 +2172,12 @@ int main(int argc, char* argv[])
// Argument: enable_context_fmha_fp32_acc
benchmarkParams.enableContextFMHAFP32Acc = result["enable_context_fmha_fp32_acc"].as<bool>();

// Argument: cuda_graph_mode
benchmarkParams.cudaGraphMode = result["cuda_graph_mode"].as<bool>();

// Argument: cuda_graph_mode
benchmarkParams.cudaGraphCacheSize = result["cuda_graph_cache_size"].as<SizeType32>();

std::optional<TokenIdType> padId;
// Argument: Padding token id
if (result.count("pad_id"))
Expand Down Expand Up @@ -2168,6 +2215,10 @@ int main(int argc, char* argv[])
{
capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT;
}
else if (capacitySchedulerPolicyArg == "static_batch")
{
capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kSTATIC_BATCH;
}
else
{
TLLM_LOG_ERROR("Unexpected scheduler policy: " + capacitySchedulerPolicyArg);
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/python/gpt_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, args, batch_sizes, in_out_lens, gpu_weights_percents,

kv_cache_type = KVCacheType.CONTINUOUS
if hasattr(self, 'kv_cache_type'):
kv_cache_type = self.kv_cache_type
kv_cache_type = KVCacheType(self.kv_cache_type)
else:
if hasattr(self, 'paged_kv_cache'):
kv_cache_type = KVCacheType.PAGED if self.paged_kv_cache == True else KVCacheType.CONTINUOUS
Expand Down
Loading

0 comments on commit 48686bc

Please sign in to comment.