forked from mlc-ai/mlc-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Serving][Refactor] EngineState and EngineAction
Following mlc-ai#1277, this PR continues to refactor the engine. Particularly, this PR introduces the `EngineState` object to denote the running state of an Engine, and the `EngineAction` abstraction to denote the action (such as prefill, decode, speculate, etc.) that the engine can take at each time step. The EngineState consists of * the queues of running requests, waiting requests and the requests to be aborted, * the request states of all requests, * the engine runtime statistics. The EngineAction contains the core interface ```c++ bool Step(EngineState estate) ``` which takes the current engine state as input and then update the engine state after taking the action.
- Loading branch information
1 parent
11ef0cd
commit 6d57417
Showing
11 changed files
with
832 additions
and
542 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/*! | ||
* Copyright (c) 2023 by Contributors | ||
* \file serve/engine_actions/abort_requests.cc | ||
*/ | ||
|
||
#include "../config.h" | ||
#include "../model.h" | ||
#include "../sampler.h" | ||
#include "action.h" | ||
#include "action_commons.h" | ||
|
||
namespace mlc { | ||
namespace llm { | ||
namespace serve { | ||
|
||
/*! | ||
* \brief The action that aborts the requests in the `abort_queue` | ||
* of the engine state. | ||
*/ | ||
class AbortRequestActionObj : public EngineActionObj { | ||
public: | ||
explicit AbortRequestActionObj(Array<Model> models) : models_(std::move(models)) {} | ||
|
||
bool Step(EngineState estate) final { | ||
// Abort all requests in the abort queue. | ||
while (!estate->abort_queue.empty()) { | ||
// - Check if the request is running or pending. | ||
Request request = estate->abort_queue.front(); | ||
auto it_running = | ||
std::find(estate->running_queue.begin(), estate->running_queue.end(), request); | ||
auto it_waiting = | ||
std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), request); | ||
ICHECK(it_running != estate->running_queue.end() || | ||
it_waiting != estate->waiting_queue.end()); | ||
|
||
if (it_running != estate->running_queue.end()) { | ||
// The request to abort is in running queue | ||
int req_id = it_running - estate->running_queue.begin(); | ||
estate->running_queue.erase(it_running); | ||
RequestState state = estate->GetRequestState(request); | ||
estate->stats.current_total_seq_len -= | ||
request->input_total_length + state->mstates[0]->committed_tokens.size() - 1; | ||
RemoveRequestFromModel(estate, req_id, models_); | ||
} else { | ||
// The request to abort is in waiting queue | ||
estate->waiting_queue.erase(it_waiting); | ||
} | ||
estate->request_states.erase(request->id); | ||
} | ||
return true; | ||
} | ||
|
||
private: | ||
/*! \brief The models where the requests to abort also need to be removed from. */ | ||
Array<Model> models_; | ||
}; | ||
|
||
EngineAction EngineAction::AbortRequest(Array<Model> models) { | ||
return EngineAction(make_object<AbortRequestActionObj>(std::move(models))); | ||
} | ||
|
||
} // namespace serve | ||
} // namespace llm | ||
} // namespace mlc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
/*! | ||
* Copyright (c) 2023 by Contributors | ||
* \file serve/engine_actions/action.cc | ||
*/ | ||
|
||
#include "action.h" | ||
|
||
namespace mlc { | ||
namespace llm { | ||
namespace serve { | ||
|
||
TVM_REGISTER_OBJECT_TYPE(EngineActionObj); | ||
|
||
} // namespace serve | ||
} // namespace llm | ||
} // namespace mlc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
/*! | ||
* Copyright (c) 2023 by Contributors | ||
* \file serve/engine_actions/action.h | ||
* \brief The abstraction of actions (e.g., prefill/decode) that an | ||
* Engine can take at each time step. | ||
*/ | ||
#ifndef MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ | ||
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ | ||
|
||
#include "../config.h" | ||
#include "../engine_state.h" | ||
#include "../model.h" | ||
#include "../sampler.h" | ||
|
||
namespace mlc { | ||
namespace llm { | ||
namespace serve { | ||
|
||
using namespace tvm::runtime; | ||
|
||
/*! | ||
* \brief The abstraction of actions that an Engine can take at each time step. | ||
* The only core interface of an action is the `Step` function. | ||
* At high level, the Step function takes the current engine state | ||
* as input, invokes model functions (such as batched-prefill or | ||
* batched-decode), run sampler to sample new tokens, and update | ||
* the engine state. | ||
*/ | ||
class EngineActionObj : public Object { | ||
public: | ||
/*! | ||
* \brief The behavior of the engine action in a single step. | ||
* \param estate The engine state to be analyzed and updated. | ||
* \return A boolean indicating if the action is successfully taken. | ||
*/ | ||
virtual bool Step(EngineState estate) = 0; | ||
|
||
static constexpr const char* _type_key = "mlc.serve.EngineAction"; | ||
static constexpr const bool _type_has_method_sequal_reduce = false; | ||
static constexpr const bool _type_has_method_shash_reduce = false; | ||
TVM_DECLARE_BASE_OBJECT_INFO(EngineActionObj, Object); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference of EngineActionObj. | ||
* It declares the full list of supported actions. | ||
* \sa EngineActionObj | ||
*/ | ||
class EngineAction : public ObjectRef { | ||
public: | ||
/*! | ||
* \brief Create the action that aborts the requests in the `abort_queue` | ||
* of the engine state. | ||
* \param models The models where the requests to abort also need | ||
* to be removed from. | ||
* \return The created action object. | ||
*/ | ||
static EngineAction AbortRequest(Array<Model> models); | ||
/*! | ||
* \brief Create the action that prefills requests in the `waiting_queue` | ||
* of the engine state. | ||
* \param models The models to run prefill in. | ||
* \param sampler The sampler to sample new tokens. | ||
* \param kv_cache_config The KV cache config to help decide prefill is doable. | ||
* \param max_single_sequence_length The max single sequence length to help | ||
* decide if prefill is doable. | ||
* \return The created action object. | ||
*/ | ||
static EngineAction NewRequestPrefill(Array<Model> models, Sampler sampler, | ||
KVCacheConfig kv_cache_config, | ||
int max_single_sequence_length); | ||
/*! | ||
* \brief Create the action that runs one-step decode for requests in the | ||
* `running_queue` of engine state. Preempt low-priority requests | ||
* accordingly when it is impossible to decode all the running requests. | ||
* \note The BatchDecode action **does not** take effect for speculative | ||
* decoding scenarios where there are multiple models. For speculative | ||
* decoding in the future, we will use other specific actions. | ||
* \param models The model to run decode in. When there are multiple | ||
* models, the `Step` function of the created action will not take effect. | ||
* \param sampler The sampler to sample new tokens. | ||
* \return The created action object. | ||
*/ | ||
static EngineAction BatchDecode(Array<Model> models, Sampler sampler); | ||
|
||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj); | ||
}; | ||
|
||
} // namespace serve | ||
} // namespace llm | ||
} // namespace mlc | ||
|
||
#endif // MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/*! | ||
* Copyright (c) 2023 by Contributors | ||
* \file serve/engine_actions/action_commons.cc | ||
*/ | ||
|
||
#include "action_commons.h" | ||
|
||
namespace mlc { | ||
namespace llm { | ||
namespace serve { | ||
|
||
void RemoveRequestFromModel(EngineState estate, int req_id, Array<Model> models) { | ||
// Remove the request from all models (usually the KV cache). | ||
for (Model model : models) { | ||
model->RemoveSequence(req_id); | ||
} | ||
// Update the internal request id of other requests. | ||
for (auto& it : estate->request_states) { | ||
RequestState state = it.second; | ||
for (RequestModelState mstate : state->mstates) { | ||
ICHECK_NE(mstate->request_id, req_id); | ||
if (mstate->request_id > req_id) { | ||
--mstate->request_id; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void ProcessFinishedRequest(EngineState estate, Array<Model> models, | ||
const std::unique_ptr<Tokenizer>& tokenizer, | ||
int max_single_sequence_length) { | ||
// - Collect finished requests. | ||
// We don't remove on the fly to avoid concurrent modification. | ||
std::vector<Request> request_to_remove; | ||
for (Request request : estate->running_queue) { | ||
if (estate->GetRequestState(request)->GenerationFinished(max_single_sequence_length)) { | ||
request_to_remove.push_back(request); | ||
} | ||
} | ||
|
||
// - Remove the finished request. | ||
for (Request request : request_to_remove) { | ||
// Remove from running queue. | ||
auto it = std::find(estate->running_queue.begin(), estate->running_queue.end(), request); | ||
ICHECK(it != estate->running_queue.end()); | ||
estate->running_queue.erase(it); | ||
|
||
// Update engine states. | ||
RequestState state = estate->GetRequestState(request); | ||
int req_id = state->mstates[0]->request_id; | ||
for (RequestModelState mstate : state->mstates) { | ||
ICHECK_EQ(mstate->request_id, req_id); | ||
mstate->request_id = -1; | ||
} | ||
RemoveRequestFromModel(estate, req_id, models); | ||
estate->request_states.erase(request->id); | ||
|
||
// Update engine statistics. | ||
int num_input_tokens = request->input_total_length; | ||
int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1; | ||
estate->stats.current_total_seq_len -= num_input_tokens + num_output_tokens; | ||
auto trequest_finish = std::chrono::high_resolution_clock::now(); | ||
estate->stats.request_total_prefill_time += | ||
static_cast<double>((state->tprefill_finish - state->tadd).count()) / 1e9; | ||
estate->stats.total_prefill_length += num_input_tokens; | ||
estate->stats.request_total_decode_time += | ||
static_cast<double>((trequest_finish - state->tprefill_finish).count()) / 1e9; | ||
estate->stats.total_decode_length += num_output_tokens; | ||
|
||
// NOTE: right now we only return the generated text. | ||
// In the future we might optional return text or token ids. | ||
String output = tokenizer->Decode(state->mstates[0]->committed_tokens); | ||
request->fcallback(request, TextData(output)); | ||
} | ||
} | ||
|
||
} // namespace serve | ||
} // namespace llm | ||
} // namespace mlc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/*! | ||
* Copyright (c) 2023 by Contributors | ||
* \file serve/engine_actions/action_commons.h | ||
* \brief Common functions that may be used in multiple EngineActions. | ||
*/ | ||
#ifndef MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ | ||
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ | ||
|
||
#include "../../tokenizers.h" | ||
#include "../engine_state.h" | ||
#include "../model.h" | ||
|
||
namespace mlc { | ||
namespace llm { | ||
namespace serve { | ||
|
||
using namespace tvm::runtime; | ||
|
||
/*! | ||
* \brief Remove the given request from models. | ||
* \param estate The engine state to update after removal. | ||
* \param req_id The internal id of the request to remove. | ||
* \param models The models to remove the given request from. | ||
*/ | ||
void RemoveRequestFromModel(EngineState estate, int req_id, Array<Model> models); | ||
|
||
/*! | ||
* \brief For each request in the `running_queue` of the engine state, | ||
* check if the request has finished its generation. Update the state | ||
* and return the generation result via request callback for the finished | ||
* requests. | ||
* \note This function removes requests from the `running_queue`. | ||
* \param estate The engine state. | ||
* \param models The models to remove the finished from. | ||
* \param tokenizer The tokenizer used to decode the generated tokens of requests. | ||
* \param max_single_sequence_length The max single sequence length to help decide | ||
* if a request is finished. | ||
*/ | ||
void ProcessFinishedRequest(EngineState estate, Array<Model> models, | ||
const std::unique_ptr<Tokenizer>& tokenizer, | ||
int max_single_sequence_length); | ||
|
||
} // namespace serve | ||
} // namespace llm | ||
} // namespace mlc | ||
|
||
#endif // MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ |
Oops, something went wrong.