Skip to content

Commit

Permalink
[Serving][Refactor] EngineState and EngineAction
Browse files Browse the repository at this point in the history
Following mlc-ai#1277, this PR continues to refactor the engine.

Particularly, this PR introduces the `EngineState` object
to denote the running state of an Engine, and the `EngineAction`
abstraction to denote the action (such as prefill, decode,
speculate, etc.) that the engine can take at each time step.

The EngineState consists of
* the queues of running requests, waiting requests and the requests
to be aborted,
* the request states of all requests,
* the engine runtime statistics.

The EngineAction contains the core interface
```c++
bool Step(EngineState estate)
```
which takes the current engine state as input and then update
the engine state after taking the action.
  • Loading branch information
MasterJH5574 committed Nov 17, 2023
1 parent 11ef0cd commit 6d57417
Show file tree
Hide file tree
Showing 11 changed files with 832 additions and 542 deletions.
513 changes: 26 additions & 487 deletions cpp/serve/engine.cc

Large diffs are not rendered by default.

64 changes: 64 additions & 0 deletions cpp/serve/engine_actions/abort_requests.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*!
* Copyright (c) 2023 by Contributors
* \file serve/engine_actions/abort_requests.cc
*/

#include "../config.h"
#include "../model.h"
#include "../sampler.h"
#include "action.h"
#include "action_commons.h"

namespace mlc {
namespace llm {
namespace serve {

/*!
* \brief The action that aborts the requests in the `abort_queue`
* of the engine state.
*/
class AbortRequestActionObj : public EngineActionObj {
public:
explicit AbortRequestActionObj(Array<Model> models) : models_(std::move(models)) {}

bool Step(EngineState estate) final {
// Abort all requests in the abort queue.
while (!estate->abort_queue.empty()) {
// - Check if the request is running or pending.
Request request = estate->abort_queue.front();
auto it_running =
std::find(estate->running_queue.begin(), estate->running_queue.end(), request);
auto it_waiting =
std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), request);
ICHECK(it_running != estate->running_queue.end() ||
it_waiting != estate->waiting_queue.end());

if (it_running != estate->running_queue.end()) {
// The request to abort is in running queue
int req_id = it_running - estate->running_queue.begin();
estate->running_queue.erase(it_running);
RequestState state = estate->GetRequestState(request);
estate->stats.current_total_seq_len -=
request->input_total_length + state->mstates[0]->committed_tokens.size() - 1;
RemoveRequestFromModel(estate, req_id, models_);
} else {
// The request to abort is in waiting queue
estate->waiting_queue.erase(it_waiting);
}
estate->request_states.erase(request->id);
}
return true;
}

private:
/*! \brief The models where the requests to abort also need to be removed from. */
Array<Model> models_;
};

EngineAction EngineAction::AbortRequest(Array<Model> models) {
return EngineAction(make_object<AbortRequestActionObj>(std::move(models)));
}

} // namespace serve
} // namespace llm
} // namespace mlc
16 changes: 16 additions & 0 deletions cpp/serve/engine_actions/action.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*!
* Copyright (c) 2023 by Contributors
* \file serve/engine_actions/action.cc
*/

#include "action.h"

namespace mlc {
namespace llm {
namespace serve {

TVM_REGISTER_OBJECT_TYPE(EngineActionObj);

} // namespace serve
} // namespace llm
} // namespace mlc
93 changes: 93 additions & 0 deletions cpp/serve/engine_actions/action.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*!
* Copyright (c) 2023 by Contributors
* \file serve/engine_actions/action.h
* \brief The abstraction of actions (e.g., prefill/decode) that an
* Engine can take at each time step.
*/
#ifndef MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_

#include "../config.h"
#include "../engine_state.h"
#include "../model.h"
#include "../sampler.h"

namespace mlc {
namespace llm {
namespace serve {

using namespace tvm::runtime;

/*!
* \brief The abstraction of actions that an Engine can take at each time step.
* The only core interface of an action is the `Step` function.
* At high level, the Step function takes the current engine state
* as input, invokes model functions (such as batched-prefill or
* batched-decode), run sampler to sample new tokens, and update
* the engine state.
*/
class EngineActionObj : public Object {
public:
/*!
* \brief The behavior of the engine action in a single step.
* \param estate The engine state to be analyzed and updated.
* \return A boolean indicating if the action is successfully taken.
*/
virtual bool Step(EngineState estate) = 0;

static constexpr const char* _type_key = "mlc.serve.EngineAction";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(EngineActionObj, Object);
};

/*!
* \brief Managed reference of EngineActionObj.
* It declares the full list of supported actions.
* \sa EngineActionObj
*/
class EngineAction : public ObjectRef {
public:
/*!
* \brief Create the action that aborts the requests in the `abort_queue`
* of the engine state.
* \param models The models where the requests to abort also need
* to be removed from.
* \return The created action object.
*/
static EngineAction AbortRequest(Array<Model> models);
/*!
* \brief Create the action that prefills requests in the `waiting_queue`
* of the engine state.
* \param models The models to run prefill in.
* \param sampler The sampler to sample new tokens.
* \param kv_cache_config The KV cache config to help decide prefill is doable.
* \param max_single_sequence_length The max single sequence length to help
* decide if prefill is doable.
* \return The created action object.
*/
static EngineAction NewRequestPrefill(Array<Model> models, Sampler sampler,
KVCacheConfig kv_cache_config,
int max_single_sequence_length);
/*!
* \brief Create the action that runs one-step decode for requests in the
* `running_queue` of engine state. Preempt low-priority requests
* accordingly when it is impossible to decode all the running requests.
* \note The BatchDecode action **does not** take effect for speculative
* decoding scenarios where there are multiple models. For speculative
* decoding in the future, we will use other specific actions.
* \param models The model to run decode in. When there are multiple
* models, the `Step` function of the created action will not take effect.
* \param sampler The sampler to sample new tokens.
* \return The created action object.
*/
static EngineAction BatchDecode(Array<Model> models, Sampler sampler);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj);
};

} // namespace serve
} // namespace llm
} // namespace mlc

