Skip to content

Commit

Permalink
Update TensorRT-LLM
Browse files Browse the repository at this point in the history
Co-authored-by: Rong Zhou <[email protected]>
Co-authored-by: Onur Galoglu <[email protected]>
Co-authored-by: Fabian Joswig <[email protected]>
  • Loading branch information
4 people authored Aug 20, 2024
1 parent 74b324f commit 32ed92e
Show file tree
Hide file tree
Showing 259 changed files with 7,861 additions and 2,845 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ __pycache__/
*.cache
*.nsys-rep
.VSCodeCounter
build*/
cpp/build*
build
!tensorrt_llm/bench/build
!builders/
*.egg-info/
.coverage
Expand Down
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ TensorRT-LLM
<div align="left">

## Latest News
* [2024/08/06] 🗫 Multilingual Challenge Accepted 🗫
🤖 #TensorRT #LLM boosts low-resource languages like Hebrew, Indonesian and Vietnamese ⚡[➡️ link](https://developer.nvidia.com/blog/accelerating-hebrew-llm-performance-with-nvidia-tensorrt-llm/?linkId=100000278659647)
* [2024/08/13] 🐍 DIY Code Completion with #Mamba ⚡ #TensorRT #LLM for speed 🤖 NIM for ease ☁️ deploy anywhere
[➡️ link](https://developer.nvidia.com/blog/revolutionizing-code-completion-with-codestral-mamba-the-next-gen-coding-llm/)
<div align="center">
<img src="docs/source/media/picture-08-06-2024.png" width="50%">
<img src="docs/source/media/picture-08-13-2024.png" width="50%">
<div align="left">

* [2024/08/06] 🗫 Multilingual Challenge Accepted 🗫
🤖 #TensorRT #LLM boosts low-resource languages like Hebrew, Indonesian and Vietnamese ⚡[➡️ link](https://developer.nvidia.com/blog/accelerating-hebrew-llm-performance-with-nvidia-tensorrt-llm/?linkId=100000278659647)

* [2024/07/30] Introducing🍊 @SliceXAI ELM Turbo 🤖 train ELM once ⚡ #TensorRT #LLM optimize ☁️ deploy anywhere
[➡️ link](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms)

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ There are currently three workflows to benchmark TensorRT-LLM:
- The recommended workflow that uses TensorRT-LLM C++ API and can take advantage of the latest features of TensorRT-LLM.
* [Python benchmarks](./python)
- The Python benchmarking scripts can only benchmark the Python runtime, which do not support the latest features, such as in-flight batching.
* [The Python benchmarking suite](./suite)
* [The Python benchmarking suite](./Suite.md)
- This benchmarking suite is a current work in progress and is prone to large changes.
316 changes: 316 additions & 0 deletions benchmarks/Suite.md

Large diffs are not rendered by default.

109 changes: 84 additions & 25 deletions benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
Expand Down Expand Up @@ -173,6 +174,9 @@ struct BenchmarkParams

// Decoding params
std::optional<std::vector<std::vector<SizeType32>>> medusaChoices;

std::optional<texec::LookaheadDecodingConfig> executorLookaheadConfig;
std::optional<texec::LookaheadDecodingConfig> requestLookaheadConfig;
};

class InferenceRequestsAsyncSend
Expand Down Expand Up @@ -509,6 +513,7 @@ class Recorder
{
if (!mStreaming)
{
TLLM_LOG_DEBUG("response.getResult().outputTokenIds");
auto outputTokenIds = response.getResult().outputTokenIds;

int32_t outSeqLen = 0;
Expand Down Expand Up @@ -824,9 +829,11 @@ class ExecutorServer
executorConfig.setMaxNumTokens(benchmarkParams.maxNumTokens.value());
}

executorConfig.setDecodingConfig(texec::DecodingConfig(
benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() : texec::DecodingMode::Auto(),
std::nullopt, benchmarkParams.medusaChoices));
executorConfig.setDecodingConfig(
texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa()
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices));
executorConfig.setExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig);

