From 26c1b8618d8831cfdc4da9851a35a74163deceb4 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 11 Nov 2023 15:49:57 -0500 Subject: [PATCH] [Serving] Support batched prefill and benchmark This PR supports the current serving framework with batched prefill, which helps improve the throughput of prefill. Some data structures are tweaked for less runtime overhead. This PR also brings the benchmark of serving engine with real-time dataset as input. --- 3rdparty/flashinfer | 2 +- cpp/serve/engine.cc | 389 ++++++++++++++---------- cpp/serve/function_table.cc | 8 +- cpp/serve/function_table.h | 1 + cpp/serve/model.cc | 155 ++++++---- cpp/serve/request_state.cc | 15 + cpp/serve/request_state.h | 23 +- python/mlc_chat/serve/engine.py | 12 +- tests/python/serve/benchmark.py | 155 +++++++--- tests/python/serve/evaluate_engine.py | 74 +++++ tests/python/serve/test_serve_engine.py | 5 +- 11 files changed, 561 insertions(+), 278 deletions(-) create mode 100644 tests/python/serve/evaluate_engine.py diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index a3fd01c48f..beb51bcb7a 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit a3fd01c48ff1d91c2f690688700c134e1cf0c161 +Subproject commit beb51bcb7a6381c73f099ff67c30c6c17bcd116a diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 03f9bc67cd..62127e92ac 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -4,6 +4,7 @@ * \brief The implementation for runtime module of serving engine module in MLC LLM. */ #define __STDC_FORMAT_MACROS +#define PICOJSON_USE_INT64 #include #include @@ -67,21 +68,23 @@ class Engine { // Step 1. Create models and their PackedFuncs. ICHECK(models_.empty()); models_.reserve(num_models); - fmodel_single_seq_prefill_.clear(); + fmodel_batch_prefill_.clear(); fmodel_decode_.clear(); fmodel_token_embed_.clear(); fmodel_add_new_sequence_.clear(); fmodel_remove_sequence_.clear(); fmodel_softmax_with_temperature_.clear(); + fmodel_get_num_available_pages_.clear(); for (int i = 0; i < num_models; ++i) { Module model = CreateModelModule(reload_libs[i], model_paths[i], devices_[i]); models_.push_back(model); - fmodel_single_seq_prefill_.push_back(model->GetFunction("single_seq_prefill")); + fmodel_batch_prefill_.push_back(model->GetFunction("batch_prefill")); fmodel_decode_.push_back(model->GetFunction("decode")); fmodel_token_embed_.push_back(model->GetFunction("token_embed")); fmodel_add_new_sequence_.push_back(model->GetFunction("add_new_sequence")); fmodel_remove_sequence_.push_back(model->GetFunction("remove_sequence")); fmodel_softmax_with_temperature_.push_back(model->GetFunction("softmax_with_temperature")); + fmodel_get_num_available_pages_.push_back(model->GetFunction("get_num_available_pages")); } // Step 2. Fetch max single sequence length from models. max_single_sequence_length_ = std::numeric_limits::max(); @@ -124,7 +127,8 @@ class Engine { */ void AddRequest(Request request) { waiting_queue_.push_back(request); - request_states_.emplace(request, RequestState(models_.size(), request->inputs)); + request_states_.emplace( + request, RequestState(models_.size(), request->inputs, GetInputLength(request->inputs))); } /*! \brief Abort the input request. */ @@ -193,42 +197,48 @@ class Engine { } current_total_seq_len_ = 0; - prefill_total_time = 0.0f; - decode_total_time = 0.0f; - prefill_total_length = 0; - decode_total_length = 0; + request_total_prefill_time_ = 0.0f; + request_total_decode_time_ = 0.0f; + engine_total_prefill_time_ = 0.0f; + engine_total_decode_time_ = 0.0f; + total_prefill_length_ = 0; + total_decode_length_ = 0; tokenize_cache_.clear(); } /*! * \brief Return the engine runtime statistics in JSON string. * We collect the following entries: - * - prefill token latency (s/tok): avg latency of processing one token in prefill - * - decode token latency (s/tok): avg latency of processing one token in decode - * - token throughput (tok/s): avg number of tokens processed per second (prefill + decode) + * - single token prefill latency (s/tok): avg latency of processing one token in prefill + * - single token decode latency (s/tok): avg latency of processing one token in decode + * - engine time for prefill (sec) + * - engine time for decode (sec) + * - total number of processed tokens in prefill. + * - total number of processed tokens in decode. * \return The statistics in JSON string. */ String StatisticsJSON() { picojson::object config; - config["prefill_token_latency"] = picojson::value(prefill_total_time / prefill_total_length); - config["decode_token_latency"] = picojson::value(decode_total_time / decode_total_length); - config["token_throughput"] = picojson::value((prefill_total_length + decode_total_length) / - (prefill_total_time + decode_total_time)); + config["single_token_prefill_latency"] = + picojson::value(request_total_prefill_time_ / total_prefill_length_); + config["single_token_decode_latency"] = + picojson::value(request_total_decode_time_ / total_decode_length_); + config["engine_total_prefill_time"] = picojson::value(engine_total_prefill_time_); + config["engine_total_decode_time"] = picojson::value(engine_total_decode_time_); + config["total_prefill_tokens"] = picojson::value(total_prefill_length_); + config["total_decode_tokens"] = picojson::value(total_decode_length_); return picojson::value(config).serialize(true); } private: /*! \brief Pick applicable requests and run prefill. */ bool StepPrefill() { - auto [requests, sample_new_token] = GetRequestsToPrefill(); - ICHECK_EQ(requests.size(), sample_new_token.size()); + auto [requests, states, sample_new_token] = GetRequestsToPrefill(); if (requests.empty()) { return false; } - // Collect ids of the first-time prefilled requests. - // We will set the prefill finish time for these requests. - std::unordered_set first_time_prefill_ids; + auto tstart = std::chrono::high_resolution_clock::now(); for (Request request : requests) { int req_id = running_queue_.size(); @@ -242,65 +252,83 @@ class Engine { running_queue_.push_back(request); // - Assign request id for the requests. AssignIDForRequest(request, req_id); - first_time_prefill_ids.insert(req_id); } - // NOTE: Right now only single-sequence prefill is supported. - // So we prefill the requests one by one for now. - for (int req_idx = 0; req_idx < static_cast(requests.size()); ++req_idx) { - Request request = requests[req_idx]; - RequestState& state = request_states_.at(request); - ICHECK_EQ(state.mstates.size(), models_.size()); - NDArray logits{nullptr}; - current_total_seq_len_ += GetRequestPrefillInputLength(request); - // - Prefill the inputs for each model. - for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - Module model = models_[model_id]; - RequestModelState mstate = state.mstates[model_id]; + int sum_prefill_lengths = 0; + NDArray logits_for_sample{nullptr}; + Array mstates_for_sample; + Array generation_cfg_for_sample; + mstates_for_sample.reserve(requests.size()); + generation_cfg_for_sample.reserve(requests.size()); + for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { + Module model = models_[model_id]; + auto [request_list, mstates, prefill_lengths] = + FilterPrefillRequests(requests, states, model_id); + Array embeddings; + std::vector request_ids; + embeddings.reserve(request_list.size()); + request_ids.reserve(request_list.size()); + for (int i = 0; i < static_cast(request_list.size()); ++i) { + Request request = request_list[i]; + int prefill_length = prefill_lengths[i]; + RequestModelState mstate = mstates[i]; + if (model_id == 0) { + // Accumulate the sequence length. + sum_prefill_lengths += prefill_length; + current_total_seq_len_ += prefill_length; + mstates_for_sample.push_back(mstate); + generation_cfg_for_sample.push_back(request->generation_cfg); + } ICHECK(mstate->draft_output_tokens.empty()); ICHECK(mstate->draft_output_token_prob.empty()); ICHECK(mstate->draft_output_prob_dist.empty()); - if (mstate->inputs.empty()) { - continue; - } + ICHECK(!mstate->inputs.empty()); + request_ids.push_back(mstate->request_id); for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - NDArray embedding = GetEmbedding(mstate->inputs[i], fmodel_token_embed_[model_id]); - NDArray output_logits = - fmodel_single_seq_prefill_[model_id](embedding, mstate->request_id); - - // Only the output logits of the last input on the first model will be send for sampling. - if (model_id == 0 && i == static_cast(mstate->inputs.size()) - 1) { - logits = output_logits; - } + embeddings.push_back(GetEmbedding(mstate->inputs[i], fmodel_token_embed_[model_id])); } // Clean up `inputs` after prefill mstate->inputs.clear(); } - if (!sample_new_token[req_idx]) { - // This request does not need to sample a new token. - // In this case, it must not be a first-time prefilled request. - ICHECK(first_time_prefill_ids.count(state.mstates[0]->request_id)); - continue; - } - ICHECK(logits.defined()); + NDArray logits = fmodel_batch_prefill_[model_id]( + embeddings, ShapeTuple(request_ids.begin(), request_ids.end()), prefill_lengths); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], 1); + ICHECK_EQ(logits->shape[1], request_list.size()); - ShapeTuple next_token = SampleTokens(logits, /*model_id=*/0, /*sampler_id=*/0, - {state.mstates[0]}, {request->generation_cfg}); - ICHECK_EQ(next_token.size(), 1); + if (model_id == 0) { + // We only need to sample for model 0 in prefill. + logits_for_sample = logits; + } + } + if (sample_new_token) { + // - Sample tokens. + int num_requests = requests.size(); + ICHECK(logits_for_sample.defined()); + ICHECK_EQ(logits_for_sample->shape[1], num_requests); + ICHECK_EQ(mstates_for_sample.size(), num_requests); + ICHECK_EQ(generation_cfg_for_sample.size(), num_requests); + logits_for_sample = logits_for_sample.CreateView( + {num_requests, 1, logits_for_sample->shape[2]}, logits_for_sample->dtype); + ShapeTuple next_tokens = SampleTokens(logits_for_sample, /*model_id=*/0, /*sampler_id=*/0, + mstates_for_sample, generation_cfg_for_sample); + ICHECK_EQ(next_tokens.size(), num_requests); // - Update the committed tokens of states. - for (RequestModelState mstate : state.mstates) { - mstate->committed_tokens.push_back(next_token[0]); - } - // If the request is first-time prefilled, set the prefill finish time. - if (first_time_prefill_ids.count(state.mstates[0]->request_id)) { - state.tprefill_finish = std::chrono::high_resolution_clock::now(); + // - If a request is first-time prefilled, set the prefill finish time. + auto tnow = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < num_requests; ++i) { + mstates_for_sample[i]->committed_tokens.push_back(next_tokens[i]); + if (mstates_for_sample[i]->committed_tokens.size() == 1) { + request_states_.at(requests[i])->tprefill_finish = tnow; + } } } + + auto tend = std::chrono::high_resolution_clock::now(); + engine_total_prefill_time_ += static_cast((tend - tstart).count()) / 1e9; + return true; } @@ -311,18 +339,19 @@ class Engine { return false; } - Array requests = GetRequestsToDecode(); - if (requests.empty()) { + PreemptUnfittableRequests(); + if (running_queue_.empty()) { return false; } + auto tstart = std::chrono::high_resolution_clock::now(); + // NOTE: Right now we only support decode all the running requests at a time. - ICHECK_EQ(requests.size(), running_queue_.size()); - int num_requests = requests.size(); + int num_requests = running_queue_.size(); // Check if the requests ids are in an ascending order. for (int i = 1; i < num_requests; ++i) { - ICHECK_GT(request_states_.at(requests[i]).mstates[0]->request_id, - request_states_.at(requests[i - 1]).mstates[0]->request_id); + ICHECK_GT(request_states_.at(running_queue_[i])->mstates[0]->request_id, + request_states_.at(running_queue_[i - 1])->mstates[0]->request_id); } current_total_seq_len_ += num_requests; @@ -337,10 +366,10 @@ class Engine { inputs.reserve(num_requests); mstates.reserve(num_requests); generation_cfg.reserve(num_requests); - for (Request request : requests) { + for (Request request : running_queue_) { RequestState& state = request_states_.at(request); - inputs.push_back(TokenData(ShapeTuple({state.mstates[0]->committed_tokens.back()}))); - mstates.push_back(state.mstates[0]); + inputs.push_back(TokenData(ShapeTuple({state->mstates[0]->committed_tokens.back()}))); + mstates.push_back(state->mstates[0]); generation_cfg.push_back(request->generation_cfg); } @@ -363,6 +392,10 @@ class Engine { for (int i = 0; i < num_requests; ++i) { mstates[i]->committed_tokens.push_back(next_tokens[i]); } + + auto tend = std::chrono::high_resolution_clock::now(); + engine_total_decode_time_ += static_cast((tend - tstart).count()) / 1e9; + return true; } @@ -413,8 +446,9 @@ class Engine { // The request to abort is in running queue int req_id = it_running - running_queue_.begin(); running_queue_.erase(it_running); - current_total_seq_len_ -= GetRequestRawInputLength(request) + - request_states_.at(request).mstates[0]->committed_tokens.size() - 1; + RequestState state = request_states_.at(request); + current_total_seq_len_ -= + state->raw_input_length + state->mstates[0]->committed_tokens.size() - 1; RemoveSequenceFromModels(req_id); UpdateRequestIDAfterRemoval(req_id); } else { @@ -440,8 +474,8 @@ class Engine { // - Update `inputs` for future prefill. RequestState& state = request_states_.at(request); current_total_seq_len_ -= - GetRequestRawInputLength(request) + state.mstates[0]->committed_tokens.size() - 1; - for (RequestModelState mstate : state.mstates) { + state->raw_input_length + state->mstates[0]->committed_tokens.size() - 1; + for (RequestModelState mstate : state->mstates) { mstate->request_id = -1; mstate->draft_output_tokens.clear(); mstate->draft_output_token_prob.clear(); @@ -465,51 +499,78 @@ class Engine { * additionally return a boolean flag indicating if a new * token needs to be sampled from logits after prefill. */ - std::pair, std::vector> GetRequestsToPrefill() { - // NOTE: Right now we only support single-sequence prefill. + std::tuple, Array, bool> GetRequestsToPrefill() { + // - Try to prefill pending requests. + std::vector prefill_requests; + std::vector states; if (!waiting_queue_.empty()) { - Array prefill_requests{waiting_queue_.front()}; - if (CanPrefill(prefill_requests)) { + int total_input_length = 0; + int total_required_pages = 0; + ICHECK(fmodel_get_num_available_pages_[0].defined()); + int num_available_pages = fmodel_get_num_available_pages_[0](); + + for (int i = 0; i < static_cast(waiting_queue_.size()); ++i) { + Request request = waiting_queue_[i]; + RequestState state = request_states_.at(request); + int input_length = GetInputLength(state->mstates[0]->inputs); + int num_require_pages = + (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (CanPrefill(i + 1, total_input_length, total_required_pages, num_available_pages)) { + prefill_requests.push_back(request); + states.push_back(state); + } else { + total_input_length -= input_length; + total_required_pages -= num_require_pages; + break; + } + } + if (!prefill_requests.empty()) { // Need to sample a new token for waiting requests. - return {prefill_requests, {true}}; + return {prefill_requests, states, true}; } } + // Try to prefill for small models. for (Request request : running_queue_) { - Array mstates = request_states_.at(request).mstates; + RequestState state = request_states_.at(request); + Array mstates = state->mstates; for (int i = 0; i < static_cast(mstates.size()); ++i) { if (!mstates[i]->inputs.empty()) { ICHECK_NE(i, 0); - // This return happens only for "small" models in - // speculative inference settings. - // Therefore no need to sample new token from logits. - return {{request}, {false}}; + prefill_requests.push_back(request); + states.push_back(state); + break; } } } - - return {}; + // This return happens only for "small" models in + // speculative inference settings. + // Therefore no need to sample new token from logits. + return {prefill_requests, states, false}; } - /*! \brief Find requests to decode. */ - Array GetRequestsToDecode() { + /*! \brief Preempt the requests unfittable for decode. */ + void PreemptUnfittableRequests() { if (running_queue_.empty()) { - return {}; + return; } - // NOTE: Right now we only support decode all the running requests at a time. - Array requests(running_queue_); - if (CanDecode(requests)) { - return requests; + int num_available_pages = fmodel_get_num_available_pages_[0](); + while (true) { + if (CanDecode(running_queue_.size())) { + break; + } + StepPreempt(running_queue_.back()); } - return {}; } /*! \brief Assign the given id for the given request. */ void AssignIDForRequest(Request request, int req_id) { // Set id in the request state. RequestState& state = request_states_.at(request); - for (RequestModelState mstate : state.mstates) { + for (RequestModelState mstate : state->mstates) { mstate->request_id = req_id; } // Add a new sequence to each model. @@ -521,9 +582,9 @@ class Engine { } /*! \brief Check if the input requests can be prefilled under conditions. */ - bool CanPrefill(Array requests) { + bool CanPrefill(int num_prefill_req, int total_input_length, int num_required_pages, + int num_available_pages) { int num_running_requests = running_queue_.size(); - int num_prefill_req = requests.size(); ICHECK_LE(num_running_requests, kv_cache_config_->max_num_sequence); // No exceeding of the maximum allowed requests that can @@ -532,55 +593,53 @@ class Engine { return false; } - // The total length + the maximum allowed single-sequence length does - // not exceed the maximum allowed total length. - // NOTE: this condition is heuristic and can be revised. - int total_input_length = 0; - int total_max_new_tokens = 0; - for (Request request : requests) { - total_input_length += GetRequestPrefillInputLength(request); - total_max_new_tokens += request->generation_cfg->max_new_tokens; - } - return current_total_seq_len_ + std::min(total_input_length + total_max_new_tokens, - num_prefill_req * max_single_sequence_length_) < - kv_cache_config_->max_total_sequence_length; + // NOTE: The conditions are heuristic and can be revised. + // Cond 1: total input length <= max allowed single sequence length. + // Cond 2: remaining pages >= 10, where 10 is a watermark number can + // be configured and adjusted in the future. + // Cond 3: at least one decode can be performed after prefill. + // Todo: move watermark to config. + int new_batch_size = num_running_requests + num_prefill_req; + return total_input_length <= max_single_sequence_length_ && + num_required_pages + new_batch_size <= num_available_pages && + current_total_seq_len_ + total_input_length + 8 * new_batch_size <= + kv_cache_config_->max_total_sequence_length; } /*! \brief Check if the input requests can be decoded under conditions. */ - bool CanDecode(Array requests) { - return current_total_seq_len_ + requests.size() < kv_cache_config_->max_total_sequence_length; + bool CanDecode(int num_requests) { + int num_available_pages = fmodel_get_num_available_pages_[0](); + return num_requests <= num_available_pages; } - /*! - * \brief Get the total equivalent **input length to prefill** - * of the given request's current state. - */ - int GetRequestPrefillInputLength(const Request& request) { - auto it = request_states_.find(request); - ICHECK(it != request_states_.end()); - - const RequestState& state = it->second; - ICHECK_EQ(state.mstates.size(), models_.size()); - int input_length = -1; - for (const RequestModelState& mstate : state.mstates) { - int length_sum = 0; - for (Data input : mstate->inputs) { - length_sum += GetInputLength(input); - } - if (input_length == -1) { - input_length = length_sum; - } else { - ICHECK_EQ(length_sum, input_length); + /*! \brief Filter the requests to prefill on the given model. */ + std::tuple, Array, ShapeTuple> FilterPrefillRequests( + Array requests, Array states, int model_id) { + ICHECK_EQ(requests.size(), states.size()); + int num_requests = requests.size(); + Array filtered_requests; + Array filtered_mstates; + std::vector prefill_length; + filtered_requests.reserve(num_requests); + filtered_mstates.reserve(num_requests); + prefill_length.reserve(num_requests); + + for (int i = 0; i < num_requests; ++i) { + int length = GetInputLength(states[i]->mstates[model_id]->inputs); + if (length > 0) { + filtered_requests.push_back(requests[i]); + filtered_mstates.push_back(states[i]->mstates[model_id]); + prefill_length.push_back(length); } } - ICHECK_NE(input_length, -1); - return input_length; + return {filtered_requests, filtered_mstates, + ShapeTuple(prefill_length.begin(), prefill_length.end())}; } - /*! \brief Get the total equivalent **input length** of the given request. */ - int GetRequestRawInputLength(const Request& request) { + /*! \brief Get the total input length of the given inputs. */ + int GetInputLength(Array inputs) { int length_sum = 0; - for (Data input : request->inputs) { + for (Data input : inputs) { length_sum += GetInputLength(input); } return length_sum; @@ -595,6 +654,7 @@ class Engine { return tokens_input->token_ids.size(); } else { ICHECK(false) << "Cannot reach here"; + throw; } } @@ -633,6 +693,7 @@ class Engine { return fmodel_token_embed(Array{tokens_input->token_ids}); } else { ICHECK(false) << "Cannot reach here"; + throw; } } @@ -785,7 +846,7 @@ class Engine { void UpdateRequestIDAfterRemoval(int removed_req_id) { for (auto& it : request_states_) { RequestState& state = it.second; - for (RequestModelState mstate : state.mstates) { + for (RequestModelState mstate : state->mstates) { ICHECK_NE(mstate->request_id, removed_req_id); if (mstate->request_id > removed_req_id) { --mstate->request_id; @@ -819,10 +880,10 @@ class Engine { running_queue_.erase(it); RequestState& state = request_states_.at(request); - int num_input_tokens = GetRequestRawInputLength(request); - int num_output_tokens = state.mstates[0]->committed_tokens.size() - 1; + int num_input_tokens = state->raw_input_length; + int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1; current_total_seq_len_ -= num_input_tokens + num_output_tokens; - for (RequestModelState mstate : state.mstates) { + for (RequestModelState mstate : state->mstates) { ICHECK_EQ(mstate->request_id, req_id); mstate->request_id = -1; } @@ -830,15 +891,19 @@ class Engine { UpdateRequestIDAfterRemoval(req_id); auto trequest_finish = std::chrono::high_resolution_clock::now(); - prefill_total_time += static_cast((state.tprefill_finish - state.tadd).count()) / 1e9; - prefill_total_length += num_input_tokens; - decode_total_time += - static_cast((trequest_finish - state.tprefill_finish).count()) / 1e9; - decode_total_length += num_output_tokens; + request_total_prefill_time_ += + static_cast((state->tprefill_finish - state->tadd).count()) / 1e9; + total_prefill_length_ += num_input_tokens; + request_total_decode_time_ += + static_cast((trequest_finish - state->tprefill_finish).count()) / 1e9; + total_decode_length_ += num_output_tokens; // NOTE: right now we only return the generated text. // In the future we might optional return text or token ids. - request->fcallback(request, TextData(state.output)); + String output = ftokenizer_decode(ShapeTuple(state->mstates[0]->committed_tokens.begin(), + state->mstates[0]->committed_tokens.end())); + state->output = output.operator std::string(); + request->fcallback(request, TextData(state->output)); // Remove the request from states. request_states_.erase(request); @@ -851,23 +916,18 @@ class Engine { // - Case 0. There is remaining draft output ==> Unfinished // All draft outputs are supposed to be processed before finish. - for (RequestModelState mstate : state.mstates) { + for (RequestModelState mstate : state->mstates) { if (!mstate->draft_output_tokens.empty()) { return false; } } // - Decode committed tokens. - const std::vector& committed_tokens = state.mstates[0]->committed_tokens; - String output = ftokenizer_decode(ShapeTuple(committed_tokens.begin(), committed_tokens.end())); - state.output = output.operator std::string(); + const std::vector& committed_tokens = state->mstates[0]->committed_tokens; // Case 1. Any of the stop strings appears in output ==> Finished - for (String stop_str : request->generation_cfg->stop_strs) { - if (state.output.rfind(stop_str) != std::string::npos) { - return true; - } - } + // Todo: handle stop_str by tokenizing. So that we don't detokenize during check + // Case 2. Any of the stop tokens appears in the committed tokens ===> Finished if (std::any_of(request->generation_cfg->stop_tokens.begin(), request->generation_cfg->stop_tokens.end(), [&committed_tokens](int32_t token) { @@ -880,7 +940,7 @@ class Engine { return true; } // Case 4. Total length of the request reaches the maximum single sequence length ==> Finished - if (GetRequestRawInputLength(request) + static_cast(committed_tokens.size()) >= + if (state->raw_input_length + static_cast(committed_tokens.size()) >= max_single_sequence_length_) { return true; } @@ -905,12 +965,13 @@ class Engine { NDArray logits_or_probs_on_cpu_{nullptr}; // PackedFuncs from model/tokenizer/sampler/env. - std::vector fmodel_single_seq_prefill_; + std::vector fmodel_batch_prefill_; std::vector fmodel_decode_; std::vector fmodel_token_embed_; std::vector fmodel_add_new_sequence_; std::vector fmodel_remove_sequence_; std::vector fmodel_softmax_with_temperature_; + std::vector fmodel_get_num_available_pages_; std::vector fsampler_require_gpu_softmax_; std::vector fsampler_compute_probs_from_logits_inplace_; std::vector fsampler_sample_token_from_probs_; @@ -919,10 +980,18 @@ class Engine { // Runtime statistics int64_t current_total_seq_len_; - double prefill_total_time = 0; - double decode_total_time = 0; - int64_t prefill_total_length = 0; - int64_t decode_total_length = 0; + /*! \brief The sum of "prefill time of each request". */ + double request_total_prefill_time_ = 0.0f; + /*! \brief The sum of "decode time of each request". */ + double request_total_decode_time_ = 0.0f; + /*! \brief The total engine time on prefill. */ + double engine_total_prefill_time_ = 0.0f; + /*! \brief The total engine time on decode. */ + double engine_total_decode_time_ = 0.0f; + /*! \brief The total number of processed tokens in prefill. */ + int64_t total_prefill_length_ = 0; + /*! \brief The total number of processed tokens in decode. */ + int64_t total_decode_length_ = 0; // Tokenization cache std::unordered_map tokenize_cache_; diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index be77c16970..2e9e1469ba 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -7,11 +7,11 @@ #include "function_table.h" #include +#include #include #include #include #include -#include #include #include @@ -108,8 +108,8 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, int num_shards) this->local_vm = fload_exec(); this->local_vm->GetFunction("vm_initialization")( static_cast(device.device_type), device.device_id, - static_cast(relax_vm::AllocatorType::kPooled), static_cast(kDLCPU), 0, - static_cast(relax_vm::AllocatorType::kPooled)); + static_cast(tvm::runtime::memory::AllocatorType::kPooled), static_cast(kDLCPU), 0, + static_cast(tvm::runtime::memory::AllocatorType::kPooled)); this->mod_get_func = [this](const std::string& name) -> PackedFunc { return this->local_vm->GetFunction(name, false); }; @@ -169,6 +169,8 @@ void FunctionTable::_InitFunctions() { get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device"); this->remove_from_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_remove"); this->popn_from_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn"); + this->get_num_available_pages_kv_cache_func_ = + get_global_func("vm.builtin.paged_attention_kv_cache_get_num_available_pages"); support_backtracking_kv_ = true; } diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index ef9b3a5976..640785389a 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -69,6 +69,7 @@ struct FunctionTable { PackedFunc sync_device_kv_cache_func_; PackedFunc remove_from_kv_cache_func_; PackedFunc popn_from_kv_cache_func_; + PackedFunc get_num_available_pages_kv_cache_func_; }; } // namespace serve diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 4fb85bfc52..9c2a1e768b 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -77,10 +77,10 @@ class ModelModule : public ModuleNode { CHECK_EQ(args.size(), 1); *rv = TokenEmbed(args[0]); }); - } else if (name == "single_seq_prefill") { + } else if (name == "batch_prefill") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 2); - *rv = SingleSequencePrefill(args[0], args[1]); + CHECK_EQ(args.size(), 3); + *rv = BatchPrefill(args[0], args[1], args[2]); }); } else if (name == "decode") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { @@ -115,15 +115,18 @@ class ModelModule : public ModuleNode { ICHECK_EQ(args.size(), 0); Reset(); }); + } else if (name == "get_num_available_pages") { + return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size(), 0); + ICHECK(kv_cache_.defined()); + *rv = ft_.get_num_available_pages_kv_cache_func_(kv_cache_); + }); } else if (name == "get_max_window_size") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size(), 0); CHECK_NE(max_window_size_, -1) << "The model has not been initialized"; *rv = max_window_size_; }); - } else if (name == "runtime_stats_text") { - // Todo: JSON style - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = GetStats(); }); } else { return PackedFunc(nullptr); } @@ -156,7 +159,8 @@ class ModelModule : public ModuleNode { } // Copy input token ids to device. DLDataType dtype(DataType::Int(32)); - NDArray token_ids_nd = CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, 2048); + NDArray token_ids_nd = + CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, max_window_size_); ICHECK_EQ(token_ids_nd->ndim, 1); ICHECK_EQ(token_ids_nd->shape[0], total_length); token_ids_nd = token_ids_nd.CreateView({1, total_length}, dtype); @@ -165,12 +169,7 @@ class ModelModule : public ModuleNode { << "`embed` function is not found in the model. Please make sure the model is compiled " "with flag `--sep-embed` and `--enable-batching`"; - auto tstart = std::chrono::high_resolution_clock::now(); NDArray embeddings = ft_.embed_func_(ft_.CopyToWorker0(token_ids_nd), params_); - auto tend = std::chrono::high_resolution_clock::now(); - - this->embed_total_time += static_cast((tend - tstart).count()) / 1e9; - this->embed_total_tokens += total_length; // embeddings: (1, total_length, hidden_size) ICHECK_EQ(embeddings->ndim, 3); @@ -183,14 +182,33 @@ class ModelModule : public ModuleNode { * \brief Single-sequence prefill function. Embedding in, logits out. * \param embeddings The embedding of the input to be prefilled. * \param seq_id The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. * \return The logits for the next token. */ - NDArray SingleSequencePrefill(NDArray embeddings, int seq_id) { + NDArray BatchPrefill(Array embedding_arr, ShapeTuple seq_ids, ShapeTuple lengths) { + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + std::vector logit_pos; + logit_pos.reserve(num_sequences); + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + logit_pos.push_back(total_length); + if (i > 0) { + CHECK_GT(seq_ids[i], seq_ids[i - 1]) << "The input sequence ids must be non-decreasing."; + } + } + // embeddings: (1, n, h) - CHECK_EQ(embeddings->ndim, 3); - CHECK_EQ(embeddings->shape[0], 1); - CHECK_EQ(embeddings->device.device_type, device_.device_type); - CHECK_EQ(embeddings->device.device_id, device_.device_id); + NDArray embeddings = ConcatEmbeddings(std::move(embedding_arr), total_length); + ICHECK_EQ(embeddings->ndim, 3); + ICHECK_EQ(embeddings->shape[0], 1); + ICHECK_EQ(embeddings->shape[1], total_length); + ICHECK_EQ(embeddings->device.device_type, device_.device_type); + ICHECK_EQ(embeddings->device.device_id, device_.device_id); + + NDArray logit_pos_nd = CopyArrayToDevice(logit_pos, &logit_pos_arr_, DataType::Int(32), 32); CHECK(ft_.prefill_func_.defined()) << "`prefill_with_embed` function is not found in the model. Please make sure the model is " @@ -202,22 +220,20 @@ class ModelModule : public ModuleNode { // Reserve in KV cache for the length of the input. ft_.reset_append_length_kv_cache_func_(kv_cache_); - ft_.reserve_length_in_kv_cache_func_(kv_cache_, seq_id, /*length=*/embeddings->shape[1]); + for (int i = 0; i < num_sequences; ++i) { + ft_.reserve_length_in_kv_cache_func_(kv_cache_, seq_ids[i], lengths[i]); + } ft_.sync_device_kv_cache_func_(kv_cache_); - auto tstart = std::chrono::high_resolution_clock::now(); - // args: embeddings, kv_cache, params - Array ret = ft_.prefill_func_(ft_.CopyToWorker0(embeddings), kv_cache_, params_); - auto tend = std::chrono::high_resolution_clock::now(); - - this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9; - this->prefill_total_tokens += embeddings->shape[1]; + // args: embeddings, logit_pos, kv_cache, params + Array ret = + ft_.prefill_func_(ft_.CopyToWorker0(embeddings), logit_pos_nd, kv_cache_, params_); - // logits: (1, 1, v) + // logits: (1, num_sequences, v) NDArray logits = Downcast(ret[0]); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], 1); + ICHECK_EQ(logits->shape[1], num_sequences); return logits; } @@ -251,13 +267,8 @@ class ModelModule : public ModuleNode { } ft_.sync_device_kv_cache_func_(kv_cache_); - auto tstart = std::chrono::high_resolution_clock::now(); // args: embeddings, kv_cache, params Array ret = ft_.decode_func_(ft_.CopyToWorker0(embeddings), kv_cache_, params_); - auto tend = std::chrono::high_resolution_clock::now(); - - this->decode_total_time += static_cast((tend - tstart).count()) / 1e9; - this->decode_total_tokens += embeddings->shape[0]; // logits: (b, 1, v) NDArray logits = Downcast(ret[0]); @@ -286,7 +297,7 @@ class ModelModule : public ModuleNode { for (GenerationConfig cfg : generation_cfg) { temperatures.push_back(cfg->temperature); } - NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 16); + NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 32); ICHECK_EQ(temperatures_nd->ndim, 1); ICHECK_EQ(temperatures_nd->shape[0], batch_size); @@ -318,6 +329,57 @@ class ModelModule : public ModuleNode { return view; } + /*! \brief Concatenate the input embeddings. */ + NDArray ConcatEmbeddings(Array embedding_arr, int64_t total_length) { + ICHECK(!embedding_arr.empty()); + int hidden_size = -1; + DataType dtype; + for (NDArray inp_embeddings : embedding_arr) { + // inp_embedding: (1, n, h) + CHECK_EQ(inp_embeddings->ndim, 3); + CHECK_EQ(inp_embeddings->shape[0], 1); + CHECK_EQ(inp_embeddings->device.device_type, device_.device_type); + CHECK_EQ(inp_embeddings->device.device_id, device_.device_id); + if (hidden_size == -1) { + hidden_size = inp_embeddings->shape[2]; + dtype = inp_embeddings.DataType(); + } else { + CHECK_EQ(inp_embeddings->shape[2], hidden_size); + CHECK_EQ(inp_embeddings.DataType(), dtype); + } + } + + // - Resize the shared embedding array. + if (embeddings_.defined()) { + ICHECK_EQ(embeddings_->ndim, 3); + ICHECK_EQ(embeddings_->shape[0], 1); + ICHECK_EQ(embeddings_->shape[2], hidden_size); + } + int64_t init_size = embeddings_.defined() ? embeddings_->shape[1] : max_window_size_; + while (init_size < total_length) { + init_size *= 2; + } + if (!embeddings_.defined() || init_size != embeddings_->shape[1]) { + embeddings_ = NDArray::Empty({1, init_size, hidden_size}, dtype, device_); + } + + // - Copy input embeddings. + int64_t start_pos = 0; + for (NDArray inp_embeddings : embedding_arr) { + int64_t length = inp_embeddings->shape[1]; + CHECK_LE(start_pos + length, total_length); + + DLTensor copy_dst = *(embeddings_.operator->()); + copy_dst.byte_offset = start_pos * hidden_size * dtype.bytes(); + copy_dst.shape = inp_embeddings->shape; + NDArray::CopyFromTo(inp_embeddings.operator->(), ©_dst); + + start_pos += length; + } + CHECK_EQ(start_pos, total_length); + return embeddings_.CreateView({1, total_length, hidden_size}, dtype); + } + /*! \brief Load model configuration from JSON. */ void LoadModelConfigJSON(const std::string& config_str) { picojson::value config_json; @@ -350,37 +412,12 @@ class ModelModule : public ModuleNode { /*! \brief reset the runtime states. */ void Reset() { - // Reset the statistics. - this->embed_total_tokens = 0; - this->prefill_total_tokens = 0; - this->decode_total_tokens = 0; - this->embed_total_time = 0; - this->prefill_total_time = 0; - this->decode_total_time = 0; // Reset the KV cache. if (kv_cache_.defined()) { ft_.reset_kv_cache_func_(kv_cache_); } } - /*! \brief Return statistics in JSON format. */ - String GetStats() { - picojson::object stats; - stats["prefill_speed"] = picojson::value(prefill_total_tokens / prefill_total_time); - stats["decode_speed"] = picojson::value(decode_total_tokens / decode_total_time); - stats["embed_speed"] = picojson::value(embed_total_tokens / embed_total_time); - return picojson::value(stats).serialize(true); - } - - //---------------------------- - // Statistics - //---------------------------- - double embed_total_time = 0; - double decode_total_time = 0; - double prefill_total_time = 0; - int64_t embed_total_tokens = 0; - int64_t decode_total_tokens = 0; - int64_t prefill_total_tokens = 0; //---------------------------- // Model configurations //---------------------------- @@ -400,6 +437,8 @@ class ModelModule : public ModuleNode { ObjectRef params_; // Shared NDArray NDArray input_token_ids_{nullptr}; + NDArray embeddings_{nullptr}; + NDArray logit_pos_arr_{nullptr}; NDArray temperature_arr_{nullptr}; }; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 8e86682d78..38b5a7bf66 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -23,6 +23,21 @@ RequestModelState::RequestModelState(int model_id, Array inputs) { data_ = std::move(n); } +TVM_REGISTER_OBJECT_TYPE(RequestStateNode); + +RequestState::RequestState(int num_models, Array inputs, int raw_input_length) { + ObjectPtr n = make_object(); + Array mstates; + mstates.reserve(num_models); + for (int i = 0; i < num_models; ++i) { + mstates.push_back(RequestModelState(i, inputs)); + } + n->mstates = std::move(mstates); + n->raw_input_length = raw_input_length; + n->tadd = std::chrono::high_resolution_clock::now(); + data_ = std::move(n); +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 1de1f5e3f2..16f39548c2 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -86,13 +86,16 @@ class RequestModelState : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestModelState, ObjectRef, RequestModelStateNode); }; -struct RequestState { +class RequestStateNode : public Object { + public: /*! * \brief The state with regard to each model. * \sa RequestModelState */ Array mstates; + /*! \brief The summed up input length of the request. */ + int raw_input_length = 0; /*! \brief The decoded text string output. */ std::string output = ""; @@ -101,13 +104,17 @@ struct RequestState { /*! \brief The time of finishing prefill stage. */ std::chrono::_V2::system_clock::time_point tprefill_finish; - explicit RequestState(int num_models, Array inputs) { - mstates.reserve(num_models); - for (int i = 0; i < num_models; ++i) { - mstates.push_back(RequestModelState(i, inputs)); - } - tadd = std::chrono::high_resolution_clock::now(); - } + static constexpr const char* _type_key = "mlc.serve.RequestState"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(RequestStateNode, Object); +}; + +class RequestState : public ObjectRef { + public: + explicit RequestState(int num_models, Array inputs, int raw_input_length); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestState, ObjectRef, RequestStateNode); }; } // namespace serve diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py index da6ada6dad..f9c0ec6c99 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_chat/serve/engine.py @@ -221,12 +221,12 @@ def reset(self) -> None: def stats(self) -> Dict[str, float]: """The engine runtime statistics. We collect the following entries: - - prefill token latency (s/tok) - avg latency of processing one token in prefill - - decode token latency (s/tok) - avg latency of processing one token in decode - - token throughput (tok/s) - avg number of tokens processed per second (prefill + decode) + - single token prefill latency (s/tok): avg latency of processing one token in prefill + - single token decode latency (s/tok): avg latency of processing one token in decode + - engine time for prefill (sec) + - engine time for decode (sec) + - total number of processed tokens in prefill. + - total number of processed tokens in decode. """ stats_json_str = self._get_stats_func() stats = json.loads(stats_json_str) diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py index 19d46f99ba..1b1a7b71bb 100644 --- a/tests/python/serve/benchmark.py +++ b/tests/python/serve/benchmark.py @@ -1,81 +1,156 @@ -# pylint: disable=missing-docstring +# pylint: disable=line-too-long,missing-docstring,no-member,too-many-locals +# type: ignore import argparse +import json +import random import time -from typing import Any, Callable, List +from typing import Any, Callable, List, Tuple +import numpy as np +from transformers import AutoTokenizer + +from mlc_chat.chat_module import _get_model_path from mlc_chat.serve import GenerationConfig, KVCacheConfig from mlc_chat.serve.engine import Engine, ModelInfo def _parse_args(): args = argparse.ArgumentParser() - args.add_argument("--model-id", type=str, default="Llama-2-7b-chat-hf-q4f16_1") + args.add_argument("--model-id", type=str, default="Llama-2-7b-chat-hf-q0f16") + # Download dataset from + # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + args.add_argument("--dataset", type=str, required=True) args.add_argument("--device", type=str, default="auto") - args.add_argument("--input-length", type=int, default=32) - args.add_argument("--output-length", type=int, default=256) - args.add_argument("--batch-size", type=int, default=16) + args.add_argument("--num-prompts", type=int, default=500) + args.add_argument("--batch-size", type=int, default=80) + args.add_argument("--page-size", type=int, default=16) + args.add_argument("--max-total-seq-length", type=int, default=16800) + args.add_argument("--seed", type=int, default=0) + parsed = args.parse_args() + assert parsed.batch_size % 16 == 0 + assert parsed.page_size == 16 + assert parsed.max_total_seq_length >= 2048 return parsed -def time_evaluator(func: Callable, args: List[Any], num_runs: int = 3, num_warmups: int = 1): - # warmup run - print("Start warmup...") - for _ in range(num_warmups): - func(*args) +def sample_requests( + dataset_path: str, num_requests: int, model_id: str +) -> Tuple[List[str], List[GenerationConfig]]: + """Sample requests from dataset. + Acknowledgement to the benchmark scripts in the vLLM project. + """ + model_path, _ = _get_model_path(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + with open(dataset_path, encoding="utf-8") as f: + dataset = json.load(f) + + # Filter out the conversations with less than 2 turns. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + if len(data["conversations"]) >= 2 + ] + # Tokenize the prompts and completions. + prompts = [prompt for prompt, _ in dataset] + prompt_token_ids = tokenizer(prompts).input_ids + completions = [completion for _, completion in dataset] + completion_token_ids = tokenizer(completions).input_ids + tokenized_dataset = [] + for i in range(len(dataset)): + output_len = len(completion_token_ids[i]) + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + + # Filter out too long sequences. + filtered_dataset: List[Tuple[str, int, int]] = [] + for prompt, prompt_token_ids, output_len in tokenized_dataset: + prompt_len = len(prompt_token_ids) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + # Sample the requests. + sampled_requests = random.sample(filtered_dataset, num_requests) - total_time = 0.0 - for run in range(num_runs): - print(f"Evaluator: start round {run}") + # Construct generation config. + prompts = [prompt for prompt, _, _ in sampled_requests] + generation_config_list = [ + GenerationConfig(temperature=1.0, top_p=1.0, max_new_tokens=output_len) + for _, _, output_len in sampled_requests + ] + return prompts, generation_config_list + + +def time_evaluator(func: Callable, args: List[Any], num_runs: int = 3): + times = [] + for _ in range(num_runs): start = time.perf_counter() func(*args) end = time.perf_counter() - total_time += end - start - print(f"Evaluator: finish round {run}") + times.append(end - start) - return total_time / num_runs + return np.array(times) def benchmark(args: argparse.Namespace): + random.seed(args.seed) + # Initialize model loading info and KV cache config model = ModelInfo(args.model_id, args.device) kv_cache_config = KVCacheConfig( - page_size=16, + page_size=args.page_size, max_num_sequence=args.batch_size, - max_total_sequence_length=args.output_length * args.batch_size * 2, + max_total_sequence_length=args.max_total_seq_length, ) - generation_config = GenerationConfig( - temperature=1.0, top_p=1.0, max_new_tokens=args.output_length - ) - prompts = [[0] * args.input_length] * args.batch_size + # Create engine engine = Engine(model, kv_cache_config) + # Sample prompts from dataset + prompts, generation_config = sample_requests(args.dataset, args.num_prompts, args.model_id) # Engine statistics - num_runs = 3 - prefill_token_latency = [] - decode_token_latency = [] - token_throughput = [] + num_runs = 1 + single_token_prefill_latency = [] + single_token_decode_latency = [] + engine_total_prefill_time = [] + engine_total_decode_time = [] + total_prefill_tokens = [] + total_decode_tokens = [] def engine_generate(): engine.reset() engine.generate(prompts, generation_config) engine_stats = engine.stats() - prefill_token_latency.append(engine_stats["prefill_token_latency"]) - decode_token_latency.append(engine_stats["decode_token_latency"]) - token_throughput.append(engine_stats["token_throughput"]) - - avg_e2e_latency = time_evaluator(engine_generate, args=[], num_runs=num_runs) + single_token_prefill_latency.append(engine_stats["single_token_prefill_latency"]) + single_token_decode_latency.append(engine_stats["single_token_decode_latency"]) + engine_total_prefill_time.append(engine_stats["engine_total_prefill_time"]) + engine_total_decode_time.append(engine_stats["engine_total_decode_time"]) + total_prefill_tokens.append(engine_stats["total_prefill_tokens"]) + total_decode_tokens.append(engine_stats["total_decode_tokens"]) - avg_prefill_token_latency = sum(prefill_token_latency[-num_runs:]) / num_runs - avg_decode_token_latency = sum(decode_token_latency[-num_runs:]) / num_runs - avg_token_throughput = sum(token_throughput[-num_runs:]) / num_runs + e2e_latency = time_evaluator(engine_generate, args=[], num_runs=num_runs) + single_token_prefill_latency = np.array(single_token_prefill_latency) + single_token_decode_latency = np.array(single_token_decode_latency) + engine_total_prefill_time = np.array(engine_total_prefill_time) + engine_total_decode_time = np.array(engine_total_decode_time) + total_prefill_tokens = np.array(total_prefill_tokens) + total_decode_tokens = np.array(total_decode_tokens) + prefill_throughput = total_prefill_tokens / engine_total_prefill_time + decode_throughput = total_decode_tokens / engine_total_decode_time + overall_throughput = (total_prefill_tokens + total_decode_tokens) / e2e_latency print(args) - print(f"Average end-to-end latency: {avg_e2e_latency} seconds for the entire batch") - print(f"Prefill token latency: {avg_prefill_token_latency * 1e3} ms/tok") - print(f"Decode token latency: {avg_decode_token_latency * 1e3} ms/tok") - print(f"Request throughput: {args.batch_size / (avg_e2e_latency / 60)} req/min") - print(f"Token throughput: {avg_token_throughput} tok/s") + print(f"Average end-to-end latency: {e2e_latency.mean():.4f} seconds for the entire batch") + print(f"Single token prefill latency: {single_token_prefill_latency.mean() * 1e3:.4f} ms/tok") + print(f"Single token decode latency: {single_token_decode_latency.mean() * 1e3:.4f} ms/tok") + print(f"Request throughput: {args.num_prompts / e2e_latency.mean():.4f} req/s") + print(f"Prefill token throughput: {prefill_throughput.mean():.4f} tok/s") + print(f"Decode token throughput: {decode_throughput.mean():.4f} tok/s") + print(f"Overall token throughput: {overall_throughput.mean():.4f} tok/s") if __name__ == "__main__": diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py new file mode 100644 index 0000000000..19f572bb77 --- /dev/null +++ b/tests/python/serve/evaluate_engine.py @@ -0,0 +1,74 @@ +# pylint: disable=line-too-long,missing-docstring +import argparse +import random +from typing import List, Tuple + +from mlc_chat.serve import GenerationConfig, KVCacheConfig +from mlc_chat.serve.engine import Engine, ModelInfo + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model-id", type=str, default="Llama-2-7b-chat-hf-q0f16") + args.add_argument("--device", type=str, default="auto") + args.add_argument("--batch-size", type=int, default=128) + args.add_argument("--page-size", type=int, default=16) + args.add_argument("--max-total-seq-length", type=int, default=16000) + args.add_argument("--seed", type=int, default=0) + + parsed = args.parse_args() + assert parsed.batch_size % 16 == 0 + assert parsed.page_size == 16 + assert parsed.max_total_seq_length >= 2048 + return parsed + + +def generate_requests( + num_requests: int, input_length: int, output_length: int +) -> Tuple[List[List[int]], List[GenerationConfig]]: + prompt_ids = [] + for _ in range(num_requests): + token_ids = [] + for _ in range(input_length): + token_ids.append(random.randint(0, 30000)) + prompt_ids.append(token_ids) + generation_config_list = [ + GenerationConfig(temperature=1.0, top_p=1.0, max_new_tokens=output_length) + ] * num_requests + return prompt_ids, generation_config_list + + +def benchmark(args: argparse.Namespace): + random.seed(args.seed) + + # Initialize model loading info and KV cache config + model = ModelInfo(args.model_id, args.device) + kv_cache_config = KVCacheConfig( + page_size=args.page_size, + max_num_sequence=args.batch_size, + max_total_sequence_length=args.max_total_seq_length, + ) + + # Create engine + engine = Engine(model, kv_cache_config) + + print(args) + for num_requests in [1, 2, 4, 8, 16, 32, 64]: + if num_requests > args.batch_size: + continue + for input_length in [64, 128, 256, 512, 1024]: + if num_requests * input_length >= 16384: + continue + for output_length in [4]: + print(f"nreq={num_requests}\t" f"in={input_length}\t" f"out={output_length}") + prompt_ids, generation_config = generate_requests( + num_requests, input_length, output_length + ) + engine.reset() + engine.generate(prompt_ids, generation_config) + print() + + +if __name__ == "__main__": + ARGS = _parse_args() + benchmark(ARGS) diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 5104002922..b47595202d 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -3,6 +3,7 @@ from typing import Callable, List, Optional import numpy as np + from mlc_chat.serve import GenerationConfig, KVCacheConfig, Request, data from mlc_chat.serve.engine import Engine, ModelInfo @@ -66,7 +67,7 @@ def test_engine_basic(): # Hyperparameters for tests (you can try different combinations). num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] repetition_penalty = 1.0 # [1.0, 1.01] max_new_tokens: int = 256 # [32, 128, 256] np.random.seed(0) @@ -181,7 +182,7 @@ def step(self) -> None: print(f"Prompt {req_id}: {request.inputs[0]}") print(f"Output {req_id}:{output}\n") assert isinstance(output, data.TextData) - assert fin_time == num_requests + request.generation_config.max_new_tokens - 2 + assert fin_time == request.generation_config.max_new_tokens - 1 def test_engine_continuous_batching_2():