diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index a3fd01c48f..beb51bcb7a 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit a3fd01c48ff1d91c2f690688700c134e1cf0c161 +Subproject commit beb51bcb7a6381c73f099ff67c30c6c17bcd116a diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 03f9bc67cd..62127e92ac 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -4,6 +4,7 @@ * \brief The implementation for runtime module of serving engine module in MLC LLM. */ #define __STDC_FORMAT_MACROS +#define PICOJSON_USE_INT64 #include #include @@ -67,21 +68,23 @@ class Engine { // Step 1. Create models and their PackedFuncs. ICHECK(models_.empty()); models_.reserve(num_models); - fmodel_single_seq_prefill_.clear(); + fmodel_batch_prefill_.clear(); fmodel_decode_.clear(); fmodel_token_embed_.clear(); fmodel_add_new_sequence_.clear(); fmodel_remove_sequence_.clear(); fmodel_softmax_with_temperature_.clear(); + fmodel_get_num_available_pages_.clear(); for (int i = 0; i < num_models; ++i) { Module model = CreateModelModule(reload_libs[i], model_paths[i], devices_[i]); models_.push_back(model); - fmodel_single_seq_prefill_.push_back(model->GetFunction("single_seq_prefill")); + fmodel_batch_prefill_.push_back(model->GetFunction("batch_prefill")); fmodel_decode_.push_back(model->GetFunction("decode")); fmodel_token_embed_.push_back(model->GetFunction("token_embed")); fmodel_add_new_sequence_.push_back(model->GetFunction("add_new_sequence")); fmodel_remove_sequence_.push_back(model->GetFunction("remove_sequence")); fmodel_softmax_with_temperature_.push_back(model->GetFunction("softmax_with_temperature")); + fmodel_get_num_available_pages_.push_back(model->GetFunction("get_num_available_pages")); } // Step 2. Fetch max single sequence length from models. max_single_sequence_length_ = std::numeric_limits::max(); @@ -124,7 +127,8 @@ class Engine { */ void AddRequest(Request request) { waiting_queue_.push_back(request); - request_states_.emplace(request, RequestState(models_.size(), request->inputs)); + request_states_.emplace( + request, RequestState(models_.size(), request->inputs, GetInputLength(request->inputs))); } /*! \brief Abort the input request. */ @@ -193,42 +197,48 @@ class Engine { } current_total_seq_len_ = 0; - prefill_total_time = 0.0f; - decode_total_time = 0.0f; - prefill_total_length = 0; - decode_total_length = 0; + request_total_prefill_time_ = 0.0f; + request_total_decode_time_ = 0.0f; + engine_total_prefill_time_ = 0.0f; + engine_total_decode_time_ = 0.0f; + total_prefill_length_ = 0; + total_decode_length_ = 0; tokenize_cache_.clear(); } /*! * \brief Return the engine runtime statistics in JSON string. * We collect the following entries: - * - prefill token latency (s/tok): avg latency of processing one token in prefill - * - decode token latency (s/tok): avg latency of processing one token in decode - * - token throughput (tok/s): avg number of tokens processed per second (prefill + decode) + * - single token prefill latency (s/tok): avg latency of processing one token in prefill + * - single token decode latency (s/tok): avg latency of processing one token in decode + * - engine time for prefill (sec) + * - engine time for decode (sec) + * - total number of processed tokens in prefill. + * - total number of processed tokens in decode. * \return The statistics in JSON string. */ String StatisticsJSON() { picojson::object config; - config["prefill_token_latency"] = picojson::value(prefill_total_time / prefill_total_length); - config["decode_token_latency"] = picojson::value(decode_total_time / decode_total_length); - config["token_throughput"] = picojson::value((prefill_total_length + decode_total_length) / - (prefill_total_time + decode_total_time)); + config["single_token_prefill_latency"] = + picojson::value(request_total_prefill_time_ / total_prefill_length_); + config["single_token_decode_latency"] = + picojson::value(request_total_decode_time_ / total_decode_length_); + config["engine_total_prefill_time"] = picojson::value(engine_total_prefill_time_); + config["engine_total_decode_time"] = picojson::value(engine_total_decode_time_); + config["total_prefill_tokens"] = picojson::value(total_prefill_length_); + config["total_decode_tokens"] = picojson::value(total_decode_length_); return picojson::value(config).serialize(true); } private: /*! \brief Pick applicable requests and run prefill. */ bool StepPrefill() { - auto [requests, sample_new_token] = GetRequestsToPrefill(); - ICHECK_EQ(requests.size(), sample_new_token.size()); + auto [requests, states, sample_new_token] = GetRequestsToPrefill(); if (requests.empty()) { return false; } - // Collect ids of the first-time prefilled requests. - // We will set the prefill finish time for these requests. - std::unordered_set first_time_prefill_ids; + auto tstart = std::chrono::high_resolution_clock::now(); for (Request request : requests) { int req_id = running_queue_.size(); @@ -242,65 +252,83 @@ class Engine { running_queue_.push_back(request); // - Assign request id for the requests. AssignIDForRequest(request, req_id); - first_time_prefill_ids.insert(req_id); } - // NOTE: Right now only single-sequence prefill is supported. - // So we prefill the requests one by one for now. - for (int req_idx = 0; req_idx < static_cast(requests.size()); ++req_idx) { - Request request = requests[req_idx]; - RequestState& state = request_states_.at(request); - ICHECK_EQ(state.mstates.size(), models_.size()); - NDArray logits{nullptr}; - current_total_seq_len_ += GetRequestPrefillInputLength(request); - // - Prefill the inputs for each model. - for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - Module model = models_[model_id]; - RequestModelState mstate = state.mstates[model_id]; + int sum_prefill_lengths = 0; + NDArray logits_for_sample{nullptr}; + Array mstates_for_sample; + Array generation_cfg_for_sample; + mstates_for_sample.reserve(requests.size()); + generation_cfg_for_sample.reserve(requests.size()); + for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { + Module model = models_[model_id]; + auto [request_list, mstates, prefill_lengths] = + FilterPrefillRequests(requests, states, model_id); + Array embeddings; + std::vector request_ids; + embeddings.reserve(request_list.size()); + request_ids.reserve(request_list.size()); + for (int i = 0; i < static_cast(request_list.size()); ++i) { + Request request = request_list[i]; + int prefill_length = prefill_lengths[i]; + RequestModelState mstate = mstates[i]; + if (model_id == 0) { + // Accumulate the sequence length. + sum_prefill_lengths += prefill_length; + current_total_seq_len_ += prefill_length; + mstates_for_sample.push_back(mstate); + generation_cfg_for_sample.push_back(request->generation_cfg); + } ICHECK(mstate->draft_output_tokens.empty()); ICHECK(mstate->draft_output_token_prob.empty()); ICHECK(mstate->draft_output_prob_dist.empty()); - if (mstate->inputs.empty()) { - continue; - } + ICHECK(!mstate->inputs.empty()); + request_ids.push_back(mstate->request_id); for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - NDArray embedding = GetEmbedding(mstate->inputs[i], fmodel_token_embed_[model_id]); - NDArray output_logits = - fmodel_single_seq_prefill_[model_id](embedding, mstate->request_id); - - // Only the output logits of the last input on the first model will be send for sampling. - if (model_id == 0 && i == static_cast(mstate->inputs.size()) - 1) { - logits = output_logits; - } + embeddings.push_back(GetEmbedding(mstate->inputs[i], fmodel_token_embed_[model_id])); } // Clean up `inputs` after prefill mstate->inputs.clear(); } - if (!sample_new_token[req_idx]) { - // This request does not need to sample a new token. - // In this case, it must not be a first-time prefilled request. - ICHECK(first_time_prefill_ids.count(state.mstates[0]->request_id)); - continue; - } - ICHECK(logits.defined()); + NDArray logits = fmodel_batch_prefill_[model_id]( + embeddings, ShapeTuple(request_ids.begin(), request_ids.end()), prefill_lengths); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], 1); + ICHECK_EQ(logits->shape[1], request_list.size()); - ShapeTuple next_token = SampleTokens(logits, /*model_id=*/0, /*sampler_id=*/0, - {state.mstates[0]}, {request->generation_cfg}); - ICHECK_EQ(next_token.size(), 1); + if (model_id == 0) { + // We only need to sample for model 0 in prefill. + logits_for_sample = logits; + } + } + if (sample_new_token) { + // - Sample tokens. + int num_requests = requests.size(); + ICHECK(logits_for_sample.defined()); + ICHECK_EQ(logits_for_sample->shape[1], num_requests); + ICHECK_EQ(mstates_for_sample.size(), num_requests); + ICHECK_EQ(generation_cfg_for_sample.size(), num_requests); + logits_for_sample = logits_for_sample.CreateView( + {num_requests, 1, logits_for_sample->shape[2]}, logits_for_sample->dtype); + ShapeTuple next_tokens = SampleTokens(logits_for_sample, /*model_id=*/0, /*sampler_id=*/0, + mstates_for_sample, generation_cfg_for_sample); + ICHECK_EQ(next_tokens.size(), num_requests); // - Update the committed tokens of states. - for (RequestModelState mstate : state.mstates) { - mstate->committed_tokens.push_back(next_token[0]); - } - // If the request is first-time prefilled, set the prefill finish time. - if (first_time_prefill_ids.count(state.mstates[0]->request_id)) { - state.tprefill_finish = std::chrono::high_resolution_clock::now(); + // - If a request is first-time prefilled, set the prefill finish time. + auto tnow = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < num_requests; ++i) { + mstates_for_sample[i]->committed_tokens.push_back(next_tokens[i]); + if (mstates_for_sample[i]->committed_tokens.size() == 1) { + request_states_.at(requests[i])->tprefill_finish = tnow; + } } } + + auto tend = std::chrono::high_resolution_clock::now(); + engine_total_prefill_time_ += static_cast((tend - tstart).count()) / 1e9; + return true; } @@ -311,18 +339,19 @@ class Engine { return false; } - Array requests = GetRequestsToDecode(); - if (requests.empty()) { + PreemptUnfittableRequests(); + if (running_queue_.empty()) { return false; } + auto tstart = std::chrono::high_resolution_clock::now(); + // NOTE: Right now we only support decode all the running requests at a time. - ICHECK_EQ(requests.size(), running_queue_.size()); - int num_requests = requests.size(); + int num_requests = running_queue_.size(); // Check if the requests ids are in an ascending order. for (int i = 1; i < num_requests; ++i) { - ICHECK_GT(request_states_.at(requests[i]).mstates[0]->request_id, - request_states_.at(requests[i - 1]).mstates[0]->request_id); + ICHECK_GT(request_states_.at(running_queue_[i])->mstates[0]->request_id, + request_states_.at(running_queue_[i - 1])->mstates[0]->request_id); } current_total_seq_len_ += num_requests; @@ -337,10 +366,10 @@ class Engine { inputs.reserve(num_requests); mstates.reserve(num_requests); generation_cfg.reserve(num_requests); - for (Request request : requests) { + for (Request request : running_queue_) { RequestState& state = request_states_.at(request); - inputs.push_back(TokenData(ShapeTuple({state.mstates[0]->committed_tokens.back()}))); - mstates.push_back(state.mstates[0]); + inputs.push_back(TokenData(ShapeTuple({state->mstates[0]->committed_tokens.back()}))); + mstates.push_back(state->mstates[0]); generation_cfg.push_back(request->generation_cfg); } @@ -363,6 +392,10 @@ class Engine { for (int i = 0; i < num_requests; ++i) { mstates[i]->committed_tokens.push_back(next_tokens[i]); } + + auto tend = std::chrono::high_resolution_clock::now(); + engine_total_decode_time_ += static_cast((tend - tstart).count()) / 1e9; + return true; } @@ -413,8 +446,9 @@ class Engine { // The request to abort is in running queue int req_id = it_running - running_queue_.begin(); running_queue_.erase(it_running); - current_total_seq_len_ -= GetRequestRawInputLength(request) + - request_states_.at(request).mstates[0]->committed_tokens.size() - 1; + RequestState state = request_states_.at(request); + current_total_seq_len_ -= + state->raw_input_length + state->mstates[0]->committed_tokens.size() - 1; RemoveSequenceFromModels(req_id); UpdateRequestIDAfterRemoval(req_id); } else { @@ -440,8 +474,8 @@ class Engine { // - Update `inputs` for future prefill. RequestState& state = request_states_.at(request); current_total_seq_len_ -= - GetRequestRawInputLength(request) + state.mstates[0]->committed_tokens.size() - 1; - for (RequestModelState mstate : state.mstates) { + state->raw_input_length + state->mstates[0]->committed_tokens.size() - 1; + for (RequestModelState mstate : state->mstates) { mstate->request_id = -1; mstate->draft_output_tokens.clear(); mstate->draft_output_token_prob.clear(); @@ -465,51 +499,78 @@ class Engine { * additionally return a boolean flag indicating if a new * token needs to be sampled from logits after prefill. */ - std::pair, std::vector> GetRequestsToPrefill() { - // NOTE: Right now we only support single-sequence prefill. + std::tuple, Array, bool> GetRequestsToPrefill() { + // - Try to prefill pending requests. + std::vector prefill_requests; + std::vector states; if (!waiting_queue_.empty()) { - Array prefill_requests{waiting_queue_.front()}; - if (CanPrefill(prefill_requests)) { + int total_input_length = 0; + int total_required_pages = 0; + ICHECK(fmodel_get_num_available_pages_[0].defined()); + int num_available_pages = fmodel_get_num_available_pages_[0](); + + for (int i = 0; i < static_cast(waiting_queue_.size()); ++i) { + Request request = waiting_queue_[i]; + RequestState state = request_states_.at(request); + int input_length = GetInputLength(state->mstates[0]->inputs); + int num_require_pages = + (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (CanPrefill(i + 1, total_input_length, total_required_pages, num_available_pages)) { + prefill_requests.push_back(request); + states.push_back(state); + } else { + total_input_length -= input_length; + total_required_pages -= num_require_pages; + break; + } + } + if (!prefill_requests.empty()) { // Need to sample a new token for waiting requests. - return {prefill_requests, {true}}; + return {prefill_requests, states, true}; } } + // Try to prefill for small models. for (Request request : running_queue_) { - Array mstates = request_states_.at(request).mstates; + RequestState state = request_states_.at(request); + Array mstates = state->mstates; for (int i = 0; i < static_cast(mstates.size()); ++i) { if (!mstates[i]->inputs.empty()) { ICHECK_NE(i, 0); - // This return happens only for "small" models in - // speculative inference settings. - // Therefore no need to sample new token from logits. - return {{request}, {false}}; + prefill_requests.push_back(request); + states.push_back(state); + break; } } } - - return {}; + // This return happens only for "small" models in + // speculative inference settings. + // Therefore no need to sample new token from logits. + return {prefill_requests, states, false}; } - /*! \brief Find requests to decode. */ - Array GetRequestsToDecode() { + /*! \brief Preempt the requests unfittable for decode. */ + void PreemptUnfittableRequests() { if (running_queue_.empty()) { - return {}; + return; } - // NOTE: Right now we only support decode all the running requests at a time. - Array requests(running_queue_); - if (CanDecode(requests)) { - return requests; + int num_available_pages = fmodel_get_num_available_pages_[0](); + while (true) { + if (CanDecode(running_queue_.size())) { + break; + } + StepPreempt(running_queue_.back()); } - return {}; } /*! \brief Assign the given id for the given request. */ void AssignIDForRequest(Request request, int req_id) { // Set id in the request state. RequestState& state = request_states_.at(request); - for (RequestModelState mstate : state.mstates) { + for (RequestModelState mstate : state->mstates) { mstate->request_id = req_id; } // Add a new sequence to each model. @@ -521,9 +582,9 @@ class Engine { } /*! \brief Check if the input requests can be prefilled under conditions. */ - bool CanPrefill(Array requests) { + bool CanPrefill(int num_prefill_req, int total_input_length, int num_required_pages, + int num_available_pages) { int num_running_requests = running_queue_.size(); - int num_prefill_req = requests.size(); ICHECK_LE(num_running_requests, kv_cache_config_->max_num_sequence); // No exceeding of the maximum allowed requests that can @@ -532,55 +593,53 @@ class Engine { return false; } - // The total length + the maximum allowed single-sequence length does - // not exceed the maximum allowed total length. - // NOTE: this condition is heuristic and can be revised. - int total_input_length = 0; - int total_max_new_tokens = 0; - for (Request request : requests) { - total_input_length += GetRequestPrefillInputLength(request); - total_max_new_tokens += request->generation_cfg->max_new_tokens; - } - return current_total_seq_len_ + std::min(total_input_length + total_max_new_tokens, - num_prefill_req * max_single_sequence_length_) < - kv_cache_config_->max_total_sequence_length; + // NOTE: The conditions are heuristic and can be revised. + // Cond 1: total input length <= max allowed single sequence length. + // Cond 2: remaining pages >= 10, where 10 is a watermark number can + // be configured and adjusted in the future. + // Cond 3: at least one decode can be performed after prefill. + // Todo: move watermark to config. + int new_batch_size = num_running_requests + num_prefill_req; + return total_input_length <= max_single_sequence_length_ && + num_required_pages + new_batch_size <= num_available_pages && + current_total_seq_len_ + total_input_length + 8 * new_batch_size <= + kv_cache_config_->max_total_sequence_length; } /*! \brief Check if the input requests can be decoded under conditions. */ - bool CanDecode(Array requests) { - return current_total_seq_len_ + requests.size() < kv_cache_config_->max_total_sequence_length; + bool CanDecode(int num_requests) { + int num_available_pages = fmodel_get_num_available_pages_[0](); + return num_requests <= num_available_pages; } - /*! - * \brief Get the total equivalent **input length to prefill** - * of the given request's current state. - */ - int GetRequestPrefillInputLength(const Request& request) { - auto it = request_states_.find(request); - ICHECK(it != request_states_.end()); - - const RequestState& state = it->second; - ICHECK_EQ(state.mstates.size(), models_.size()); - int input_length = -1; - for (const RequestModelState& mstate : state.mstates) { - int length_sum = 0; - for (Data input : mstate->inputs) { - length_sum += GetInputLength(input); - } - if (input_length == -1) { - input_length = length_sum; - } else { - ICHECK_EQ(length_sum, input_length); + /*! \brief Filter the requests to prefill on the given model. */ + std::tuple, Array, ShapeTuple> FilterPrefillRequests( + Array requests, Array states, int model_id) { + ICHECK_EQ(requests.size(), states.size()); + int num_requests = requests.size(); + Array filtered_requests; + Array filtered_mstates; + std::vector prefill_length; + filtered_requests.reserve(num_requests); + filtered_mstates.reserve(num_requests); + prefill_length.reserve(num_requests); + + for (int i = 0; i < num_requests; ++i) { + int length = GetInputLength(states[i]->mstates[model_id]->inputs); + if (length > 0) { + filtered_requests.push_back(requests[i]); + filtered_mstates.push_back(states[i]->mstates[model_id]); + prefill_length.push_back(length); } } - ICHECK_NE(input_length, -1); - return input_length; + return {filtered_requests, filtered_mstates, + ShapeTuple(prefill_length.begin(), prefill_length.end())}; } - /*! \brief Get the total equivalent **input length** of the given request. */ - int GetRequestRawInputLength(const Request& request) { + /*! \brief Get the total input length of the given inputs. */ + int GetInputLength(Array inputs) { int length_sum = 0; - for (Data input : request->inputs) { + for (Data input : inputs) { length_sum += GetInputLength(input); } return length_sum; @@ -595,6 +654,7 @@ class Engine { return tokens_input->token_ids.size(); } else { ICHECK(false) << "Cannot reach here"; + throw; } } @@ -633,6 +693,7 @@ class Engine { return fmodel_token_embed(Array{tokens_input->token_ids}); } else { ICHECK(false) << "Cannot reach here"; + throw; } } @@ -785,7 +846,7 @@ class Engine { void UpdateRequestIDAfterRemoval(int removed_req_id) { for (auto& it : request_states_) { RequestState& state = it.second; - for (RequestModelState mstate : state.mstates) { + for (RequestModelState mstate : state->mstates) { ICHECK_NE(mstate->request_id, removed_req_id); if (mstate->request_id > removed_req_id) { --mstate->request_id; @@ -819,10 +880,10 @@ class Engine { running_queue_.erase(it); RequestState& state = request_states_.at(request); - int num_input_tokens = GetRequestRawInputLength(request); - int num_output_tokens = state.mstates[0]->committed_tokens.size() - 1; + int num_input_tokens = state->raw_input_length; + int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1; current_total_seq_len_ -= num_input_tokens + num_output_tokens; - for (RequestModelState mstate : state.mstates) { + for (RequestModelState mstate : state->mstates) { ICHECK_EQ(mstate->request_id, req_id); mstate->request_id = -1; } @@ -830,15 +891,19 @@ class Engine { UpdateRequestIDAfterRemoval(req_id); auto trequest_finish = std::chrono::high_resolution_clock::now(); - prefill_total_time += static_cast((state.tprefill_finish - state.tadd).count()) / 1e9; - prefill_total_length += num_input_tokens; - decode_total_time += - static_cast((trequest_finish - state.tprefill_finish).count()) / 1e9; - decode_total_length += num_output_tokens; + request_total_prefill_time_ += + static_cast((state->tprefill_finish - state->tadd).count()) / 1e9; + total_prefill_length_ += num_input_tokens; + request_total_decode_time_ += + static_cast((trequest_finish - state->tprefill_finish).count()) / 1e9; + total_decode_length_ += num_output_tokens; // NOTE: right now we only return the generated text. // In the future we might optional return text or token ids. - request->fcallback(request, TextData(state.output)); + String output = ftokenizer_decode(ShapeTuple(state->mstates[0]->committed_tokens.begin(), + state->mstates[0]->committed_tokens.end())); + state->output = output.operator std::string(); + request->fcallback(request, TextData(state->output)); // Remove the request from states. request_states_.erase(request); @@ -851,23 +916,18 @@ class Engine { // - Case 0. There is remaining draft output ==> Unfinished // All draft outputs are supposed to be processed before finish. - for (RequestModelState mstate : state.mstates) { + for (RequestModelState mstate : state->mstates) { if (!mstate->draft_output_tokens.empty()) { return false; } } // - Decode committed tokens. - const std::vector& committed_tokens = state.mstates[0]->committed_tokens; - String output = ftokenizer_decode(ShapeTuple(committed_tokens.begin(), committed_tokens.end())); - state.output = output.operator std::string(); + const std::vector& committed_tokens = state->mstates[0]->committed_tokens; // Case 1. Any of the stop strings appears in output ==> Finished - for (String stop_str : request->generation_cfg->stop_strs) { - if (state.output.rfind(stop_str) != std::string::npos) { - return true; - } - } + // Todo: handle stop_str by tokenizing. So that we don't detokenize during check + // Case 2. Any of the stop tokens appears in the committed tokens ===> Finished if (std::any_of(request->generation_cfg->stop_tokens.begin(), request->generation_cfg->stop_tokens.end(), [&committed_tokens](int32_t token) { @@ -880,7 +940,7 @@ class Engine { return true; } // Case 4. Total length of the request reaches the maximum single sequence length ==> Finished - if (GetRequestRawInputLength(request) + static_cast(committed_tokens.size()) >= + if (state->raw_input_length + static_cast(committed_tokens.size()) >= max_single_sequence_length_) { return true; } @@ -905,12 +965,13 @@ class Engine { NDArray logits_or_probs_on_cpu_{nullptr}; // PackedFuncs from model/tokenizer/sampler/env. - std::vector fmodel_single_seq_prefill_; + std::vector fmodel_batch_prefill_; std::vector fmodel_decode_; std::vector fmodel_token_embed_; std::vector fmodel_add_new_sequence_; std::vector fmodel_remove_sequence_; std::vector fmodel_softmax_with_temperature_; + std::vector fmodel_get_num_available_pages_; std::vector fsampler_require_gpu_softmax_; std::vector fsampler_compute_probs_from_logits_inplace_; std::vector fsampler_sample_token_from_probs_; @@ -919,10 +980,18 @@ class Engine { // Runtime statistics int64_t current_total_seq_len_; - double prefill_total_time = 0; - double decode_total_time = 0; - int64_t prefill_total_length = 0; - int64_t decode_total_length = 0; + /*! \brief The sum of "prefill time of each request". */ + double request_total_prefill_time_ = 0.0f; + /*! \brief The sum of "decode time of each request". */ + double request_total_decode_time_ = 0.0f; + /*! \brief The total engine time on prefill. */ + double engine_total_prefill_time_ = 0.0f; + /*! \brief The total engine time on decode. */ + double engine_total_decode_time_ = 0.0f; + /*! \brief The total number of processed tokens in prefill. */ + int64_t total_prefill_length_ = 0; + /*! \brief The total number of processed tokens in decode. */ + int64_t total_decode_length_ = 0; // Tokenization cache std::unordered_map tokenize_cache_; diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index be77c16970..2e9e1469ba 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -7,11 +7,11 @@ #include "function_table.h" #include +#include #include #include #include #include -#include #include #include @@ -108,8 +108,8 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, int num_shards) this->local_vm = fload_exec(); this->local_vm->GetFunction("vm_initialization")( static_cast(device.device_type), device.device_id, - static_cast(relax_vm::AllocatorType::kPooled), static_cast(kDLCPU), 0, - static_cast(relax_vm::AllocatorType::kPooled)); + static_cast(tvm::runtime::memory::AllocatorType::kPooled), static_cast(kDLCPU), 0, + static_cast(tvm::runtime::memory::AllocatorType::kPooled)); this->mod_get_func = [this](const std::string& name) -> PackedFunc { return this->local_vm->GetFunction(name, false); }; @@ -169,6 +169,8 @@ void FunctionTable::_InitFunctions() { get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device"); this->remove_from_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_remove"); this->popn_from_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn"); + this->get_num_available_pages_kv_cache_func_ = + get_global_func("vm.builtin.paged_attention_kv_cache_get_num_available_pages"); support_backtracking_kv_ = true; } diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index ef9b3a5976..640785389a 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -69,6 +69,7 @@ struct FunctionTable { PackedFunc sync_device_kv_cache_func_; PackedFunc remove_from_kv_cache_func_; PackedFunc popn_from_kv_cache_func_; + PackedFunc get_num_available_pages_kv_cache_func_; }; } // namespace serve diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 4fb85bfc52..9c2a1e768b 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -77,10 +77,10 @@ class ModelModule : public ModuleNode { CHECK_EQ(args.size(), 1); *rv = TokenEmbed(args[0]); }); - } else if (name == "single_seq_prefill") { + } else if (name == "batch_prefill") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 2); - *rv = SingleSequencePrefill(args[0], args[1]); + CHECK_EQ(args.size(), 3); + *rv = BatchPrefill(args[0], args[1], args[2]); }); } else if (name == "decode") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { @@ -115,15 +115,18 @@ class ModelModule : public ModuleNode { ICHECK_EQ(args.size(), 0); Reset(); }); + } else if (name == "get_num_available_pages") { + return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size(), 0); + ICHECK(kv_cache_.defined()); + *rv = ft_.get_num_available_pages_kv_cache_func_(kv_cache_); + }); } else if (name == "get_max_window_size") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size(), 0); CHECK_NE(max_window_size_, -1) << "The model has not been initialized"; *rv = max_window_size_; }); - } else if (name == "runtime_stats_text") { - // Todo: JSON style - return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = GetStats(); }); } else { return PackedFunc(nullptr); } @@ -156,7 +159,8 @@ class ModelModule : public ModuleNode { } // Copy input token ids to device. DLDataType dtype(DataType::Int(32)); - NDArray token_ids_nd = CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, 2048); + NDArray token_ids_nd = + CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, max_window_size_); ICHECK_EQ(token_ids_nd->ndim, 1); ICHECK_EQ(token_ids_nd->shape[0], total_length); token_ids_nd = token_ids_nd.CreateView({1, total_length}, dtype); @@ -165,12 +169,7 @@ class ModelModule : public ModuleNode { << "`embed` function is not found in the model. Please make sure the model is compiled " "with flag `--sep-embed` and `--enable-batching`"; - auto tstart = std::chrono::high_resolution_clock::now(); NDArray embeddings = ft_.embed_func_(ft_.CopyToWorker0(token_ids_nd), params_); - auto tend = std::chrono::high_resolution_clock::now(); - - this->embed_total_time += static_cast((tend - tstart).count()) / 1e9; - this->embed_total_tokens += total_length; // embeddings: (1, total_length, hidden_size) ICHECK_EQ(embeddings->ndim, 3); @@ -183,14 +182,33 @@ class ModelModule : public ModuleNode { * \brief Single-sequence prefill function. Embedding in, logits out. * \param embeddings The embedding of the input to be prefilled. * \param seq_id The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. * \return The logits for the next token. */ - NDArray SingleSequencePrefill(NDArray embeddings, int seq_id) { + NDArray BatchPrefill(Array embedding_arr, ShapeTuple seq_ids, ShapeTuple lengths) { + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + std::vector logit_pos; + logit_pos.reserve(num_sequences); + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + logit_pos.push_back(total_length); + if (i > 0) { + CHECK_GT(seq_ids[i], seq_ids[i - 1]) << "The input sequence ids must be non-decreasing."; + } + } + // embeddings: (1, n, h) - CHECK_EQ(embeddings->ndim, 3); - CHECK_EQ(embeddings->shape[0], 1); - CHECK_EQ(embeddings->device.device_type, device_.device_type); - CHECK_EQ(embeddings->device.device_id, device_.device_id); + NDArray embeddings = ConcatEmbeddings(std::move(embedding_arr), total_length); + ICHECK_EQ(embeddings->ndim, 3); + ICHECK_EQ(embeddings->shape[0], 1); + ICHECK_EQ(embeddings->shape[1], total_length); + ICHECK_EQ(embeddings->device.device_type, device_.device_type); + ICHECK_EQ(embeddings->device.device_id, device_.device_id); + + NDArray logit_pos_nd = CopyArrayToDevice(logit_pos, &logit_pos_arr_, DataType::Int(32), 32); CHECK(ft_.prefill_func_.defined()) << "`prefill_with_embed` function is not found in the model. Please make sure the model is " @@ -202,22 +220,20 @@ class ModelModule : public ModuleNode { // Reserve in KV cache for the length of the input. ft_.reset_append_length_kv_cache_func_(kv_cache_); - ft_.reserve_length_in_kv_cache_func_(kv_cache_, seq_id, /*length=*/embeddings->shape[1]); + for (int i = 0; i < num_sequences; ++i) { + ft_.reserve_length_in_kv_cache_func_(kv_cache_, seq_ids[i], lengths[i]); + } ft_.sync_device_kv_cache_func_(kv_cache_); - auto tstart = std::chrono::high_resolution_clock::now(); - // args: embeddings, kv_cache, params - Array ret = ft_.prefill_func_(ft_.CopyToWorker0(embeddings), kv_cache_, params_); - auto tend = std::chrono::high_resolution_clock::now(); - - this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9; - this->prefill_total_tokens += embeddings->shape[1]; + // args: embeddings, logit_pos, kv_cache, params + Array ret = + ft_.prefill_func_(ft_.CopyToWorker0(embeddings), logit_pos_nd, kv_cache_, params_); - // logits: (1, 1, v) + // logits: (1, num_sequences, v) NDArray logits = Downcast(ret[0]); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], 1); + ICHECK_EQ(logits->shape[1], num_sequences); return logits; } @@ -251,13 +267,8 @@ class ModelModule : public ModuleNode { } ft_.sync_device_kv_cache_func_(kv_cache_); - auto tstart = std::chrono::high_resolution_clock::now(); // args: embeddings, kv_cache, params Array ret = ft_.decode_func_(ft_.CopyToWorker0(embeddings), kv_cache_, params_); - auto tend = std::chrono::high_resolution_clock::now(); - - this->decode_total_time += static_cast((tend - tstart).count()) / 1e9; - this->decode_total_tokens += embeddings->shape[0]; // logits: (b, 1, v) NDArray logits = Downcast(ret[0]); @@ -286,7 +297,7 @@ class ModelModule : public ModuleNode { for (GenerationConfig cfg : generation_cfg) { temperatures.push_back(cfg->temperature); } - NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 16); + NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 32); ICHECK_EQ(temperatures_nd->ndim, 1); ICHECK_EQ(temperatures_nd->shape[0], batch_size); @@ -318,6 +329,57 @@ class ModelModule : public ModuleNode { return view; } + /*! \brief Concatenate the input embeddings. */ + NDArray ConcatEmbeddings(Array embedding_arr, int64_t total_length) { + ICHECK(!embedding_arr.empty()); + int hidden_size = -1; + DataType dtype; + for (NDArray inp_embeddings : embedding_arr) { + // inp_embedding: (1, n, h) + CHECK_EQ(inp_embeddings->ndim, 3); + CHECK_EQ(inp_embeddings->shape[0], 1); + CHECK_EQ(inp_embeddings->device.device_type, device_.device_type); + CHECK_EQ(inp_embeddings->device.device_id, device_.device_id); + if (hidden_size == -1) { + hidden_size = inp_embeddings->shape[2]; + dtype = inp_embeddings.DataType(); + } else { + CHECK_EQ(inp_embeddings->shape[2], hidden_size); + CHECK_EQ(inp_embeddings.DataType(), dtype); + } + } + + // - Resize the shared embedding array. + if (embeddings_.defined()) { + ICHECK_EQ(embeddings_->ndim, 3); + ICHECK_EQ(embeddings_->shape[0], 1); + ICHECK_EQ(embeddings_->shape[2], hidden_size); + } + int64_t init_size = embeddings_.defined() ? embeddings_->shape[1] : max_window_size_; + while (init_size < total_length) { + init_size *= 2; + } + if (!embeddings_.defined() || init_size != embeddings_->shape[1]) { + embeddings_ = NDArray::Empty({1, init_size, hidden_size}, dtype, device_); + } + + // - Copy input embeddings. + int64_t start_pos = 0; + for (NDArray inp_embeddings : embedding_arr) { + int64_t length = inp_embeddings->shape[1]; + CHECK_LE(start_pos + length, total_length); + + DLTensor copy_dst = *(embeddings_.operator->()); + copy_dst.byte_offset = start_pos * hidden_size * dtype.bytes(); + copy_dst.shape = inp_embeddings->shape; + NDArray::CopyFromTo(inp_embeddings.operator->(), ©_dst); + + start_pos += length; + } + CHECK_EQ(start_pos, total_length); + return embeddings_.CreateView({1, total_length, hidden_size}, dtype); + } + /*! \brief Load model configuration from JSON. */ void LoadModelConfigJSON(const std::string& config_str) { picojson::value config_json; @@ -350,37 +412,12 @@ class ModelModule : public ModuleNode { /*! \brief reset the runtime states. */ void Reset() { - // Reset the statistics. - this->embed_total_tokens = 0; - this->prefill_total_tokens = 0; - this->decode_total_tokens = 0; - this->embed_total_time = 0; - this->prefill_total_time = 0; - this->decode_total_time = 0; // Reset the KV cache. if (kv_cache_.defined()) { ft_.reset_kv_cache_func_(kv_cache_); } } - /*! \brief Return statistics in JSON format. */ - String GetStats() { - picojson::object stats; - stats["prefill_speed"] = picojson::value(prefill_total_tokens / prefill_total_time); - stats["decode_speed"] = picojson::value(decode_total_tokens / decode_total_time); - stats["embed_speed"] = picojson::value(embed_total_tokens / embed_total_time); - return picojson::value(stats).serialize(true); - } - - //---------------------------- - // Statistics - //---------------------------- - double embed_total_time = 0; - double decode_total_time = 0; - double prefill_total_time = 0; - int64_t embed_total_tokens = 0; - int64_t decode_total_tokens = 0; - int64_t prefill_total_tokens = 0; //---------------------------- // Model configurations //---------------------------- @@ -400,6 +437,8 @@ class ModelModule : public ModuleNode { ObjectRef params_; // Shared NDArray NDArray input_token_ids_{nullptr}; + NDArray embeddings_{nullptr}; + NDArray logit_pos_arr_{nullptr}; NDArray temperature_arr_{nullptr}; }; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 8e86682d78..38b5a7bf66 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -23,6 +23,21 @@ RequestModelState::RequestModelState(int model_id, Array inputs) { data_ = std::move(n); } +TVM_REGISTER_OBJECT_TYPE(RequestStateNode); + +RequestState::RequestState(int num_models, Array inputs, int raw_input_length) { + ObjectPtr n = make_object(); + Array mstates; + mstates.reserve(num_models); + for (int i = 0; i < num_models; ++i) { + mstates.push_back(RequestModelState(i, inputs)); + } + n->mstates = std::move(mstates); + n->raw_input_length = raw_input_length; + n->tadd = std::chrono::high_resolution_clock::now(); + data_ = std::move(n); +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 1de1f5e3f2..16f39548c2 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -86,13 +86,16 @@ class RequestModelState : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestModelState, ObjectRef, RequestModelStateNode); }; -struct RequestState { +class RequestStateNode : public Object { + public: /*! * \brief The state with regard to each model. * \sa RequestModelState */ Array mstates; + /*! \brief The summed up input length of the request. */ + int raw_input_length = 0; /*! \brief The decoded text string output. */ std::string output = ""; @@ -101,13 +104,17 @@ struct RequestState { /*! \brief The time of finishing prefill stage. */ std::chrono::_V2::system_clock::time_point tprefill_finish; - explicit RequestState(int num_models, Array inputs) { - mstates.reserve(num_models); - for (int i = 0; i < num_models; ++i) { - mstates.push_back(RequestModelState(i, inputs)); - } - tadd = std::chrono::high_resolution_clock::now(); - } + static constexpr const char* _type_key = "mlc.serve.RequestState"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(RequestStateNode, Object); +}; + +class RequestState : public ObjectRef { + public: + explicit RequestState(int num_models, Array inputs, int raw_input_length); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestState, ObjectRef, RequestStateNode); }; } // namespace serve diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py index da6ada6dad..f9c0ec6c99 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_chat/serve/engine.py @@ -221,12 +221,12 @@ def reset(self) -> None: def stats(self) -> Dict[str, float]: """The engine runtime statistics. We collect the following entries: - - prefill token latency (s/tok) - avg latency of processing one token in prefill - - decode token latency (s/tok) - avg latency of processing one token in decode - - token throughput (tok/s) - avg number of tokens processed per second (prefill + decode) + - single token prefill latency (s/tok): avg latency of processing one token in prefill + - single token decode latency (s/tok): avg latency of processing one token in decode + - engine time for prefill (sec) + - engine time for decode (sec) + - total number of processed tokens in prefill. + - total number of processed tokens in decode. """ stats_json_str = self._get_stats_func() stats = json.loads(stats_json_str) diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py index 19d46f99ba..1b1a7b71bb 100644 --- a/tests/python/serve/benchmark.py +++ b/tests/python/serve/benchmark.py @@ -1,81 +1,156 @@ -# pylint: disable=missing-docstring +# pylint: disable=line-too-long,missing-docstring,no-member,too-many-locals +# type: ignore import argparse +import json +import random import time -from typing import Any, Callable, List +from typing import Any, Callable, List, Tuple +import numpy as np +from transformers import AutoTokenizer + +from mlc_chat.chat_module import _get_model_path from mlc_chat.serve import GenerationConfig, KVCacheConfig from mlc_chat.serve.engine import Engine, ModelInfo def _parse_args(): args = argparse.ArgumentParser() - args.add_argument("--model-id", type=str, default="Llama-2-7b-chat-hf-q4f16_1") + args.add_argument("--model-id", type=str, default="Llama-2-7b-chat-hf-q0f16") + # Download dataset from + # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + args.add_argument("--dataset", type=str, required=True) args.add_argument("--device", type=str, default="auto") - args.add_argument("--input-length", type=int, default=32) - args.add_argument("--output-length", type=int, default=256) - args.add_argument("--batch-size", type=int, default=16) + args.add_argument("--num-prompts", type=int, default=500) + args.add_argument("--batch-size", type=int, default=80) + args.add_argument("--page-size", type=int, default=16) + args.add_argument("--max-total-seq-length", type=int, default=16800) + args.add_argument("--seed", type=int, default=0) + parsed = args.parse_args() + assert parsed.batch_size % 16 == 0 + assert parsed.page_size == 16 + assert parsed.max_total_seq_length >= 2048 return parsed -def time_evaluator(func: Callable, args: List[Any], num_runs: int = 3, num_warmups: int = 1): - # warmup run - print("Start warmup...") - for _ in range(num_warmups): - func(*args) +def sample_requests( + dataset_path: str, num_requests: int, model_id: str +) -> Tuple[List[str], List[GenerationConfig]]: + """Sample requests from dataset. + Acknowledgement to the benchmark scripts in the vLLM project. + """ + model_path, _ = _get_model_path(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + with open(dataset_path, encoding="utf-8") as f: + dataset = json.load(f) + + # Filter out the conversations with less than 2 turns. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + if len(data["conversations"]) >= 2 + ] + # Tokenize the prompts and completions. + prompts = [prompt for prompt, _ in dataset] + prompt_token_ids = tokenizer(prompts).input_ids + completions = [completion for _, completion in dataset] + completion_token_ids = tokenizer(completions).input_ids + tokenized_dataset = [] + for i in range(len(dataset)): + output_len = len(completion_token_ids[i]) + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + + # Filter out too long sequences. + filtered_dataset: List[Tuple[str, int, int]] = [] + for prompt, prompt_token_ids, output_len in tokenized_dataset: + prompt_len = len(prompt_token_ids) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + # Sample the requests. + sampled_requests = random.sample(filtered_dataset, num_requests) - total_time = 0.0 - for run in range(num_runs): - print(f"Evaluator: start round {run}") + # Construct generation config. + prompts = [prompt for prompt, _, _ in sampled_requests] + generation_config_list = [ + GenerationConfig(temperature=1.0, top_p=1.0, max_new_tokens=output_len) + for _, _, output_len in sampled_requests + ] + return prompts, generation_config_list + + +def time_evaluator(func: Callable, args: List[Any], num_runs: int = 3): + times = [] + for _ in range(num_runs): start = time.perf_counter() func(*args) end = time.perf_counter() - total_time += end - start - print(f"Evaluator: finish round {run}") + times.append(end - start) - return total_time / num_runs + return np.array(times) def benchmark(args: argparse.Namespace): + random.seed(args.seed) + # Initialize model loading info and KV cache config model = ModelInfo(args.model_id, args.device) kv_cache_config = KVCacheConfig( - page_size=16, + page_size=args.page_size, max_num_sequence=args.batch_size, - max_total_sequence_length=args.output_length * args.batch_size * 2, + max_total_sequence_length=args.max_total_seq_length, ) - generation_config = GenerationConfig( - temperature=1.0, top_p=1.0, max_new_tokens=args.output_length - ) - prompts = [[0] * args.input_length] * args.batch_size + # Create engine engine = Engine(model, kv_cache_config) + # Sample prompts from dataset + prompts, generation_config = sample_requests(args.dataset, args.num_prompts, args.model_id) # Engine statistics - num_runs = 3 - prefill_token_latency = [] - decode_token_latency = [] - token_throughput = [] + num_runs = 1 + single_token_prefill_latency = [] + single_token_decode_latency = [] + engine_total_prefill_time = [] + engine_total_decode_time = [] + total_prefill_tokens = [] + total_decode_tokens = [] def engine_generate(): engine.reset() engine.generate(prompts, generation_config) engine_stats = engine.stats() - prefill_token_latency.append(engine_stats["prefill_token_latency"]) - decode_token_latency.append(engine_stats["decode_token_latency"]) - token_throughput.append(engine_stats["token_throughput"]) - - avg_e2e_latency = time_evaluator(engine_generate, args=[], num_runs=num_runs) + single_token_prefill_latency.append(engine_stats["single_token_prefill_latency"]) + single_token_decode_latency.append(engine_stats["single_token_decode_latency"]) + engine_total_prefill_time.append(engine_stats["engine_total_prefill_time"]) + engine_total_decode_time.append(engine_stats["engine_total_decode_time"]) + total_prefill_tokens.append(engine_stats["total_prefill_tokens"]) + total_decode_tokens.append(engine_stats["total_decode_tokens"]) - avg_prefill_token_latency = sum(prefill_token_latency[-num_runs:]) / num_runs - avg_decode_token_latency = sum(decode_token_latency[-num_runs:]) / num_runs - avg_token_throughput = sum(token_throughput[-num_runs:]) / num_runs + e2e_latency = time_evaluator(engine_generate, args=[], num_runs=num_runs) + single_token_prefill_latency = np.array(single_token_prefill_latency) + single_token_decode_latency = np.array(single_token_decode_latency) + engine_total_prefill_time = np.array(engine_total_prefill_time) + engine_total_decode_time = np.array(engine_total_decode_time) + total_prefill_tokens = np.array(total_prefill_tokens) + total_decode_tokens = np.array(total_decode_tokens) + prefill_throughput = total_prefill_tokens / engine_total_prefill_time + decode_throughput = total_decode_tokens / engine_total_decode_time + overall_throughput = (total_prefill_tokens + total_decode_tokens) / e2e_latency print(args) - print(f"Average end-to-end latency: {avg_e2e_latency} seconds for the entire batch") - print(f"Prefill token latency: {avg_prefill_token_latency * 1e3} ms/tok") - print(f"Decode token latency: {avg_decode_token_latency * 1e3} ms/tok") - print(f"Request throughput: {args.batch_size / (avg_e2e_latency / 60)} req/min") - print(f"Token throughput: {avg_token_throughput} tok/s") + print(f"Average end-to-end latency: {e2e_latency.mean():.4f} seconds for the entire batch") + print(f"Single token prefill latency: {single_token_prefill_latency.mean() * 1e3:.4f} ms/tok") + print(f"Single token decode latency: {single_token_decode_latency.mean() * 1e3:.4f} ms/tok") + print(f"Request throughput: {args.num_prompts / e2e_latency.mean():.4f} req/s") + print(f"Prefill token throughput: {prefill_throughput.mean():.4f} tok/s") + print(f"Decode token throughput: {decode_throughput.mean():.4f} tok/s") + print(f"Overall token throughput: {overall_throughput.mean():.4f} tok/s") if __name__ == "__main__": diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py new file mode 100644 index 0000000000..19f572bb77 --- /dev/null +++ b/tests/python/serve/evaluate_engine.py @@ -0,0 +1,74 @@ +# pylint: disable=line-too-long,missing-docstring +import argparse +import random +from typing import List, Tuple + +from mlc_chat.serve import GenerationConfig, KVCacheConfig +from mlc_chat.serve.engine import Engine, ModelInfo + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model-id", type=str, default="Llama-2-7b-chat-hf-q0f16") + args.add_argument("--device", type=str, default="auto") + args.add_argument("--batch-size", type=int, default=128) + args.add_argument("--page-size", type=int, default=16) + args.add_argument("--max-total-seq-length", type=int, default=16000) + args.add_argument("--seed", type=int, default=0) + + parsed = args.parse_args() + assert parsed.batch_size % 16 == 0 + assert parsed.page_size == 16 + assert parsed.max_total_seq_length >= 2048 + return parsed + + +def generate_requests( + num_requests: int, input_length: int, output_length: int +) -> Tuple[List[List[int]], List[GenerationConfig]]: + prompt_ids = [] + for _ in range(num_requests): + token_ids = [] + for _ in range(input_length): + token_ids.append(random.randint(0, 30000)) + prompt_ids.append(token_ids) + generation_config_list = [ + GenerationConfig(temperature=1.0, top_p=1.0, max_new_tokens=output_length) + ] * num_requests + return prompt_ids, generation_config_list + + +def benchmark(args: argparse.Namespace): + random.seed(args.seed) + + # Initialize model loading info and KV cache config + model = ModelInfo(args.model_id, args.device) + kv_cache_config = KVCacheConfig( + page_size=args.page_size, + max_num_sequence=args.batch_size, + max_total_sequence_length=args.max_total_seq_length, + ) + + # Create engine + engine = Engine(model, kv_cache_config) + + print(args) + for num_requests in [1, 2, 4, 8, 16, 32, 64]: + if num_requests > args.batch_size: + continue + for input_length in [64, 128, 256, 512, 1024]: + if num_requests * input_length >= 16384: + continue + for output_length in [4]: + print(f"nreq={num_requests}\t" f"in={input_length}\t" f"out={output_length}") + prompt_ids, generation_config = generate_requests( + num_requests, input_length, output_length + ) + engine.reset() + engine.generate(prompt_ids, generation_config) + print() + + +if __name__ == "__main__": + ARGS = _parse_args() + benchmark(ARGS) diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 5104002922..b47595202d 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -3,6 +3,7 @@ from typing import Callable, List, Optional import numpy as np + from mlc_chat.serve import GenerationConfig, KVCacheConfig, Request, data from mlc_chat.serve.engine import Engine, ModelInfo @@ -66,7 +67,7 @@ def test_engine_basic(): # Hyperparameters for tests (you can try different combinations). num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] repetition_penalty = 1.0 # [1.0, 1.01] max_new_tokens: int = 256 # [32, 128, 256] np.random.seed(0) @@ -181,7 +182,7 @@ def step(self) -> None: print(f"Prompt {req_id}: {request.inputs[0]}") print(f"Output {req_id}:{output}\n") assert isinstance(output, data.TextData) - assert fin_time == num_requests + request.generation_config.max_new_tokens - 2 + assert fin_time == request.generation_config.max_new_tokens - 1 def test_engine_continuous_batching_2():