if (executorModelType == texec::ModelType::kDECODER_ONLY)
Expand Down Expand Up @@ -910,7 +917,7 @@ class ExecutorServer
for (auto const& response : responses)
{
auto const reqId = response.getRequestId();

TLLM_LOG_DEBUG("response.getResult().isFinal");
if (response.getResult().isFinal)
{
mActiveCount--;
Expand Down Expand Up @@ -1323,7 +1330,8 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr,
ITensor::SharedPtr const& returnGenerationLogits = nullptr, ITensor::SharedPtr const& loraWeights = nullptr,
ITensor::SharedPtr const& loraConfig = nullptr)
ITensor::SharedPtr const& loraConfig = nullptr,
std::optional<tensorrt_llm::executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt)
{
auto request = std::make_shared<InferenceRequest>(reqId);
auto const& inputIds = sample.inputIds;
Expand Down Expand Up @@ -1361,6 +1369,10 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
{
request->setLoraConfig(loraConfig);
}
if (lookaheadConfig)
{
request->setLookaheadConfig(lookaheadConfig.value());
}
if (streaming)
{
request->setIsStreaming(true);
Expand All @@ -1372,18 +1384,20 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
std::optional<SizeType32> const& eosId, std::optional<SizeType32> const& padId, bool streaming = false,
bool const& returnContextLogits = false, bool const& returnGenerationLogits = false,
std::optional<texec::LoraConfig> const& loraConfig = std::nullopt,
std::optional<texec::LookaheadDecodingConfig> const& lookaheadConfig = std::nullopt,
std::optional<texec::VecTokens> encoderInputTokenIds = std::nullopt)
{
auto samplingConfig = texec::SamplingConfig{beamWidth};
auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false};
return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId,
std::nullopt, // badWords
std::nullopt, // stopWords
std::nullopt, // embeddingBias
std::nullopt, // speculativeDecoding
std::nullopt, // pTuning
loraConfig,
std::nullopt, // logitsPostProcessorName
std::nullopt, // badWords
std::nullopt, // stopWords
std::nullopt, // embeddingBias
std::nullopt, // speculativeDecoding
std::nullopt, // pTuning
loraConfig, // loraConfig
lookaheadConfig, // lookaheadConfig
std::nullopt, // logitsPostProcessorName
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
}

Expand Down Expand Up @@ -1429,9 +1443,11 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
optionalParams.maxBatchSize = benchmarkParams.maxBatchSize;
optionalParams.maxNumTokens = benchmarkParams.maxNumTokens;
optionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy};
optionalParams.decodingConfig = texec::DecodingConfig(
benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa() : texec::DecodingMode::Auto(),
std::nullopt, benchmarkParams.medusaChoices);
optionalParams.decodingConfig
= texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa()
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);

Expand Down Expand Up @@ -1501,8 +1517,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
++reqId;
if (i == terminateReqId)
++reqId;
auto request = makeRequest(
reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
auto request = makeRequest(reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, nullptr, nullptr, nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
Expand All @@ -1517,7 +1533,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor, nullptr,
nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);

if (i < numSamples - 1)
Expand All @@ -1541,7 +1558,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor,
nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
Expand Down Expand Up @@ -1644,13 +1662,13 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
{
Sample s{std::vector<int32_t>{decoderStartTokenId}, 1, static_cast<int32_t>(taskId)};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false,
loraConfig, std::vector<int32_t>{1, 2, 3, 4, 5}));
loraConfig, std::nullopt, std::vector<int32_t>{1, 2, 3, 4, 5}));
}
else
{
Sample s{std::vector<int32_t>{1, 2, 3, 4, 5}, 1, static_cast<int32_t>(taskId)};
requests.emplace_back(
makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig));
makeExecutorRequest(s, beamWidth, eosId, padId, false, false, false, loraConfig, std::nullopt));
}
}
executorServer->enqueue(std::move(requests), true);
Expand All @@ -1668,12 +1686,14 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
{
Sample s{std::vector<int32_t>{decoderStartTokenId}, samples[0].outputLen, samples[0].taskId};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming,
returnContextLogits, returnGenerationLogits, std::nullopt, samples[0].inputIds));
returnContextLogits, returnGenerationLogits, std::nullopt,
benchmarkParams.requestLookaheadConfig, samples[0].inputIds));
}
else
{
requests.emplace_back(makeExecutorRequest(samples[0], beamWidth, eosId, padId,
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits));
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, std::nullopt,
benchmarkParams.requestLookaheadConfig));
}
}
executorServer->enqueue(std::move(requests), true);
Expand All @@ -1699,12 +1719,14 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
{
Sample s{std::vector<int32_t>{decoderStartTokenId}, samples[i].outputLen, samples[i].taskId};
requests.emplace_back(makeExecutorRequest(s, beamWidth, eosId, padId, benchmarkParams.streaming,
returnContextLogits, returnGenerationLogits, loraConfig, samples[i].inputIds));
returnContextLogits, returnGenerationLogits, loraConfig, benchmarkParams.requestLookaheadConfig,
samples[i].inputIds));
}
else
{
requests.emplace_back(makeExecutorRequest(samples[i], beamWidth, eosId, padId,
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig));
benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig,
benchmarkParams.requestLookaheadConfig));
}
}

