From 6d57417456737c4253d99ebedea473ddbf516aac Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 17 Nov 2023 15:38:21 -0500 Subject: [PATCH] [Serving][Refactor] EngineState and EngineAction Following #1277, this PR continues to refactor the engine. Particularly, this PR introduces the `EngineState` object to denote the running state of an Engine, and the `EngineAction` abstraction to denote the action (such as prefill, decode, speculate, etc.) that the engine can take at each time step. The EngineState consists of * the queues of running requests, waiting requests and the requests to be aborted, * the request states of all requests, * the engine runtime statistics. The EngineAction contains the core interface ```c++ bool Step(EngineState estate) ``` which takes the current engine state as input and then update the engine state after taking the action. --- cpp/serve/engine.cc | 513 +----------------- cpp/serve/engine_actions/abort_requests.cc | 64 +++ cpp/serve/engine_actions/action.cc | 16 + cpp/serve/engine_actions/action.h | 93 ++++ cpp/serve/engine_actions/action_commons.cc | 79 +++ cpp/serve/engine_actions/action_commons.h | 47 ++ cpp/serve/engine_actions/batch_decode.cc | 166 ++++++ .../engine_actions/new_request_prefill.cc | 220 ++++++++ .../{engine_stats.cc => engine_state.cc} | 20 +- cpp/serve/engine_state.h | 103 ++++ cpp/serve/engine_stats.h | 53 -- 11 files changed, 832 insertions(+), 542 deletions(-) create mode 100644 cpp/serve/engine_actions/abort_requests.cc create mode 100644 cpp/serve/engine_actions/action.cc create mode 100644 cpp/serve/engine_actions/action.h create mode 100644 cpp/serve/engine_actions/action_commons.cc create mode 100644 cpp/serve/engine_actions/action_commons.h create mode 100644 cpp/serve/engine_actions/batch_decode.cc create mode 100644 cpp/serve/engine_actions/new_request_prefill.cc rename cpp/serve/{engine_stats.cc => engine_state.cc} (72%) create mode 100644 cpp/serve/engine_state.h delete mode 100644 cpp/serve/engine_stats.h diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 967e460800..04a414f510 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -12,7 +12,9 @@ #include #include "../tokenizers.h" -#include "engine_stats.h" +#include "engine_actions/action.h" +#include "engine_actions/action_commons.h" +#include "engine_state.h" #include "model.h" #include "request.h" #include "request_state.h" @@ -48,10 +50,6 @@ class EngineModule; * - engine management, * - high-level request management, * - engine "step" action. - * - * The internal implementation of Engine has the following categories: - * - internal request management, - * - actions and request schedule policy (such as prefill, decode, etc.) */ class Engine { friend class EngineModule; @@ -99,17 +97,19 @@ class Engine { // requirement of speculative encoding. sampler_ = Sampler::Create(/*sampler_kind=*/"cpu"); tokenizer_ = TokenizerFromPath(model_paths[0]); + // Step 6. Initialize action lists. + action_abort_request_ = EngineAction::AbortRequest(models_); + action_new_request_prefill_ = EngineAction::NewRequestPrefill( + models_, sampler_, kv_cache_config_, max_single_sequence_length_); + action_batch_decode_ = EngineAction::BatchDecode(models_, sampler_); ResetEngine(); } /*! \brief Reset the engine, clean up all running data and statistics. */ void ResetEngine() { - running_queue_.clear(); - waiting_queue_.clear(); - abort_queue_.clear(); - request_states_.clear(); - stats_.Reset(); + ICHECK(estate_.defined()); + estate_->Reset(); for (Model model : models_) { model->Reset(); } @@ -126,12 +126,12 @@ class Engine { request = Request::FromUntokenized(request, tokenizer_); ICHECK_NE(request->input_total_length, -1); // Append to the waiting queue and create the request state. - waiting_queue_.push_back(request); - request_states_.emplace(request->id, RequestState(request, models_.size())); + estate_->waiting_queue.push_back(request); + estate_->request_states.emplace(request->id, RequestState(request, models_.size())); } /*! \brief Abort the input request. */ - void AbortRequest(Request request) { abort_queue_.push_back(request); } + void AbortRequest(Request request) { estate_->abort_queue.push_back(request); } /*********************** Engine Action ***********************/ @@ -146,500 +146,39 @@ class Engine { * generation results for those finished requests. */ void Step() { - // - Abort requests. - while (!abort_queue_.empty()) { - StepAbort(abort_queue_.front()); - abort_queue_.erase(abort_queue_.begin()); - } + // - Action 0. Abort requests. + action_abort_request_->Step(estate_); // - Action 1. Prefill the front-most waiting request. - bool prefill_processed = StepPrefill(); + bool prefill_processed = action_new_request_prefill_->Step(estate_); if (prefill_processed) { return; } // - Action 2. Run decode step. - bool decode_processed = StepDecode(); + bool decode_processed = action_batch_decode_->Step(estate_); if (decode_processed) { - ProcessFinishedRequest(); + ProcessFinishedRequest(estate_, models_, tokenizer_, max_single_sequence_length_); return; } - ICHECK(running_queue_.empty()) + ICHECK(estate_->running_queue.empty()) << "Not taking any action in a step is not expected with running requests."; } private: - /***************** Internal Request Management *****************/ - - /*! \brief Assign the given internal id for the given request. */ - void AssignIDForRequest(Request request, int req_id) { - // Set internal id in the request state. - RequestState state = request_states_.at(request->id); - for (RequestModelState mstate : state->mstates) { - mstate->request_id = req_id; - } - // Add a new sequence to each model. - for (int i = 0; i < static_cast(models_.size()); ++i) { - int seq_id_in_model = models_[i]->AddNewSequence(); - ICHECK_EQ(seq_id_in_model, req_id); - } - } - - /*! - * \brief Remove the given request from models and update request states. - * \param req_id The internal id of the request to remove. - */ - void RemoveRequestFromModel(int req_id) { - // Remove the request from all models (usually the KV cache). - for (Model model : models_) { - model->RemoveSequence(req_id); - } - // Update the internal request id of other requests. - for (auto& it : request_states_) { - RequestState state = it.second; - for (RequestModelState mstate : state->mstates) { - ICHECK_NE(mstate->request_id, req_id); - if (mstate->request_id > req_id) { - --mstate->request_id; - } - } - } - } - - /*! - * \brief Preempt the generation of the given request, moving - * it from running request set to the foremost of waiting - * request queue. - */ - void PreemptRequest(std::vector::iterator request_it) { - Request request = *request_it; - - // Remove from models. - // - Reset `request_id` of states. - // - Clear model speculation draft. - // - Update `inputs` for future prefill. - RequestState state = request_states_.at(request->id); - int req_id = state->mstates[0]->request_id; - stats_.current_total_seq_len -= - request->input_total_length + state->mstates[0]->committed_tokens.size() - 1; - for (RequestModelState mstate : state->mstates) { - mstate->request_id = -1; - mstate->draft_output_tokens.clear(); - mstate->draft_output_token_prob.clear(); - mstate->draft_output_prob_dist.clear(); - ICHECK(mstate->inputs.empty()); - ICHECK(!mstate->committed_tokens.empty()); - - Array inputs = request->inputs; - if (const auto* token_input = inputs.back().as()) { - // Merge the TokenData so that a single time TokenEmbed is needed. - std::vector token_ids{token_input->token_ids->data, - token_input->token_ids->data + token_input->token_ids.size()}; - token_ids.insert(token_ids.end(), mstate->committed_tokens.begin(), - mstate->committed_tokens.end()); - inputs.Set(inputs.size() - 1, TokenData(token_ids)); - } else { - inputs.push_back(TokenData(mstate->committed_tokens)); - } - mstate->inputs = std::move(inputs); - } - RemoveRequestFromModel(req_id); - - // Move from running queue to the front of waiting queue. - running_queue_.erase(request_it); - waiting_queue_.insert(waiting_queue_.begin(), request); - } - - /*! - * \brief For each request, check if the request has finished - * its generation. And update the state and return the generation - * result for the finished requests. - * \note This function removes requests from the running request - * queue. - */ - void ProcessFinishedRequest() { - // - Collect finished requests. - // We don't remove on the fly to avoid concurrent modification. - std::vector request_to_remove; - for (Request request : running_queue_) { - if (request_states_.at(request->id)->GenerationFinished(max_single_sequence_length_)) { - request_to_remove.push_back(request); - } - } - - // - Remove the finished request. - for (Request request : request_to_remove) { - // Remove from running queue. - auto it = std::find(running_queue_.begin(), running_queue_.end(), request); - ICHECK(it != running_queue_.end()); - running_queue_.erase(it); - - // Update engine states. - RequestState state = request_states_.at(request->id); - int req_id = state->mstates[0]->request_id; - for (RequestModelState mstate : state->mstates) { - ICHECK_EQ(mstate->request_id, req_id); - mstate->request_id = -1; - } - RemoveRequestFromModel(req_id); - request_states_.erase(request->id); - - // Update engine statistics. - int num_input_tokens = request->input_total_length; - int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1; - stats_.current_total_seq_len -= num_input_tokens + num_output_tokens; - auto trequest_finish = std::chrono::high_resolution_clock::now(); - stats_.request_total_prefill_time += - static_cast((state->tprefill_finish - state->tadd).count()) / 1e9; - stats_.total_prefill_length += num_input_tokens; - stats_.request_total_decode_time += - static_cast((trequest_finish - state->tprefill_finish).count()) / 1e9; - stats_.total_decode_length += num_output_tokens; - - // NOTE: right now we only return the generated text. - // In the future we might optional return text or token ids. - String output = tokenizer_->Decode(state->mstates[0]->committed_tokens); - request->fcallback(request, TextData(output)); - } - } - - /************** Engine Actions and Request Schedule Policy **************/ - - /********************* - * Action 1. Prefill * - *********************/ - - /*! \brief Pick applicable requests and run prefill. */ - bool StepPrefill() { - auto [requests, states, sample_new_token] = GetRequestsToPrefill(); - if (requests.empty()) { - return false; - } - - auto tstart = std::chrono::high_resolution_clock::now(); - - for (Request request : requests) { - int req_id = running_queue_.size(); - auto it = std::find(waiting_queue_.begin(), waiting_queue_.end(), request); - if (it == waiting_queue_.end()) { - continue; - } - - // - Move request from waiting queue to running queue. - waiting_queue_.erase(it); - running_queue_.push_back(request); - // - Assign request id for the requests. - AssignIDForRequest(request, req_id); - } - - int sum_prefill_lengths = 0; - NDArray logits_for_sample{nullptr}; - Array mstates_for_sample; - Array generation_cfg_for_sample; - mstates_for_sample.reserve(requests.size()); - generation_cfg_for_sample.reserve(requests.size()); - for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - Model model = models_[model_id]; - auto [request_list, mstates, prefill_lengths] = - FilterPrefillRequests(requests, states, model_id); - Array embeddings; - std::vector request_ids; - embeddings.reserve(request_list.size()); - request_ids.reserve(request_list.size()); - for (int i = 0; i < static_cast(request_list.size()); ++i) { - Request request = request_list[i]; - int prefill_length = prefill_lengths[i]; - RequestModelState mstate = mstates[i]; - if (model_id == 0) { - // Accumulate the sequence length. - sum_prefill_lengths += prefill_length; - stats_.current_total_seq_len += prefill_length; - mstates_for_sample.push_back(mstate); - generation_cfg_for_sample.push_back(request->generation_cfg); - } - ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_token_prob.empty()); - ICHECK(mstate->draft_output_prob_dist.empty()); - ICHECK(!mstate->inputs.empty()); - request_ids.push_back(mstate->request_id); - for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - embeddings.push_back(mstate->inputs[i]->GetEmbedding(model)); - } - // Clean up `inputs` after prefill - mstate->inputs.clear(); - } - - NDArray logits = model->BatchPrefill(embeddings, request_ids, prefill_lengths); - ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], request_list.size()); - - if (model_id == 0) { - // We only need to sample for model 0 in prefill. - logits_for_sample = logits; - } - } - - if (sample_new_token) { - // - Sample tokens. - int num_requests = requests.size(); - ICHECK(logits_for_sample.defined()); - ICHECK_EQ(logits_for_sample->shape[1], num_requests); - ICHECK_EQ(mstates_for_sample.size(), num_requests); - ICHECK_EQ(generation_cfg_for_sample.size(), num_requests); - logits_for_sample = logits_for_sample.CreateView( - {num_requests, 1, logits_for_sample->shape[2]}, logits_for_sample->dtype); - std::vector next_tokens = sampler_->SampleTokens( - logits_for_sample, models_[0], mstates_for_sample, generation_cfg_for_sample); - ICHECK_EQ(next_tokens.size(), num_requests); - // - Update the committed tokens of states. - // - If a request is first-time prefilled, set the prefill finish time. - auto tnow = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < num_requests; ++i) { - mstates_for_sample[i]->committed_tokens.push_back(next_tokens[i]); - if (mstates_for_sample[i]->committed_tokens.size() == 1) { - request_states_.at(requests[i]->id)->tprefill_finish = tnow; - } - } - } - - auto tend = std::chrono::high_resolution_clock::now(); - stats_.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; - - return true; - } - - /*! - * \brief Find one or multiple requests to run prefill. - * \return The requests to prefill. For each request, we - * additionally return a boolean flag indicating if a new - * token needs to be sampled from logits after prefill. - */ - std::tuple, Array, bool> GetRequestsToPrefill() { - // - Try to prefill pending requests. - std::vector prefill_requests; - std::vector states; - if (!waiting_queue_.empty()) { - int total_input_length = 0; - int total_required_pages = 0; - int num_available_pages = models_[0]->GetNumAvailablePages(); - - for (int i = 0; i < static_cast(waiting_queue_.size()); ++i) { - Request request = waiting_queue_[i]; - RequestState state = request_states_.at(request->id); - int input_length = state->mstates[0]->GetInputLength(); - int num_require_pages = - (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - if (CanPrefill(i + 1, total_input_length, total_required_pages, num_available_pages)) { - prefill_requests.push_back(request); - states.push_back(state); - } else { - total_input_length -= input_length; - total_required_pages -= num_require_pages; - break; - } - } - if (!prefill_requests.empty()) { - // Need to sample a new token for waiting requests. - return {prefill_requests, states, true}; - } - } - - // Try to prefill for small models. - for (Request request : running_queue_) { - RequestState state = request_states_.at(request->id); - Array mstates = state->mstates; - for (int i = 0; i < static_cast(mstates.size()); ++i) { - if (!mstates[i]->inputs.empty()) { - ICHECK_NE(i, 0); - prefill_requests.push_back(request); - states.push_back(state); - break; - } - } - } - // This return happens only for "small" models in - // speculative inference settings. - // Therefore no need to sample new token from logits. - return {prefill_requests, states, false}; - } - - /*! \brief Check if the input requests can be prefilled under conditions. */ - bool CanPrefill(int num_prefill_req, int total_input_length, int num_required_pages, - int num_available_pages) { - int num_running_requests = running_queue_.size(); - ICHECK_LE(num_running_requests, kv_cache_config_->max_num_sequence); - - // No exceeding of the maximum allowed requests that can - // run simultaneously. - if (num_running_requests + num_prefill_req > kv_cache_config_->max_num_sequence) { - return false; - } - - // NOTE: The conditions are heuristic and can be revised. - // Cond 1: total input length <= max allowed single sequence length. - // Cond 2: at least one decode can be performed after prefill. - // Cond 3: number of total tokens after 8 times of decode does not - // exceed the limit, where 8 is a watermark number can - // be configured and adjusted in the future. - int new_batch_size = num_running_requests + num_prefill_req; - return total_input_length <= max_single_sequence_length_ && - num_required_pages + new_batch_size <= num_available_pages && - stats_.current_total_seq_len + total_input_length + 8 * new_batch_size <= - kv_cache_config_->max_total_sequence_length; - } - - /*! \brief Filter the requests to prefill on the given model. */ - std::tuple, Array, std::vector> FilterPrefillRequests( - Array requests, Array states, int model_id) { - ICHECK_EQ(requests.size(), states.size()); - int num_requests = requests.size(); - Array filtered_requests; - Array filtered_mstates; - std::vector prefill_length; - filtered_requests.reserve(num_requests); - filtered_mstates.reserve(num_requests); - prefill_length.reserve(num_requests); - - for (int i = 0; i < num_requests; ++i) { - int length = states[i]->mstates[model_id]->GetInputLength(); - if (length > 0) { - filtered_requests.push_back(requests[i]); - filtered_mstates.push_back(states[i]->mstates[model_id]); - prefill_length.push_back(length); - } - } - return {filtered_requests, filtered_mstates, prefill_length}; - } - - /******************** - * Action 2. Decode * - ********************/ - - /*! \brief Pick applicable requests and run decode. */ - bool StepDecode() { - // - Do not run decode when there are multiple models. - if (models_.size() > 1) { - return false; - } - - if (running_queue_.empty()) { - return false; - } - - // Preempt requests when decode cannot apply. - while (!CanDecode(running_queue_.size())) { - PreemptRequest(running_queue_.end() - 1); - } - - auto tstart = std::chrono::high_resolution_clock::now(); - - // NOTE: Right now we only support decode all the running requests at a time. - int num_requests = running_queue_.size(); - // Check if the requests ids are in an ascending order. - for (int i = 1; i < num_requests; ++i) { - ICHECK_GT(request_states_.at(running_queue_[i]->id)->mstates[0]->request_id, - request_states_.at(running_queue_[i - 1]->id)->mstates[0]->request_id); - } - - stats_.current_total_seq_len += num_requests; - // Collect - // - the last committed token, - // - the request states, - // - the sampling parameters, - // of each request. - std::vector input_tokens; - Array mstates; - Array generation_cfg; - input_tokens.reserve(num_requests); - mstates.reserve(num_requests); - generation_cfg.reserve(num_requests); - for (Request request : running_queue_) { - RequestState state = request_states_.at(request->id); - input_tokens.push_back(state->mstates[0]->committed_tokens.back()); - mstates.push_back(state->mstates[0]); - generation_cfg.push_back(request->generation_cfg); - } - - // - Compute embeddings. - NDArray embeddings = - models_[0]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}}); - ICHECK_EQ(embeddings->ndim, 3); - ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], num_requests); - embeddings = embeddings.CreateView({num_requests, 1, embeddings->shape[2]}, embeddings->dtype); - - // - Invoke model decode. - NDArray logits = models_[0]->BatchDecode(embeddings); - ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], embeddings->shape[0]); - ICHECK_EQ(logits->shape[1], 1); - - // - Sample tokens. - std::vector next_tokens = - sampler_->SampleTokens(logits, models_[0], mstates, generation_cfg); - ICHECK_EQ(next_tokens.size(), num_requests); - - // - Update the committed tokens of states. - for (int i = 0; i < num_requests; ++i) { - mstates[i]->committed_tokens.push_back(next_tokens[i]); - } - - auto tend = std::chrono::high_resolution_clock::now(); - stats_.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; - - return true; - } - - /*! \brief Check if the input requests can be decoded under conditions. */ - bool CanDecode(int num_requests) { - int num_available_pages = models_[0]->GetNumAvailablePages(); - return num_requests <= num_available_pages; - } - - /******************* - * Action 3. Abort * - *******************/ - - /*! \brief Abort the generation of the given request. */ - void StepAbort(Request request) { - auto it_running = std::find(running_queue_.begin(), running_queue_.end(), request); - auto it_waiting = std::find(waiting_queue_.begin(), waiting_queue_.end(), request); - ICHECK(it_running != running_queue_.end() || it_waiting != waiting_queue_.end()); - if (it_running != running_queue_.end()) { - // The request to abort is in running queue - int req_id = it_running - running_queue_.begin(); - running_queue_.erase(it_running); - RequestState state = request_states_.at(request->id); - stats_.current_total_seq_len -= - request->input_total_length + state->mstates[0]->committed_tokens.size() - 1; - RemoveRequestFromModel(req_id); - } else { - // The request to abort is in waiting queue - waiting_queue_.erase(it_waiting); - } - request_states_.erase(request->id); - } - - /***************** Engine Data Structures *****************/ - - // Request queues - std::vector running_queue_; - std::vector waiting_queue_; - std::vector abort_queue_; - // Request states - std::unordered_map request_states_; + // Engine state, managing requests and request states. + EngineState estate_; // Models, sampler and tokenizer. Array models_; Sampler sampler_; std::unique_ptr tokenizer_; - // Runtime statistics - EngineStats stats_; + // Engine actions. + EngineAction action_abort_request_; + EngineAction action_new_request_prefill_; + EngineAction action_batch_decode_; // Configurations KVCacheConfig kv_cache_config_; @@ -714,7 +253,7 @@ class EngineModule : public ModuleNode { } else if (name == "stats") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 0); - *rv = GetEngine()->stats_.AsJSON(); + *rv = GetEngine()->estate_->stats.AsJSON(); }); } else if (name == "reset") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { diff --git a/cpp/serve/engine_actions/abort_requests.cc b/cpp/serve/engine_actions/abort_requests.cc new file mode 100644 index 0000000000..e7b62acf7d --- /dev/null +++ b/cpp/serve/engine_actions/abort_requests.cc @@ -0,0 +1,64 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/abort_requests.cc + */ + +#include "../config.h" +#include "../model.h" +#include "../sampler.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that aborts the requests in the `abort_queue` + * of the engine state. + */ +class AbortRequestActionObj : public EngineActionObj { + public: + explicit AbortRequestActionObj(Array models) : models_(std::move(models)) {} + + bool Step(EngineState estate) final { + // Abort all requests in the abort queue. + while (!estate->abort_queue.empty()) { + // - Check if the request is running or pending. + Request request = estate->abort_queue.front(); + auto it_running = + std::find(estate->running_queue.begin(), estate->running_queue.end(), request); + auto it_waiting = + std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), request); + ICHECK(it_running != estate->running_queue.end() || + it_waiting != estate->waiting_queue.end()); + + if (it_running != estate->running_queue.end()) { + // The request to abort is in running queue + int req_id = it_running - estate->running_queue.begin(); + estate->running_queue.erase(it_running); + RequestState state = estate->GetRequestState(request); + estate->stats.current_total_seq_len -= + request->input_total_length + state->mstates[0]->committed_tokens.size() - 1; + RemoveRequestFromModel(estate, req_id, models_); + } else { + // The request to abort is in waiting queue + estate->waiting_queue.erase(it_waiting); + } + estate->request_states.erase(request->id); + } + return true; + } + + private: + /*! \brief The models where the requests to abort also need to be removed from. */ + Array models_; +}; + +EngineAction EngineAction::AbortRequest(Array models) { + return EngineAction(make_object(std::move(models))); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/action.cc b/cpp/serve/engine_actions/action.cc new file mode 100644 index 0000000000..8a0580eb86 --- /dev/null +++ b/cpp/serve/engine_actions/action.cc @@ -0,0 +1,16 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/action.cc + */ + +#include "action.h" + +namespace mlc { +namespace llm { +namespace serve { + +TVM_REGISTER_OBJECT_TYPE(EngineActionObj); + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h new file mode 100644 index 0000000000..6cd93771ed --- /dev/null +++ b/cpp/serve/engine_actions/action.h @@ -0,0 +1,93 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/action.h + * \brief The abstraction of actions (e.g., prefill/decode) that an + * Engine can take at each time step. + */ +#ifndef MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ +#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ + +#include "../config.h" +#include "../engine_state.h" +#include "../model.h" +#include "../sampler.h" + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief The abstraction of actions that an Engine can take at each time step. + * The only core interface of an action is the `Step` function. + * At high level, the Step function takes the current engine state + * as input, invokes model functions (such as batched-prefill or + * batched-decode), run sampler to sample new tokens, and update + * the engine state. + */ +class EngineActionObj : public Object { + public: + /*! + * \brief The behavior of the engine action in a single step. + * \param estate The engine state to be analyzed and updated. + * \return A boolean indicating if the action is successfully taken. + */ + virtual bool Step(EngineState estate) = 0; + + static constexpr const char* _type_key = "mlc.serve.EngineAction"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(EngineActionObj, Object); +}; + +/*! + * \brief Managed reference of EngineActionObj. + * It declares the full list of supported actions. + * \sa EngineActionObj + */ +class EngineAction : public ObjectRef { + public: + /*! + * \brief Create the action that aborts the requests in the `abort_queue` + * of the engine state. + * \param models The models where the requests to abort also need + * to be removed from. + * \return The created action object. + */ + static EngineAction AbortRequest(Array models); + /*! + * \brief Create the action that prefills requests in the `waiting_queue` + * of the engine state. + * \param models The models to run prefill in. + * \param sampler The sampler to sample new tokens. + * \param kv_cache_config The KV cache config to help decide prefill is doable. + * \param max_single_sequence_length The max single sequence length to help + * decide if prefill is doable. + * \return The created action object. + */ + static EngineAction NewRequestPrefill(Array models, Sampler sampler, + KVCacheConfig kv_cache_config, + int max_single_sequence_length); + /*! + * \brief Create the action that runs one-step decode for requests in the + * `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + * \note The BatchDecode action **does not** take effect for speculative + * decoding scenarios where there are multiple models. For speculative + * decoding in the future, we will use other specific actions. + * \param models The model to run decode in. When there are multiple + * models, the `Step` function of the created action will not take effect. + * \param sampler The sampler to sample new tokens. + * \return The created action object. + */ + static EngineAction BatchDecode(Array models, Sampler sampler); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj); +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc new file mode 100644 index 0000000000..7556b76bcb --- /dev/null +++ b/cpp/serve/engine_actions/action_commons.cc @@ -0,0 +1,79 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/action_commons.cc + */ + +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +void RemoveRequestFromModel(EngineState estate, int req_id, Array models) { + // Remove the request from all models (usually the KV cache). + for (Model model : models) { + model->RemoveSequence(req_id); + } + // Update the internal request id of other requests. + for (auto& it : estate->request_states) { + RequestState state = it.second; + for (RequestModelState mstate : state->mstates) { + ICHECK_NE(mstate->request_id, req_id); + if (mstate->request_id > req_id) { + --mstate->request_id; + } + } + } +} + +void ProcessFinishedRequest(EngineState estate, Array models, + const std::unique_ptr& tokenizer, + int max_single_sequence_length) { + // - Collect finished requests. + // We don't remove on the fly to avoid concurrent modification. + std::vector request_to_remove; + for (Request request : estate->running_queue) { + if (estate->GetRequestState(request)->GenerationFinished(max_single_sequence_length)) { + request_to_remove.push_back(request); + } + } + + // - Remove the finished request. + for (Request request : request_to_remove) { + // Remove from running queue. + auto it = std::find(estate->running_queue.begin(), estate->running_queue.end(), request); + ICHECK(it != estate->running_queue.end()); + estate->running_queue.erase(it); + + // Update engine states. + RequestState state = estate->GetRequestState(request); + int req_id = state->mstates[0]->request_id; + for (RequestModelState mstate : state->mstates) { + ICHECK_EQ(mstate->request_id, req_id); + mstate->request_id = -1; + } + RemoveRequestFromModel(estate, req_id, models); + estate->request_states.erase(request->id); + + // Update engine statistics. + int num_input_tokens = request->input_total_length; + int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1; + estate->stats.current_total_seq_len -= num_input_tokens + num_output_tokens; + auto trequest_finish = std::chrono::high_resolution_clock::now(); + estate->stats.request_total_prefill_time += + static_cast((state->tprefill_finish - state->tadd).count()) / 1e9; + estate->stats.total_prefill_length += num_input_tokens; + estate->stats.request_total_decode_time += + static_cast((trequest_finish - state->tprefill_finish).count()) / 1e9; + estate->stats.total_decode_length += num_output_tokens; + + // NOTE: right now we only return the generated text. + // In the future we might optional return text or token ids. + String output = tokenizer->Decode(state->mstates[0]->committed_tokens); + request->fcallback(request, TextData(output)); + } +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h new file mode 100644 index 0000000000..c8367a0b48 --- /dev/null +++ b/cpp/serve/engine_actions/action_commons.h @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/action_commons.h + * \brief Common functions that may be used in multiple EngineActions. + */ +#ifndef MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ +#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ + +#include "../../tokenizers.h" +#include "../engine_state.h" +#include "../model.h" + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief Remove the given request from models. + * \param estate The engine state to update after removal. + * \param req_id The internal id of the request to remove. + * \param models The models to remove the given request from. + */ +void RemoveRequestFromModel(EngineState estate, int req_id, Array models); + +/*! + * \brief For each request in the `running_queue` of the engine state, + * check if the request has finished its generation. Update the state + * and return the generation result via request callback for the finished + * requests. + * \note This function removes requests from the `running_queue`. + * \param estate The engine state. + * \param models The models to remove the finished from. + * \param tokenizer The tokenizer used to decode the generated tokens of requests. + * \param max_single_sequence_length The max single sequence length to help decide + * if a request is finished. + */ +void ProcessFinishedRequest(EngineState estate, Array models, + const std::unique_ptr& tokenizer, + int max_single_sequence_length); + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc new file mode 100644 index 0000000000..5da056fff4 --- /dev/null +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -0,0 +1,166 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/batch_decode.cc + */ + +#include "../config.h" +#include "../model.h" +#include "../sampler.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that runs one-step decode for requests in the + * `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + * \note The BatchDecode action **does not** take effect for speculative + * decoding scenarios where there are multiple models. For speculative + * decoding in the future, we will use other specific actions. + */ +class BatchDecodeActionObj : public EngineActionObj { + public: + explicit BatchDecodeActionObj(Array models, Sampler sampler) + : models_(std::move(models)), sampler_(std::move(sampler)) {} + + bool Step(EngineState estate) final { + // - Do not run decode when there are multiple models or no running requests. + if (models_.size() > 1 || estate->running_queue.empty()) { + return false; + } + + // Preempt requests when decode cannot apply. + int num_available_pages = models_[0]->GetNumAvailablePages(); + while (!CanDecode(estate->running_queue.size())) { + PreemptLastRunningRequest(estate); + } + + auto tstart = std::chrono::high_resolution_clock::now(); + + // NOTE: Right now we only support decode all the running requests at a time. + int num_requests = estate->running_queue.size(); + // Check if the requests ids are in an ascending order. + for (int i = 1; i < num_requests; ++i) { + ICHECK_GT(estate->GetRequestState(estate->running_queue[i])->mstates[0]->request_id, + estate->GetRequestState(estate->running_queue[i - 1])->mstates[0]->request_id); + } + + estate->stats.current_total_seq_len += num_requests; + // Collect + // - the last committed token, + // - the request states, + // - the sampling parameters, + // of each request. + std::vector input_tokens; + Array mstates; + Array generation_cfg; + input_tokens.reserve(num_requests); + mstates.reserve(num_requests); + generation_cfg.reserve(num_requests); + for (Request request : estate->running_queue) { + RequestState rstate = estate->GetRequestState(request); + input_tokens.push_back(rstate->mstates[0]->committed_tokens.back()); + mstates.push_back(rstate->mstates[0]); + generation_cfg.push_back(request->generation_cfg); + } + + // - Compute embeddings. + NDArray embeddings = + models_[0]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}}); + ICHECK_EQ(embeddings->ndim, 3); + ICHECK_EQ(embeddings->shape[0], 1); + ICHECK_EQ(embeddings->shape[1], num_requests); + embeddings = embeddings.CreateView({num_requests, 1, embeddings->shape[2]}, embeddings->dtype); + + // - Invoke model decode. + NDArray logits = models_[0]->BatchDecode(embeddings); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], embeddings->shape[0]); + ICHECK_EQ(logits->shape[1], 1); + + // - Sample tokens. + std::vector next_tokens = + sampler_->SampleTokens(logits, models_[0], mstates, generation_cfg); + ICHECK_EQ(next_tokens.size(), num_requests); + + // - Update the committed tokens of states. + for (int i = 0; i < num_requests; ++i) { + mstates[i]->committed_tokens.push_back(next_tokens[i]); + } + + auto tend = std::chrono::high_resolution_clock::now(); + estate->stats.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; + + return true; + } + + private: + /*! \brief Check if the input requests can be decoded under conditions. */ + bool CanDecode(int num_requests) { + int num_available_pages = models_[0]->GetNumAvailablePages(); + return num_requests <= num_available_pages; + } + + /*! + * \brief Preempt the last running requests from `running_queue`, + * moving it from running request set to the foremost of waiting + * request queue. + */ + void PreemptLastRunningRequest(EngineState estate) { + Request request = estate->running_queue.back(); + + // Remove from models. + // - Reset internal `request_id` of states. + // - Clear model speculation draft. + // - Update `inputs` for future prefill. + RequestState rstate = estate->GetRequestState(request); + int req_id = rstate->mstates[0]->request_id; + estate->stats.current_total_seq_len -= + request->input_total_length + rstate->mstates[0]->committed_tokens.size() - 1; + for (RequestModelState mstate : rstate->mstates) { + mstate->request_id = -1; + mstate->draft_output_tokens.clear(); + mstate->draft_output_token_prob.clear(); + mstate->draft_output_prob_dist.clear(); + ICHECK(mstate->inputs.empty()); + ICHECK(!mstate->committed_tokens.empty()); + + Array inputs = request->inputs; + if (const auto* token_input = inputs.back().as()) { + // Merge the TokenData so that a single time TokenEmbed is needed. + std::vector token_ids{token_input->token_ids->data, + token_input->token_ids->data + token_input->token_ids.size()}; + token_ids.insert(token_ids.end(), mstate->committed_tokens.begin(), + mstate->committed_tokens.end()); + inputs.Set(inputs.size() - 1, TokenData(token_ids)); + } else { + inputs.push_back(TokenData(mstate->committed_tokens)); + } + mstate->inputs = std::move(inputs); + } + RemoveRequestFromModel(estate, req_id, models_); + + // Move from running queue to the front of waiting queue. + estate->running_queue.erase(estate->running_queue.end() - 1); + estate->waiting_queue.insert(estate->waiting_queue.begin(), request); + } + + /*! + * \brief The model to run decode in. When there are multiple + * models, the `Step` function of the created action will not take effect. + */ + Array models_; + /*! \brief The sampler to sample new tokens. */ + Sampler sampler_; +}; + +EngineAction EngineAction::BatchDecode(Array models, Sampler sampler) { + return EngineAction(make_object(std::move(models), std::move(sampler))); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc new file mode 100644 index 0000000000..54f18a3614 --- /dev/null +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -0,0 +1,220 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/new_request_prefill.cc + */ + +#include "../config.h" +#include "../model.h" +#include "../sampler.h" +#include "action.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that prefills requests in the `waiting_queue` of + * the engine state. + */ +class NewRequestPrefillActionObj : public EngineActionObj { + public: + explicit NewRequestPrefillActionObj(Array models, Sampler sampler, + KVCacheConfig kv_cache_config, int max_single_sequence_length) + : models_(std::move(models)), + sampler_(std::move(sampler)), + kv_cache_config_(std::move(kv_cache_config)), + max_single_sequence_length_(max_single_sequence_length) {} + + bool Step(EngineState estate) final { + // - Find the requests in `waiting_queue` that can prefill in this step. + auto [requests, rstates, prefill_lengths] = GetRequestsToPrefill(estate); + ICHECK_EQ(requests.size(), rstates.size()); + ICHECK_EQ(requests.size(), prefill_lengths.size()); + if (requests.empty()) { + return false; + } + + int num_requests = requests.size(); + auto tstart = std::chrono::high_resolution_clock::now(); + + // - Move requests from waiting queue to running queue. + // And assign internal ID for the requests. + std::vector request_ids; + request_ids.reserve(num_requests); + for (int i = 0; i < num_requests; ++i) { + int req_id = estate->running_queue.size(); + auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), requests[i]); + ICHECK(it != estate->waiting_queue.end()); + + // - Move request from waiting queue to running queue. + estate->waiting_queue.erase(it); + estate->running_queue.push_back(requests[i]); + // - Assign internal request id for the requests. + AssignInternalIDForRequest(rstates[i], requests[i], req_id); + request_ids.push_back(req_id); + } + + // - Get embedding and run prefill for each model. + NDArray logits_for_sample{nullptr}; + for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { + Array embeddings; + embeddings.reserve(num_requests); + for (int i = 0; i < num_requests; ++i) { + RequestModelState mstate = rstates[i]->mstates[model_id]; + ICHECK_EQ(mstate->GetInputLength(), prefill_lengths[i]); + ICHECK(mstate->draft_output_tokens.empty()); + ICHECK(mstate->draft_output_token_prob.empty()); + ICHECK(mstate->draft_output_prob_dist.empty()); + ICHECK(!mstate->inputs.empty()); + for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { + embeddings.push_back(mstate->inputs[i]->GetEmbedding(models_[model_id])); + } + // Clean up `inputs` after prefill + mstate->inputs.clear(); + } + + NDArray logits = models_[model_id]->BatchPrefill(embeddings, request_ids, prefill_lengths); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], 1); + ICHECK_EQ(logits->shape[1], num_requests); + + if (model_id == 0) { + // We only need to sample for model 0 in prefill. + logits_for_sample = logits; + } + } + + // - Sample tokens. + ICHECK(logits_for_sample.defined()); + logits_for_sample = logits_for_sample.CreateView({num_requests, 1, logits_for_sample->shape[2]}, + logits_for_sample->dtype); + Array mstates_for_sample = + rstates.Map([](RequestState rstate) { return rstate->mstates[0]; }); + std::vector next_tokens = sampler_->SampleTokens( + logits_for_sample, models_[0], mstates_for_sample, + requests.Map([](Request request) { return request->generation_cfg; })); + ICHECK_EQ(next_tokens.size(), num_requests); + + // - Update the committed tokens of states. + // - If a request is first-time prefilled, set the prefill finish time. + // - Accumulate the sequence length in engine statistics. + int sum_prefill_lengths = 0; + auto tnow = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < num_requests; ++i) { + mstates_for_sample[i]->committed_tokens.push_back(next_tokens[i]); + if (mstates_for_sample[i]->committed_tokens.size() == 1) { + estate->GetRequestState(requests[i])->tprefill_finish = tnow; + } + sum_prefill_lengths += prefill_lengths[i]; + } + estate->stats.current_total_seq_len += sum_prefill_lengths; + + auto tend = std::chrono::high_resolution_clock::now(); + estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; + + return true; + } + + private: + /*! + * \brief Find one or multiple requests to run prefill. + * \param estate The engine state. + * \return The requests to prefill, together with their respective + * state and input length. + */ + std::tuple, Array, std::vector> GetRequestsToPrefill( + EngineState estate) { + if (estate->waiting_queue.empty()) { + // No request to prefill. + return {{}, {}, {}}; + } + + // - Try to prefill pending requests. + std::vector prefill_requests; + std::vector rstates; + std::vector prefill_lengths; + int total_input_length = 0; + int total_required_pages = 0; + int num_available_pages = models_[0]->GetNumAvailablePages(); + + for (int i = 1; i <= static_cast(estate->waiting_queue.size()); ++i) { + Request request = estate->waiting_queue[i - 1]; + RequestState rstate = estate->GetRequestState(request); + int input_length = rstate->mstates[0]->GetInputLength(); + int num_require_pages = + (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (CanPrefill(estate, i, total_input_length, total_required_pages, num_available_pages)) { + prefill_requests.push_back(request); + rstates.push_back(rstate); + prefill_lengths.push_back(input_length); + } else { + total_input_length -= input_length; + total_required_pages -= num_require_pages; + break; + } + } + + return {prefill_requests, rstates, prefill_lengths}; + } + + /*! \brief Check if the input requests can be prefilled under conditions. */ + bool CanPrefill(EngineState estate, int num_prefill_req, int total_input_length, + int num_required_pages, int num_available_pages) { + int num_running_requests = estate->running_queue.size(); + ICHECK_LE(num_running_requests, kv_cache_config_->max_num_sequence); + + // No exceeding of the maximum allowed requests that can + // run simultaneously. + if (num_running_requests + num_prefill_req > kv_cache_config_->max_num_sequence) { + return false; + } + + // NOTE: The conditions are heuristic and can be revised. + // Cond 1: total input length <= max allowed single sequence length. + // Cond 2: at least one decode can be performed after prefill. + // Cond 3: number of total tokens after 8 times of decode does not + // exceed the limit, where 8 is a watermark number can + // be configured and adjusted in the future. + int new_batch_size = num_running_requests + num_prefill_req; + return total_input_length <= max_single_sequence_length_ && + num_required_pages + new_batch_size <= num_available_pages && + estate->stats.current_total_seq_len + total_input_length + 8 * new_batch_size <= + kv_cache_config_->max_total_sequence_length; + } + + /*! \brief Assign the given internal id for the given request. */ + void AssignInternalIDForRequest(RequestState rstate, Request request, int req_id) { + // Set internal id in the request state. + for (RequestModelState mstate : rstate->mstates) { + mstate->request_id = req_id; + } + // Add a new sequence to each model. + for (int i = 0; i < static_cast(models_.size()); ++i) { + int seq_id_in_model = models_[i]->AddNewSequence(); + ICHECK_EQ(seq_id_in_model, req_id); + } + } + + /*! \brief The models to run prefill in. */ + Array models_; + /*! \brief The sampler to sample new tokens. */ + Sampler sampler_; + /*! \brief The KV cache config to help decide prefill is doable. */ + KVCacheConfig kv_cache_config_; + /*! \brief The max single sequence length to help decide if prefill is doable. */ + int max_single_sequence_length_; +}; + +EngineAction EngineAction::NewRequestPrefill(Array models, Sampler sampler, + KVCacheConfig kv_cache_config, + int max_single_sequence_length) { + return EngineAction(make_object(std::move(models), std::move(sampler), + std::move(kv_cache_config), + max_single_sequence_length)); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_stats.cc b/cpp/serve/engine_state.cc similarity index 72% rename from cpp/serve/engine_stats.cc rename to cpp/serve/engine_state.cc index 8339008fba..1c88663be9 100644 --- a/cpp/serve/engine_stats.cc +++ b/cpp/serve/engine_state.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/engine_stats.cc + * \file serve/engine_state.cc */ #define PICOJSON_USE_INT64 -#include "engine_stats.h" +#include "engine_state.h" #include @@ -35,6 +35,22 @@ void EngineStats::Reset() { total_decode_length = 0; } +TVM_REGISTER_OBJECT_TYPE(EngineStateObj); + +EngineState::EngineState() { data_ = make_object(); } + +void EngineStateObj::Reset() { + running_queue.clear(); + waiting_queue.clear(); + abort_queue.clear(); + request_states.clear(); + stats.Reset(); +} + +RequestState EngineStateObj::GetRequestState(Request request) { + return request_states.at(request->id); +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine_state.h b/cpp/serve/engine_state.h new file mode 100644 index 0000000000..25d94fcab0 --- /dev/null +++ b/cpp/serve/engine_state.h @@ -0,0 +1,103 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_state.h + */ +#ifndef MLC_LLM_SERVE_ENGINE_STATE_H_ +#define MLC_LLM_SERVE_ENGINE_STATE_H_ + +#include + +#include "request.h" +#include "request_state.h" + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! \brief Runtime statistics of engine. */ +struct EngineStats { + /*! \brief The current total sequence length in the first model. */ + int64_t current_total_seq_len; + /*! \brief The sum of "prefill time of each request". */ + double request_total_prefill_time = 0.0f; + /*! \brief The sum of "decode time of each request". */ + double request_total_decode_time = 0.0f; + /*! \brief The total engine time on prefill. */ + double engine_total_prefill_time = 0.0f; + /*! \brief The total engine time on decode. */ + double engine_total_decode_time = 0.0f; + /*! \brief The total number of processed tokens in prefill. */ + int64_t total_prefill_length = 0; + /*! \brief The total number of processed tokens in decode. */ + int64_t total_decode_length = 0; + + /*! + * \brief Return the engine runtime statistics in JSON string. + * We collect the following entries: + * - single token prefill latency (s/tok): avg latency of processing one token in prefill + * - single token decode latency (s/tok): avg latency of processing one token in decode + * - engine time for prefill (sec) + * - engine time for decode (sec) + * - total number of processed tokens in prefill. + * - total number of processed tokens in decode. + * \return The statistics in JSON string. + */ + String AsJSON() const; + /*! \brief Reset all the statistics. */ + void Reset(); +}; + +/*! + * \brief The state of the running engine. + * It contains the requests and their states submitted to the Engine. + */ +class EngineStateObj : public Object { + public: + /*! \brief The requests being processed. */ + std::vector running_queue; + /*! \brief The requests that have not started for process yet. */ + std::vector waiting_queue; + /*! \brief The requests to abort. */ + std::vector abort_queue; + /*! \brief The states of all requests. */ + std::unordered_map request_states; + /*! \brief Runtime statistics. */ + EngineStats stats; + + /*! \brief Reset the engine state and clear the statistics. */ + void Reset(); + /*! \brief Get the request state of the given request. */ + RequestState GetRequestState(Request request); + + static constexpr const char* _type_key = "mlc.serve.EngineState"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(EngineStateObj, Object); +}; + +/*! + * \brief Managed reference of EngineStateObj. + * \sa EngineStateObj + */ +class EngineState : public ObjectRef { + public: + explicit EngineState(); + + // Default constructors. + EngineState(const EngineState& other) = default; + EngineState(EngineState&& other) = default; + EngineState& operator=(const EngineState& other) = default; + EngineState& operator=(EngineState&& other) = default; + + explicit EngineState(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ObjectRef(n) {} + EngineStateObj* operator->() const { return static_cast(data_.get()); } + using ContainerType = EngineStateObj; +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_ENGINE_STATE_H_ diff --git a/cpp/serve/engine_stats.h b/cpp/serve/engine_stats.h deleted file mode 100644 index 6aa3f7397f..0000000000 --- a/cpp/serve/engine_stats.h +++ /dev/null @@ -1,53 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors - * \file serve/engine_stats.h - */ -#ifndef MLC_LLM_SERVE_ENGINE_STATS_H_ -#define MLC_LLM_SERVE_ENGINE_STATS_H_ - -#include - -namespace mlc { -namespace llm { -namespace serve { - -using namespace tvm::runtime; - -/*! \brief Runtime statistics of engine. */ -struct EngineStats { - /*! \brief The current total sequence length in the first model. */ - int64_t current_total_seq_len; - /*! \brief The sum of "prefill time of each request". */ - double request_total_prefill_time = 0.0f; - /*! \brief The sum of "decode time of each request". */ - double request_total_decode_time = 0.0f; - /*! \brief The total engine time on prefill. */ - double engine_total_prefill_time = 0.0f; - /*! \brief The total engine time on decode. */ - double engine_total_decode_time = 0.0f; - /*! \brief The total number of processed tokens in prefill. */ - int64_t total_prefill_length = 0; - /*! \brief The total number of processed tokens in decode. */ - int64_t total_decode_length = 0; - - /*! - * \brief Return the engine runtime statistics in JSON string. - * We collect the following entries: - * - single token prefill latency (s/tok): avg latency of processing one token in prefill - * - single token decode latency (s/tok): avg latency of processing one token in decode - * - engine time for prefill (sec) - * - engine time for decode (sec) - * - total number of processed tokens in prefill. - * - total number of processed tokens in decode. - * \return The statistics in JSON string. - */ - String AsJSON() const; - /*! \brief Reset all the statistics. */ - void Reset(); -}; - -} // namespace serve -} // namespace llm -} // namespace mlc - -#endif // MLC_LLM_SERVE_ENGINE_STATS_H_