#endif // MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_
79 changes: 79 additions & 0 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*!
* Copyright (c) 2023 by Contributors
* \file serve/engine_actions/action_commons.cc
*/

#include "action_commons.h"

namespace mlc {
namespace llm {
namespace serve {

void RemoveRequestFromModel(EngineState estate, int req_id, Array<Model> models) {
// Remove the request from all models (usually the KV cache).
for (Model model : models) {
model->RemoveSequence(req_id);
}
// Update the internal request id of other requests.
for (auto& it : estate->request_states) {
RequestState state = it.second;
for (RequestModelState mstate : state->mstates) {
ICHECK_NE(mstate->request_id, req_id);
if (mstate->request_id > req_id) {
--mstate->request_id;
}
}
}
}

void ProcessFinishedRequest(EngineState estate, Array<Model> models,
const std::unique_ptr<Tokenizer>& tokenizer,
int max_single_sequence_length) {
// - Collect finished requests.
// We don't remove on the fly to avoid concurrent modification.
std::vector<Request> request_to_remove;
for (Request request : estate->running_queue) {
if (estate->GetRequestState(request)->GenerationFinished(max_single_sequence_length)) {
request_to_remove.push_back(request);
}
}

// - Remove the finished request.
for (Request request : request_to_remove) {
// Remove from running queue.
auto it = std::find(estate->running_queue.begin(), estate->running_queue.end(), request);
ICHECK(it != estate->running_queue.end());
estate->running_queue.erase(it);

// Update engine states.
RequestState state = estate->GetRequestState(request);
int req_id = state->mstates[0]->request_id;
for (RequestModelState mstate : state->mstates) {
ICHECK_EQ(mstate->request_id, req_id);
mstate->request_id = -1;
}
RemoveRequestFromModel(estate, req_id, models);
estate->request_states.erase(request->id);

// Update engine statistics.
int num_input_tokens = request->input_total_length;
int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1;
estate->stats.current_total_seq_len -= num_input_tokens + num_output_tokens;
auto trequest_finish = std::chrono::high_resolution_clock::now();
estate->stats.request_total_prefill_time +=
static_cast<double>((state->tprefill_finish - state->tadd).count()) / 1e9;
estate->stats.total_prefill_length += num_input_tokens;
estate->stats.request_total_decode_time +=
static_cast<double>((trequest_finish - state->tprefill_finish).count()) / 1e9;
estate->stats.total_decode_length += num_output_tokens;

// NOTE: right now we only return the generated text.
// In the future we might optional return text or token ids.
String output = tokenizer->Decode(state->mstates[0]->committed_tokens);
request->fcallback(request, TextData(output));
}
}

} // namespace serve
} // namespace llm
} // namespace mlc
47 changes: 47 additions & 0 deletions cpp/serve/engine_actions/action_commons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*!
* Copyright (c) 2023 by Contributors
* \file serve/engine_actions/action_commons.h
* \brief Common functions that may be used in multiple EngineActions.
*/
#ifndef MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_

#include "../../tokenizers.h"
#include "../engine_state.h"
#include "../model.h"

namespace mlc {
namespace llm {
namespace serve {

using namespace tvm::runtime;

/*!
* \brief Remove the given request from models.
* \param estate The engine state to update after removal.
* \param req_id The internal id of the request to remove.
* \param models The models to remove the given request from.
*/
void RemoveRequestFromModel(EngineState estate, int req_id, Array<Model> models);

/*!
* \brief For each request in the `running_queue` of the engine state,
* check if the request has finished its generation. Update the state
* and return the generation result via request callback for the finished
* requests.
* \note This function removes requests from the `running_queue`.
* \param estate The engine state.
* \param models The models to remove the finished from.
* \param tokenizer The tokenizer used to decode the generated tokens of requests.
* \param max_single_sequence_length The max single sequence length to help decide
* if a request is finished.
*/
void ProcessFinishedRequest(EngineState estate, Array<Model> models,
const std::unique_ptr<Tokenizer>& tokenizer,
int max_single_sequence_length);

} // namespace serve
} // namespace llm
} // namespace mlc

#endif // MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_
Loading

0 comments on commit 6d57417

Please sign in to comment.