diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c8b3189e6..66a1cb90a1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,11 +4,11 @@ project(mlc_llm C CXX) include(CheckCXXCompilerFlag) if(NOT MSVC) check_cxx_compiler_flag("-std=c++17" SUPPORT_CXX17) - set(CMAKE_CXX_FLAGS "-std=c++17 -O3 -fopenmp -ffast-math -march=native ${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS "-std=c++17 -O3 -ffast-math -march=native ${CMAKE_CXX_FLAGS}") set(CMAKE_CUDA_STANDARD 17) else() check_cxx_compiler_flag("/std:c++17" SUPPORT_CXX17) - set(CMAKE_CXX_FLAGS "/std:c++17 -O3 -fopenmp -ffast-math -march=native ${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS "/std:c++17 -O3 -ffast-math -march=native ${CMAKE_CXX_FLAGS}") set(CMAKE_CUDA_STANDARD 17) endif() diff --git a/cpp/serve/sampler.cc b/cpp/serve/sampler.cc index 632d56374d..7012b22869 100644 --- a/cpp/serve/sampler.cc +++ b/cpp/serve/sampler.cc @@ -7,7 +7,7 @@ #include "sampler.h" -#include +#include #include #include #include @@ -22,6 +22,30 @@ namespace mlc { namespace llm { namespace serve { +int SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample); + +/*! + * \brief Execute the given lambda function in parallel with + * threading backend in TVM. + * \tparam T The type of the lambda: "void (int i)". + * \param flambda The lambda to be executed in parallel. + * It should have the signature "void (int i)". + * \param begin The start index of this parallel loop (inclusive). + * \param end The end index of this parallel loop (exclusive). + * \example + * + * The for loop + * for (int i = 0; i < 10; i++) { + * a[i] = i; + * } + * should work the same as: + * parallel_for_with_threading_backend([&a](int i) { + * a[i] = i; + * }, 0, 10); + */ +template +inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end); + /*! * \brief The sampler runtime module. * It contains functions to @@ -31,16 +55,6 @@ namespace serve { class SamplerModule : public ModuleNode { public: explicit SamplerModule(DLDevice device) : device_(device), rng_(RandomGenerator::GetInstance()) { - // Set number of sampling threads from OpenMP. - num_threads_ = static_cast(omp_get_max_threads()) / 2; - - // Set sampling function. - auto fsample_topp_from_prob_ptr = - tvm::runtime::Registry::Get("vm.builtin.sample_top_p_from_prob"); - ICHECK(fsample_topp_from_prob_ptr) - << "Cannot find env function vm.builtin.sample_top_p_from_prob"; - fsample_topp_from_prob_ = *fsample_topp_from_prob_ptr; - // Set customized "logits -> prob" function. const PackedFunc* f_logits_to_probs = Registry::Get("mlc.llm.compute_probs_from_logits_inplace"); @@ -115,22 +129,23 @@ class SamplerModule : public ModuleNode { return; } - int n = logits->shape[0]; -#pragma omp parallel for num_threads(this->num_threads_) - for (int i = 0; i < n; ++i) { - // - Apply repetition penalty (inplace). - if (generation_cfg[i]->repetition_penalty != 1.0) { - ApplyRepetitionPenaltyOnCPU(logits, i, states[i], generation_cfg[i]->repetition_penalty); - } - // - Compute probability (inplace) from logits. - // Using softmax if temperature is non-zero. - // Or set probability of the max-logit position to 1. - if (generation_cfg[i]->temperature >= 1e-6) { - ApplySoftmaxWithTemperatureOnCPU(logits, i, generation_cfg[i]->temperature); - } else { - SetProbWithArgmaxOnCPU(logits, i); - } - } + parallel_for_with_threading_backend( + [this, &logits, &states, &generation_cfg](int i) { + // - Apply repetition penalty (inplace). + if (generation_cfg[i]->repetition_penalty != 1.0) { + ApplyRepetitionPenaltyOnCPU(logits, i, states[i], + generation_cfg[i]->repetition_penalty); + } + // - Compute probability (inplace) from logits. + // Using softmax if temperature is non-zero. + // Or set probability of the max-logit position to 1. + if (generation_cfg[i]->temperature >= 1e-6) { + ApplySoftmaxWithTemperatureOnCPU(logits, i, generation_cfg[i]->temperature); + } else { + SetProbWithArgmaxOnCPU(logits, i); + } + }, + 0, logits->shape[0]); } /*! @@ -153,12 +168,13 @@ class SamplerModule : public ModuleNode { random_numbers.push_back(rng_.GetRandomNumber()); } -#pragma omp parallel for num_threads(this->num_threads_) - for (int i = 0; i < n; ++i) { - // Sample top p from probability. - sampled_tokens[i] = - fsample_topp_from_prob_(probs, i, generation_cfg[i]->top_p, random_numbers[i]); - } + parallel_for_with_threading_backend( + [&sampled_tokens, &probs, &generation_cfg, &random_numbers](int i) { + // Sample top p from probability. + sampled_tokens[i] = + SampleTopPFromProb(probs, i, generation_cfg[i]->top_p, random_numbers[i]); + }, + 0, n); return ShapeTuple(sampled_tokens.begin(), sampled_tokens.end()); } @@ -243,14 +259,10 @@ class SamplerModule : public ModuleNode { /*! \brief The runtime device where the input logits is. */ DLDevice device_; - /*! \brief Number of CPU threads for parallel sampling and other computation. */ - int num_threads_; /*! \brief The random generator. */ RandomGenerator& rng_; /*! \brief Customized function which computes prob distribution from logits */ PackedFunc flogits_to_probs_inplace_; - /*! \brief Function which samples a token from prob distribution with top_p value. */ - PackedFunc fsample_topp_from_prob_; }; tvm::runtime::Module CreateSamplerModule(DLDevice device) { @@ -258,6 +270,160 @@ tvm::runtime::Module CreateSamplerModule(DLDevice device) { return Module(n); } +int SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample) { + // prob: (*, v) + // The prob array may have arbitrary ndim and shape. + // The last dimension corresponds to the prob distribution size. + // We use the `unit_offset` parameter to determine which slice + // of the prob array we sample from. + + ICHECK(prob.IsContiguous()); + ICHECK(prob.DataType() == DataType::Float(32)); + + if (prob->device.device_type != kDLCPU) { + prob = prob.CopyTo(DLDevice{kDLCPU, 0}); + } + + ICHECK(prob->device.device_type == kDLCPU); + + int64_t ndata = prob->shape[prob->ndim - 1]; + const float* __restrict p_prob = + static_cast(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * ndata); + constexpr double one = 1.0f - 1e-5f; + + if (top_p >= one) { + // Specially handle case where top_p == 1. + double prob_sum = 0.0f; + for (int64_t i = 0; i < ndata; ++i) { + prob_sum += p_prob[i]; + if (prob_sum >= uniform_sample) { + return i; + } + } + LOG(INFO) << "prob sum = " << prob_sum << ", sample = " << uniform_sample; + ICHECK(false) << "Possibly prob distribution contains NAN."; + } + + // Key observation: when we are doing top_p sampling + // usually we only need to preserve some of the elements with + // high probabilities before we do sort + thread_local std::vector> data; + + auto sample_top_p_with_filter = [&](float cuttoff) -> int64_t { + data.clear(); + // filter the data with cuttoff + float cutoff_sum = 0.0f; + for (int64_t i = 0; i < ndata; ++i) { + if (p_prob[i] >= cuttoff) { + cutoff_sum += p_prob[i]; + data.emplace_back(std::make_pair(p_prob[i], static_cast(i))); + if (cutoff_sum > 1 - cuttoff) { + // Short cut. When the remaining parts cannot have total + // probability larger than cutoff, we can quit. + break; + } + } + } + if (data.size() == 0) return -1; + auto fcmp = [](const std::pair& lhs, const std::pair& rhs) { + return lhs.first > rhs.first; + }; + std::sort(data.begin(), data.end(), fcmp); + + // short cut, if we know that + // uniform sample < p[0] / top_p + // we know that unform_sample < p[0] / top_p_sum + // because top_p_sum guarantees to be smaller than top_p + // so we can simply return the argmax sample + // without computing anything + if (uniform_sample < data[0].first / top_p) return data[0].second; + + // compute top_p_sum + float cum_sum_prob = 0.0f; + float top_p_sum = 0.0f; + for (auto it = data.begin(); it != data.end(); ++it) { + float prob = it->first; + if (cum_sum_prob < top_p) { + top_p_sum += prob; + } else { + // we get to the right cutoff pt + break; + } + cum_sum_prob += prob; + it->first = cum_sum_prob; + } + // we find that the current total sum by the given cutoff + // is not sufficient to cover everything + // this means we might need to retry a smaller cutoff pt. + if (cum_sum_prob < top_p && cuttoff != 0.0f) return -1; + + for (auto it = data.begin(); it != data.end(); ++it) { + if (uniform_sample < it->first / top_p_sum) { + return it->second; + } + } + return data[data.size() - 1].second; + }; + + if (top_p < 1) { + // sample through cutoff by a number + // by pigeonhole principle we will get at most 1024 elements + // usually it is much less by applying this filtering(order of 10 - 20) + data.reserve(256); + int64_t sampled_index = sample_top_p_with_filter(top_p / 1024); + if (sampled_index >= 0) return sampled_index; + } + // fallback via full prob, rare case + data.reserve(ndata); + int64_t sampled_index = sample_top_p_with_filter(0.0f); + ICHECK_GE(sampled_index, 0); + return sampled_index; +} + +namespace detail { + +// The detailed implementation of `parallel_for_with_threading_backend`. +// To avoid template expansion, the implementation cannot be placed +// in .cc files. + +template +struct ParallelForWithThreadingBackendLambdaInvoker { + static int TVMParallelLambdaInvoke(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + int num_task = penv->num_task; + // Convert void* back to lambda type. + T* lambda_ptr = static_cast(cdata); + // Invoke the lambda with the task id (thread id). + (*lambda_ptr)(task_id, num_task); + return 0; + } +}; + +template +inline void parallel_launch_with_threading_backend(T flambda) { + // Launch the lambda by passing its address. + void* cdata = &flambda; + TVMBackendParallelLaunch(ParallelForWithThreadingBackendLambdaInvoker::TVMParallelLambdaInvoke, + cdata, /*num_task=*/0); +} + +} // namespace detail + +template +inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end) { + auto flaunch = [begin, end, flambda](int task_id, int num_task) { + // For each thread, do static division and call into flambda. + int64_t total_len = end - begin; + int64_t step = (total_len + num_task - 1) / num_task; + int64_t local_begin = std::min(begin + step * task_id, end); + int64_t local_end = std::min(local_begin + step, end); + for (int64_t i = local_begin; i < local_end; ++i) { + flambda(i); + } + }; + // Launch with all threads. + detail::parallel_launch_with_threading_backend(flaunch); +} + } // namespace serve } // namespace llm } // namespace mlc