Skip to content

Commit

Permalink
[Serving] Substitute OpenMP multi-threading with TVM threading backend
Browse files Browse the repository at this point in the history
This PR replaces the multi-threading sampling implementation
from using OpenMP to using the threading backend in TVM.
  • Loading branch information
MasterJH5574 committed Nov 15, 2023
1 parent 06b6c1f commit cef6522
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 39 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ project(mlc_llm C CXX)
include(CheckCXXCompilerFlag)
if(NOT MSVC)
check_cxx_compiler_flag("-std=c++17" SUPPORT_CXX17)
set(CMAKE_CXX_FLAGS "-std=c++17 -O3 -fopenmp -ffast-math -march=native ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "-std=c++17 -O3 -ffast-math -march=native ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_STANDARD 17)
else()
check_cxx_compiler_flag("/std:c++17" SUPPORT_CXX17)
set(CMAKE_CXX_FLAGS "/std:c++17 -O3 -fopenmp -ffast-math -march=native ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "/std:c++17 -O3 -ffast-math -march=native ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_STANDARD 17)
endif()

Expand Down
240 changes: 203 additions & 37 deletions cpp/serve/sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include "sampler.h"

#include <omp.h>
#include <tvm/runtime/c_backend_api.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
Expand All @@ -22,6 +22,30 @@ namespace mlc {
namespace llm {
namespace serve {

int SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample);

/*!
* \brief Execute the given lambda function in parallel with
* threading backend in TVM.
* \tparam T The type of the lambda: "void (int i)".
* \param flambda The lambda to be executed in parallel.
* It should have the signature "void (int i)".
* \param begin The start index of this parallel loop (inclusive).
* \param end The end index of this parallel loop (exclusive).
* \example
*
* The for loop
* for (int i = 0; i < 10; i++) {
* a[i] = i;
* }
* should work the same as:
* parallel_for_with_threading_backend([&a](int i) {
* a[i] = i;
* }, 0, 10);
*/
template <typename T>
inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end);

/*!
* \brief The sampler runtime module.
* It contains functions to
Expand All @@ -31,16 +55,6 @@ namespace serve {
class SamplerModule : public ModuleNode {
public:
explicit SamplerModule(DLDevice device) : device_(device), rng_(RandomGenerator::GetInstance()) {
// Set number of sampling threads from OpenMP.
num_threads_ = static_cast<int>(omp_get_max_threads()) / 2;

// Set sampling function.
auto fsample_topp_from_prob_ptr =
tvm::runtime::Registry::Get("vm.builtin.sample_top_p_from_prob");
ICHECK(fsample_topp_from_prob_ptr)
<< "Cannot find env function vm.builtin.sample_top_p_from_prob";
fsample_topp_from_prob_ = *fsample_topp_from_prob_ptr;

// Set customized "logits -> prob" function.
const PackedFunc* f_logits_to_probs =
Registry::Get("mlc.llm.compute_probs_from_logits_inplace");
Expand Down Expand Up @@ -115,22 +129,23 @@ class SamplerModule : public ModuleNode {
return;
}

int n = logits->shape[0];
#pragma omp parallel for num_threads(this->num_threads_)
for (int i = 0; i < n; ++i) {
// - Apply repetition penalty (inplace).
if (generation_cfg[i]->repetition_penalty != 1.0) {
ApplyRepetitionPenaltyOnCPU(logits, i, states[i], generation_cfg[i]->repetition_penalty);
}
// - Compute probability (inplace) from logits.
// Using softmax if temperature is non-zero.
// Or set probability of the max-logit position to 1.
if (generation_cfg[i]->temperature >= 1e-6) {
ApplySoftmaxWithTemperatureOnCPU(logits, i, generation_cfg[i]->temperature);
} else {
SetProbWithArgmaxOnCPU(logits, i);
}
}
parallel_for_with_threading_backend(
[this, &logits, &states, &generation_cfg](int i) {
// - Apply repetition penalty (inplace).
if (generation_cfg[i]->repetition_penalty != 1.0) {
ApplyRepetitionPenaltyOnCPU(logits, i, states[i],
generation_cfg[i]->repetition_penalty);
}
// - Compute probability (inplace) from logits.
// Using softmax if temperature is non-zero.
// Or set probability of the max-logit position to 1.
if (generation_cfg[i]->temperature >= 1e-6) {
ApplySoftmaxWithTemperatureOnCPU(logits, i, generation_cfg[i]->temperature);
} else {
SetProbWithArgmaxOnCPU(logits, i);
}
},
0, logits->shape[0]);
}

/*!
Expand All @@ -153,12 +168,13 @@ class SamplerModule : public ModuleNode {
random_numbers.push_back(rng_.GetRandomNumber());
}

#pragma omp parallel for num_threads(this->num_threads_)
for (int i = 0; i < n; ++i) {
// Sample top p from probability.
sampled_tokens[i] =
fsample_topp_from_prob_(probs, i, generation_cfg[i]->top_p, random_numbers[i]);
}
parallel_for_with_threading_backend(
[&sampled_tokens, &probs, &generation_cfg, &random_numbers](int i) {
// Sample top p from probability.
sampled_tokens[i] =
SampleTopPFromProb(probs, i, generation_cfg[i]->top_p, random_numbers[i]);
},
0, n);
return ShapeTuple(sampled_tokens.begin(), sampled_tokens.end());
}

Expand Down Expand Up @@ -243,21 +259,171 @@ class SamplerModule : public ModuleNode {

/*! \brief The runtime device where the input logits is. */
DLDevice device_;
/*! \brief Number of CPU threads for parallel sampling and other computation. */
int num_threads_;
/*! \brief The random generator. */
RandomGenerator& rng_;
/*! \brief Customized function which computes prob distribution from logits */
PackedFunc flogits_to_probs_inplace_;
/*! \brief Function which samples a token from prob distribution with top_p value. */
PackedFunc fsample_topp_from_prob_;
};

