Skip to content

Commit

Permalink
[Serving] Support multi-threading CPU sampling (mlc-ai#1232)
Browse files Browse the repository at this point in the history
This PR supports the multi-threading token sampling on CPU.

In serving scenarios, the token sampling process becomes one of the
bottlenecks, as the model computation has higher throughput than
single-sequence settings. Therefore, we enhances the CPU sampling
with multi-threading.

Particularly,
* this PR changes the design scope of Sampler. Prior to this PR,
the sampling function exposed by Sampler focuses on sampling a
single token. After this PR, a function processes a batch of
tokens. This makes the multi-threading more manageable.
* the multi-threading at this moment is backed by OpenMP, according
to our micro-benchmark.

Note: to effectively enable OpenMP, now need to compile with gcc/g++.
  • Loading branch information
MasterJH5574 committed Nov 13, 2023
1 parent 3f38242 commit 36ea52d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 65 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ project(mlc_llm C CXX)
include(CheckCXXCompilerFlag)
if(NOT MSVC)
check_cxx_compiler_flag("-std=c++17" SUPPORT_CXX17)
set(CMAKE_CXX_FLAGS "-std=c++17 ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "-std=c++17 -O3 -fopenmp -ffast-math -march=native ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_STANDARD 17)
else()
check_cxx_compiler_flag("/std:c++17" SUPPORT_CXX17)
set(CMAKE_CXX_FLAGS "/std:c++17 ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "/std:c++17 -O3 -fopenmp -ffast-math -march=native ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_STANDARD 17)
endif()

Expand Down
37 changes: 12 additions & 25 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "../random.h"
#include "model.h"
#include "request.h"
#include "request_state.h"
Expand Down Expand Up @@ -289,8 +288,8 @@ class Engine {
ICHECK_EQ(logits->shape[0], 1);
ICHECK_EQ(logits->shape[1], 1);

std::vector<int32_t> next_token = SampleTokens(logits, /*model_id=*/0, /*sampler_id=*/0,
{state.mstates[0]}, {request->generation_cfg});
ShapeTuple next_token = SampleTokens(logits, /*model_id=*/0, /*sampler_id=*/0,
{state.mstates[0]}, {request->generation_cfg});
ICHECK_EQ(next_token.size(), 1);

// - Update the committed tokens of states.
Expand Down Expand Up @@ -356,7 +355,7 @@ class Engine {
ICHECK_EQ(logits->shape[1], 1);

// - Sample tokens.
std::vector<int32_t> next_tokens =
ShapeTuple next_tokens =
SampleTokens(logits, /*model_id=*/0, /*sampler_id=*/0, mstates, generation_cfg);
ICHECK_EQ(next_tokens.size(), num_requests);

Expand Down Expand Up @@ -699,9 +698,9 @@ class Engine {
* in the input batch.
* \return The sampled tokens, one for each request in the batch.
*/
std::vector<int32_t> SampleTokens(NDArray logits_on_device, int model_id, int sampler_id,
Array<RequestModelState> request_mstates,
Array<GenerationConfig> generation_cfg) {
ShapeTuple SampleTokens(NDArray logits_on_device, int model_id, int sampler_id,
Array<RequestModelState> request_mstates,
Array<GenerationConfig> generation_cfg) {
ICHECK(logits_on_device.defined());
ICHECK_EQ(logits_on_device->ndim, 3);
ICHECK_EQ(logits_on_device->shape[1], 1)
Expand All @@ -723,12 +722,10 @@ class Engine {
logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(probs_on_device);
} else {
logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(logits_on_device);
for (int i = 0; i < num_sequence; ++i) {
// The "compute_probs_from_logits_inplace" function updates
// `logits_or_probs_on_cpu` in place.
fsampler_compute_probs_from_logits_inplace_[sampler_id](
logits_or_probs_on_cpu, /*token_offset=*/i, request_mstates[i], generation_cfg[i]);
}
// The "compute_probs_from_logits_inplace" function updates
// `logits_or_probs_on_cpu` in place.
fsampler_compute_probs_from_logits_inplace_[sampler_id](
logits_or_probs_on_cpu, std::move(request_mstates), generation_cfg);
}
// `CopyLogitsOrProbsToCPU` flattens the first two dimensions.
ICHECK_EQ(logits_or_probs_on_cpu->ndim, 2);
Expand All @@ -737,18 +734,8 @@ class Engine {
// NOTE: Though we have the probability field in RequestModelState,
// we do not save the probabilities right now.
// We will handle this in the future when we work on speculation.
std::vector<int32_t> new_tokens;
std::vector<double> randnums;
new_tokens.reserve(num_sequence);
randnums.reserve(num_sequence);
for (int i = 0; i < num_sequence; ++i) {
randnums.push_back(RandomGenerator::GetInstance().GetRandomNumber());
}
for (int i = 0; i < num_sequence; ++i) {
int32_t token_id = fsampler_sample_token_from_probs_[sampler_id](
logits_or_probs_on_cpu, /*token_offset=*/i, generation_cfg[i], randnums[i]);
new_tokens.push_back(token_id);
}
ShapeTuple new_tokens =
fsampler_sample_token_from_probs_[sampler_id](logits_or_probs_on_cpu, generation_cfg);
return new_tokens;
}

Expand Down
96 changes: 58 additions & 38 deletions cpp/serve/sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

#include "sampler.h"

#include <omp.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <cmath>

#include "../random.h"
#include "request_state.h"

namespace mlc {
Expand All @@ -28,7 +30,10 @@ namespace serve {
*/
class SamplerModule : public ModuleNode {
public:
explicit SamplerModule(DLDevice device) : device_(device) {
explicit SamplerModule(DLDevice device) : device_(device), rng_(RandomGenerator::GetInstance()) {
// Set number of sampling threads from OpenMP.
num_threads_ = static_cast<int>(omp_get_max_threads()) / 2;

// Set sampling function.
auto fsample_topp_from_prob_ptr =
tvm::runtime::Registry::Get("vm.builtin.sample_top_p_from_prob");
Expand All @@ -48,13 +53,13 @@ class SamplerModule : public ModuleNode {
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "compute_probs_from_logits_inplace") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 4);
ComputeProbsFromLogitsInplace(args[0], args[1], args[2], args[3]);
CHECK_EQ(args.size(), 3);
ComputeProbsFromLogitsInplace(args[0], args[1], args[2]);
});
} else if (name == "sample_token_from_probs") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 4);
*rv = SampleTokenFromProbs(args[0], args[1], args[2], args[3]);
CHECK_EQ(args.size(), 2);
*rv = SampleTokenFromProbs(args[0], args[1]);
});
} else if (name == "require_gpu_softmax") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
Expand Down Expand Up @@ -92,60 +97,69 @@ class SamplerModule : public ModuleNode {

/*!
* \brief Compute the probability distribution from on-cpu logits for
* a **single token** **in place**.
* The input logits are batched. We use an input "token offset"
* to determine the start logit offset of the token to compute.
* a batch of tokens **in place**.
* \param logits The input logits on CPU.
* \param token_offset The input token offset to determine where the
* logits of the target token start.
* \param state The request state, which contains the history generated tokens.
* \param states The request states, which contains the history generated tokens.
* \param generation_cfg The generation config.
* \note The function returns nothing. It in-place updates the input logits array.
*/
void ComputeProbsFromLogitsInplace(NDArray logits, int token_offset, RequestModelState state,
GenerationConfig generation_cfg) {
void ComputeProbsFromLogitsInplace(NDArray logits, Array<RequestModelState> states,
Array<GenerationConfig> generation_cfg) {
// logits: (n, v)
CHECK_EQ(logits->ndim, 2);
CHECK_EQ(logits->device.device_type, kDLCPU);

// - Invoke environment compute function if exists.
if (flogits_to_probs_inplace_.defined()) {
flogits_to_probs_inplace_(logits, token_offset, state, generation_cfg);
flogits_to_probs_inplace_(logits, states, generation_cfg);
return;
}

// - Apply repetition penalty (inplace).
if (generation_cfg->repetition_penalty != 1.0) {
ApplyRepetitionPenaltyOnCPU(logits, token_offset, state, generation_cfg->repetition_penalty);
}
// - Compute probability (inplace) from logits.
// Using softmax if temperature is non-zero.
// Or set probability of the max-logit position to 1.
if (generation_cfg->temperature >= 1e-6) {
ApplySoftmaxWithTemperatureOnCPU(logits, token_offset, generation_cfg->temperature);
} else {
SetProbWithArgmaxOnCPU(logits, token_offset);
int n = logits->shape[0];
#pragma omp parallel for num_threads(this->num_threads_)
for (int i = 0; i < n; ++i) {
// - Apply repetition penalty (inplace).
if (generation_cfg[i]->repetition_penalty != 1.0) {
ApplyRepetitionPenaltyOnCPU(logits, i, states[i], generation_cfg[i]->repetition_penalty);
}
// - Compute probability (inplace) from logits.
// Using softmax if temperature is non-zero.
// Or set probability of the max-logit position to 1.
if (generation_cfg[i]->temperature >= 1e-6) {
ApplySoftmaxWithTemperatureOnCPU(logits, i, generation_cfg[i]->temperature);
} else {
SetProbWithArgmaxOnCPU(logits, i);
}
}
}

/*!
* \brief Sample a token from the input probability distribution.
* The input prob distribution are batched. We use an input "token offset"
* to determine the start probability offset of the token to compute.
* \param probs The input probability distribution.
* \param token_offset The input token offset to determine where the
* probability distribution of the target token start.
* \brief Sample tokens from a batch of input probability distributions.
* \param probs The input batch of probability distributions.
* \param generation_cfg The generation config.
* \param random_number A random number for sampling.
* \return The sampled token.
* \return The sampled tokens, one for each instance of the batch.
*/
int32_t SampleTokenFromProbs(NDArray probs, int token_offset, GenerationConfig generation_cfg,
double random_number) {
ShapeTuple SampleTokenFromProbs(NDArray probs, Array<GenerationConfig> generation_cfg) {
// probs: (n, v)
CHECK_EQ(probs->ndim, 2);
CHECK_EQ(probs->device.device_type, kDLCPU);
// Sample top p from probability.
return fsample_topp_from_prob_(probs, token_offset, generation_cfg->top_p, random_number);

int n = probs->shape[0];
std::vector<double> random_numbers;
std::vector<int32_t> sampled_tokens;
random_numbers.reserve(n);
sampled_tokens.resize(n);
for (int i = 0; i < n; ++i) {
random_numbers.push_back(rng_.GetRandomNumber());
}

#pragma omp parallel for num_threads(this->num_threads_)
for (int i = 0; i < n; ++i) {
// Sample top p from probability.
sampled_tokens[i] =
fsample_topp_from_prob_(probs, i, generation_cfg[i]->top_p, random_numbers[i]);
}
return ShapeTuple(sampled_tokens.begin(), sampled_tokens.end());
}

/*! \brief Apply repetition penalty to logits based on history tokens. */
Expand Down Expand Up @@ -182,7 +196,9 @@ class SamplerModule : public ModuleNode {
CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU);
int vocab_size = logits->shape[1];

float* logits_raw_data = static_cast<float*>(logits->data) + (token_offset * vocab_size);
float* __restrict logits_raw_data =
static_cast<float*>(__builtin_assume_aligned(logits->data, 4)) +
(token_offset * vocab_size);
float m = std::numeric_limits<float>::min();
float inv_temp = 1.0f / temperature;
double d = 0.0f;
Expand Down Expand Up @@ -227,6 +243,10 @@ class SamplerModule : public ModuleNode {

/*! \brief The runtime device where the input logits is. */
DLDevice device_;
/*! \brief Number of CPU threads for parallel sampling and other computation. */
int num_threads_;
/*! \brief The random generator. */
RandomGenerator& rng_;
/*! \brief Customized function which computes prob distribution from logits */
PackedFunc flogits_to_probs_inplace_;
/*! \brief Function which samples a token from prob distribution with top_p value. */
Expand Down

0 comments on commit 36ea52d

Please sign in to comment.