Expand Down Expand Up @@ -1789,6 +1811,25 @@ std::vector<std::vector<SizeType32>> parseVectorOfVectors(std::string const& inp
return result;
}

texec::LookaheadDecodingConfig parseLookaheadConfig(std::string const& input)
{
std::regex regex("\\[ *(\\d+) *, *(\\d+) *, *(\\d+) *\\]");
std::smatch match;
if (std::regex_match(input, match, regex))
{
TLLM_CHECK(match.size() == 4);
auto w = std::stoi(match[1]);
auto n = std::stoi(match[2]);
auto g = std::stoi(match[3]);
return texec::LookaheadDecodingConfig(w, n, g);
}
else
{
TLLM_LOG_WARNING("cannot parse lookahead config from '%s'", input.c_str());
return texec::LookaheadDecodingConfig();
}
}

} // namespace

int main(int argc, char* argv[])
Expand Down Expand Up @@ -1898,6 +1939,14 @@ int main(int argc, char* argv[])

options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("executor_lookahead_config",
"lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size]",
cxxopts::value<std::string>());

options.add_options()("request_lookahead_config",
"lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size], and each <= "
"executor lookahead config",
cxxopts::value<std::string>());

auto result = options.parse(argc, argv);

Expand Down Expand Up @@ -2055,6 +2104,16 @@ int main(int argc, char* argv[])
{
benchmarkParams.medusaChoices = parseVectorOfVectors(result["medusa_choices"].as<std::string>());
}
if (result.count("executor_lookahead_config"))
{
benchmarkParams.executorLookaheadConfig
= parseLookaheadConfig(result["executor_lookahead_config"].as<std::string>());
}
if (result.count("request_lookahead_config"))
{
benchmarkParams.requestLookaheadConfig
= parseLookaheadConfig(result["request_lookahead_config"].as<std::string>());
}

// Argument: multi_block_mode
benchmarkParams.multiBlockMode = result["multi_block_mode"].as<bool>();
Expand Down
25 changes: 12 additions & 13 deletions benchmarks/python/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import tensorrt_llm as tllm
from tensorrt_llm import Mapping, Tensor
from tensorrt_llm._ipc_utils import peer_access
from tensorrt_llm._utils import OMPI_COMM_TYPE_HOST, mpi_comm
from tensorrt_llm.functional import AllReduceStrategy, allreduce
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
Expand Down Expand Up @@ -106,18 +105,18 @@ def allreduce_benchmark(dtype: str,
_, start = cuda.cuEventCreate(0)
_, stop = cuda.cuEventCreate(0)
runtimes = []
with peer_access(mapping):
tllm.mpi_barrier()

for _ in range(10):
cuda.cuEventRecord(start, stream.cuda_stream)
session.run(inputs=feed_dict,
outputs={"output": output},
stream=stream.cuda_stream)
cuda.cuEventRecord(stop, stream.cuda_stream)
torch.cuda.synchronize()
_, ms = cuda.cuEventElapsedTime(start, stop)
runtimes.append(ms)

tllm.mpi_barrier()

for _ in range(10):
cuda.cuEventRecord(start, stream.cuda_stream)
session.run(inputs=feed_dict,
outputs={"output": output},
stream=stream.cuda_stream)
cuda.cuEventRecord(stop, stream.cuda_stream)
torch.cuda.synchronize()
_, ms = cuda.cuEventElapsedTime(start, stop)
runtimes.append(ms)

median_ms = sorted(runtimes)[len(runtimes) // 2]
assert torch.allclose(output, (input * world_size)**inner_loop)
Expand Down
Loading

0 comments on commit 32ed92e

Please sign in to comment.