tvm::runtime::Module CreateSamplerModule(DLDevice device) {
ObjectPtr<SamplerModule> n = make_object<SamplerModule>(device);
return Module(n);
}

int SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample) {
// prob: (*, v)
// The prob array may have arbitrary ndim and shape.
// The last dimension corresponds to the prob distribution size.
// We use the `unit_offset` parameter to determine which slice
// of the prob array we sample from.

ICHECK(prob.IsContiguous());
ICHECK(prob.DataType() == DataType::Float(32));

if (prob->device.device_type != kDLCPU) {
prob = prob.CopyTo(DLDevice{kDLCPU, 0});
}

ICHECK(prob->device.device_type == kDLCPU);

int64_t ndata = prob->shape[prob->ndim - 1];
const float* __restrict p_prob =
static_cast<float*>(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * ndata);
constexpr double one = 1.0f - 1e-5f;

if (top_p >= one) {
// Specially handle case where top_p == 1.
double prob_sum = 0.0f;
for (int64_t i = 0; i < ndata; ++i) {
prob_sum += p_prob[i];
if (prob_sum >= uniform_sample) {
return i;
}
}
LOG(INFO) << "prob sum = " << prob_sum << ", sample = " << uniform_sample;
ICHECK(false) << "Possibly prob distribution contains NAN.";
}

// Key observation: when we are doing top_p sampling
// usually we only need to preserve some of the elements with
// high probabilities before we do sort
thread_local std::vector<std::pair<float, int>> data;

auto sample_top_p_with_filter = [&](float cuttoff) -> int64_t {
data.clear();
// filter the data with cuttoff
float cutoff_sum = 0.0f;
for (int64_t i = 0; i < ndata; ++i) {
if (p_prob[i] >= cuttoff) {
cutoff_sum += p_prob[i];
data.emplace_back(std::make_pair(p_prob[i], static_cast<int>(i)));
if (cutoff_sum > 1 - cuttoff) {
// Short cut. When the remaining parts cannot have total
// probability larger than cutoff, we can quit.
break;
}
}
}
if (data.size() == 0) return -1;
auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, int>& rhs) {
return lhs.first > rhs.first;
};
std::sort(data.begin(), data.end(), fcmp);

// short cut, if we know that
// uniform sample < p[0] / top_p
// we know that unform_sample < p[0] / top_p_sum
// because top_p_sum guarantees to be smaller than top_p
// so we can simply return the argmax sample
// without computing anything
if (uniform_sample < data[0].first / top_p) return data[0].second;

// compute top_p_sum
float cum_sum_prob = 0.0f;
float top_p_sum = 0.0f;
for (auto it = data.begin(); it != data.end(); ++it) {
float prob = it->first;
if (cum_sum_prob < top_p) {
top_p_sum += prob;
} else {
// we get to the right cutoff pt
break;
}
cum_sum_prob += prob;
it->first = cum_sum_prob;
}
// we find that the current total sum by the given cutoff
// is not sufficient to cover everything
// this means we might need to retry a smaller cutoff pt.
if (cum_sum_prob < top_p && cuttoff != 0.0f) return -1;

for (auto it = data.begin(); it != data.end(); ++it) {
if (uniform_sample < it->first / top_p_sum) {
return it->second;
}
}
return data[data.size() - 1].second;
};

if (top_p < 1) {
// sample through cutoff by a number
// by pigeonhole principle we will get at most 1024 elements
// usually it is much less by applying this filtering(order of 10 - 20)
data.reserve(256);
int64_t sampled_index = sample_top_p_with_filter(top_p / 1024);
if (sampled_index >= 0) return sampled_index;
}
// fallback via full prob, rare case
data.reserve(ndata);
int64_t sampled_index = sample_top_p_with_filter(0.0f);
ICHECK_GE(sampled_index, 0);
return sampled_index;
}

namespace detail {

// The detailed implementation of `parallel_for_with_threading_backend`.
// To avoid template expansion, the implementation cannot be placed
// in .cc files.

template <typename T>
struct ParallelForWithThreadingBackendLambdaInvoker {
static int TVMParallelLambdaInvoke(int task_id, TVMParallelGroupEnv* penv, void* cdata) {
int num_task = penv->num_task;
// Convert void* back to lambda type.
T* lambda_ptr = static_cast<T*>(cdata);
// Invoke the lambda with the task id (thread id).
(*lambda_ptr)(task_id, num_task);
return 0;
}
};

template <typename T>
inline void parallel_launch_with_threading_backend(T flambda) {
// Launch the lambda by passing its address.
void* cdata = &flambda;
TVMBackendParallelLaunch(ParallelForWithThreadingBackendLambdaInvoker<T>::TVMParallelLambdaInvoke,
cdata, /*num_task=*/0);
}

} // namespace detail

template <typename T>
inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end) {
auto flaunch = [begin, end, flambda](int task_id, int num_task) {
// For each thread, do static division and call into flambda.
int64_t total_len = end - begin;
int64_t step = (total_len + num_task - 1) / num_task;
int64_t local_begin = std::min(begin + step * task_id, end);
int64_t local_end = std::min(local_begin + step, end);
for (int64_t i = local_begin; i < local_end; ++i) {
flambda(i);
}
};
// Launch with all threads.
detail::parallel_launch_with_threading_backend(flaunch);
}

} // namespace serve
} // namespace llm
} // namespace mlc

0 comments on commit cef6522

Please sign in to comment.