Skip to content

Commit

Permalink
[Serving] Support batched prefill and benchmark
Browse files Browse the repository at this point in the history
This PR supports the current serving framework with batched prefill,
which helps improve the throughput of prefill.

Some data structures are tweaked for less runtime overhead.

This PR also brings the benchmark of serving engine with real-time
dataset as input.
  • Loading branch information
MasterJH5574 committed Nov 11, 2023
1 parent a54b4bd commit 8a671fa
Show file tree
Hide file tree
Showing 11 changed files with 555 additions and 275 deletions.
389 changes: 229 additions & 160 deletions cpp/serve/engine.cc

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ void FunctionTable::_InitFunctions() {
get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device");
this->remove_from_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_remove");
this->popn_from_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn");
this->get_num_available_pages_kv_cache_func_ =
get_global_func("vm.builtin.paged_attention_kv_cache_get_num_available_pages");
support_backtracking_kv_ = true;
}

Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct FunctionTable {
PackedFunc sync_device_kv_cache_func_;
PackedFunc remove_from_kv_cache_func_;
PackedFunc popn_from_kv_cache_func_;
PackedFunc get_num_available_pages_kv_cache_func_;
};

} // namespace serve
Expand Down
155 changes: 97 additions & 58 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ class ModelModule : public ModuleNode {
CHECK_EQ(args.size(), 1);
*rv = TokenEmbed(args[0]);
});
} else if (name == "single_seq_prefill") {
} else if (name == "batch_prefill") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 2);
*rv = SingleSequencePrefill(args[0], args[1]);
CHECK_EQ(args.size(), 3);
*rv = BatchPrefill(args[0], args[1], args[2]);
});
} else if (name == "decode") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
Expand Down Expand Up @@ -115,15 +115,18 @@ class ModelModule : public ModuleNode {
ICHECK_EQ(args.size(), 0);
Reset();
});
} else if (name == "get_num_available_pages") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 0);
ICHECK(kv_cache_.defined());
*rv = ft_.get_num_available_pages_kv_cache_func_(kv_cache_);
});
} else if (name == "get_max_window_size") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 0);
CHECK_NE(max_window_size_, -1) << "The model has not been initialized";
*rv = max_window_size_;
});
} else if (name == "runtime_stats_text") {
// Todo: JSON style
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = GetStats(); });
} else {
return PackedFunc(nullptr);
}
Expand Down Expand Up @@ -156,7 +159,8 @@ class ModelModule : public ModuleNode {
}
// Copy input token ids to device.
DLDataType dtype(DataType::Int(32));
NDArray token_ids_nd = CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, 2048);
NDArray token_ids_nd =
CopyArrayToDevice(flattened_token_ids, &input_token_ids_, dtype, max_window_size_);
ICHECK_EQ(token_ids_nd->ndim, 1);
ICHECK_EQ(token_ids_nd->shape[0], total_length);
token_ids_nd = token_ids_nd.CreateView({1, total_length}, dtype);
Expand All @@ -165,12 +169,7 @@ class ModelModule : public ModuleNode {
<< "`embed` function is not found in the model. Please make sure the model is compiled "
"with flag `--sep-embed` and `--enable-batching`";

auto tstart = std::chrono::high_resolution_clock::now();
NDArray embeddings = ft_.embed_func_(ft_.CopyToWorker0(token_ids_nd), params_);
auto tend = std::chrono::high_resolution_clock::now();

this->embed_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->embed_total_tokens += total_length;

// embeddings: (1, total_length, hidden_size)
ICHECK_EQ(embeddings->ndim, 3);
Expand All @@ -183,14 +182,33 @@ class ModelModule : public ModuleNode {
* \brief Single-sequence prefill function. Embedding in, logits out.
* \param embeddings The embedding of the input to be prefilled.
* \param seq_id The id of the sequence in the KV cache.
* \param lengths The length of each sequence to prefill.
* \return The logits for the next token.
*/
NDArray SingleSequencePrefill(NDArray embeddings, int seq_id) {
NDArray BatchPrefill(Array<NDArray> embedding_arr, ShapeTuple seq_ids, ShapeTuple lengths) {
CHECK(!seq_ids.empty());
CHECK_EQ(seq_ids.size(), lengths.size());
int num_sequences = seq_ids.size();
int total_length = 0;
std::vector<int> logit_pos;
logit_pos.reserve(num_sequences);
for (int i = 0; i < num_sequences; ++i) {
total_length += lengths[i];
logit_pos.push_back(total_length);
if (i > 0) {
CHECK_GT(seq_ids[i], seq_ids[i - 1]) << "The input sequence ids must be non-decreasing.";
}
}

// embeddings: (1, n, h)
CHECK_EQ(embeddings->ndim, 3);
CHECK_EQ(embeddings->shape[0], 1);
CHECK_EQ(embeddings->device.device_type, device_.device_type);
CHECK_EQ(embeddings->device.device_id, device_.device_id);
NDArray embeddings = ConcatEmbeddings(std::move(embedding_arr), total_length);
ICHECK_EQ(embeddings->ndim, 3);
ICHECK_EQ(embeddings->shape[0], 1);
ICHECK_EQ(embeddings->shape[1], total_length);
ICHECK_EQ(embeddings->device.device_type, device_.device_type);
ICHECK_EQ(embeddings->device.device_id, device_.device_id);

NDArray logit_pos_nd = CopyArrayToDevice(logit_pos, &logit_pos_arr_, DataType::Int(32), 32);

CHECK(ft_.prefill_func_.defined())
<< "`prefill_with_embed` function is not found in the model. Please make sure the model is "
Expand All @@ -202,22 +220,20 @@ class ModelModule : public ModuleNode {

// Reserve in KV cache for the length of the input.
ft_.reset_append_length_kv_cache_func_(kv_cache_);
ft_.reserve_length_in_kv_cache_func_(kv_cache_, seq_id, /*length=*/embeddings->shape[1]);
for (int i = 0; i < num_sequences; ++i) {
ft_.reserve_length_in_kv_cache_func_(kv_cache_, seq_ids[i], lengths[i]);
}
ft_.sync_device_kv_cache_func_(kv_cache_);

auto tstart = std::chrono::high_resolution_clock::now();
// args: embeddings, kv_cache, params
Array<ObjectRef> ret = ft_.prefill_func_(ft_.CopyToWorker0(embeddings), kv_cache_, params_);
auto tend = std::chrono::high_resolution_clock::now();

this->prefill_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += embeddings->shape[1];
// args: embeddings, logit_pos, kv_cache, params
Array<ObjectRef> ret =
ft_.prefill_func_(ft_.CopyToWorker0(embeddings), logit_pos_nd, kv_cache_, params_);

// logits: (1, 1, v)
// logits: (1, num_sequences, v)
NDArray logits = Downcast<NDArray>(ret[0]);
ICHECK_EQ(logits->ndim, 3);
ICHECK_EQ(logits->shape[0], 1);
ICHECK_EQ(logits->shape[1], 1);
ICHECK_EQ(logits->shape[1], num_sequences);
return logits;
}

Expand Down Expand Up @@ -251,13 +267,8 @@ class ModelModule : public ModuleNode {
}
ft_.sync_device_kv_cache_func_(kv_cache_);

auto tstart = std::chrono::high_resolution_clock::now();
// args: embeddings, kv_cache, params
Array<ObjectRef> ret = ft_.decode_func_(ft_.CopyToWorker0(embeddings), kv_cache_, params_);
auto tend = std::chrono::high_resolution_clock::now();

this->decode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->decode_total_tokens += embeddings->shape[0];

// logits: (b, 1, v)
NDArray logits = Downcast<NDArray>(ret[0]);
Expand Down Expand Up @@ -286,7 +297,7 @@ class ModelModule : public ModuleNode {
for (GenerationConfig cfg : generation_cfg) {
temperatures.push_back(cfg->temperature);
}
NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 16);
NDArray temperatures_nd = CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 32);
ICHECK_EQ(temperatures_nd->ndim, 1);
ICHECK_EQ(temperatures_nd->shape[0], batch_size);

Expand Down Expand Up @@ -318,6 +329,57 @@ class ModelModule : public ModuleNode {
return view;
}

/*! \brief Concatenate the input embeddings. */
NDArray ConcatEmbeddings(Array<NDArray> embedding_arr, int64_t total_length) {
ICHECK(!embedding_arr.empty());
int hidden_size = -1;
DataType dtype;
for (NDArray inp_embeddings : embedding_arr) {
// inp_embedding: (1, n, h)
CHECK_EQ(inp_embeddings->ndim, 3);
CHECK_EQ(inp_embeddings->shape[0], 1);
CHECK_EQ(inp_embeddings->device.device_type, device_.device_type);
CHECK_EQ(inp_embeddings->device.device_id, device_.device_id);
if (hidden_size == -1) {
hidden_size = inp_embeddings->shape[2];
dtype = inp_embeddings.DataType();
} else {
CHECK_EQ(inp_embeddings->shape[2], hidden_size);
CHECK_EQ(inp_embeddings.DataType(), dtype);
}
}

// - Resize the shared embedding array.
if (embeddings_.defined()) {
ICHECK_EQ(embeddings_->ndim, 3);
ICHECK_EQ(embeddings_->shape[0], 1);
ICHECK_EQ(embeddings_->shape[2], hidden_size);
}
int64_t init_size = embeddings_.defined() ? embeddings_->shape[1] : max_window_size_;
while (init_size < total_length) {
init_size *= 2;
}
if (!embeddings_.defined() || init_size != embeddings_->shape[1]) {
embeddings_ = NDArray::Empty({1, init_size, hidden_size}, dtype, device_);
}

// - Copy input embeddings.
int64_t start_pos = 0;
for (NDArray inp_embeddings : embedding_arr) {
int64_t length = inp_embeddings->shape[1];
CHECK_LE(start_pos + length, total_length);

DLTensor copy_dst = *(embeddings_.operator->());
copy_dst.byte_offset = start_pos * hidden_size * dtype.bytes();
copy_dst.shape = inp_embeddings->shape;
NDArray::CopyFromTo(inp_embeddings.operator->(), &copy_dst);

start_pos += length;
}
CHECK_EQ(start_pos, total_length);
return embeddings_.CreateView({1, total_length, hidden_size}, dtype);
}

/*! \brief Load model configuration from JSON. */
void LoadModelConfigJSON(const std::string& config_str) {
picojson::value config_json;
Expand Down Expand Up @@ -350,37 +412,12 @@ class ModelModule : public ModuleNode {

/*! \brief reset the runtime states. */
void Reset() {
// Reset the statistics.
this->embed_total_tokens = 0;
this->prefill_total_tokens = 0;
this->decode_total_tokens = 0;
this->embed_total_time = 0;
this->prefill_total_time = 0;
this->decode_total_time = 0;
// Reset the KV cache.
if (kv_cache_.defined()) {
ft_.reset_kv_cache_func_(kv_cache_);
}
}

/*! \brief Return statistics in JSON format. */
String GetStats() {
picojson::object stats;
stats["prefill_speed"] = picojson::value(prefill_total_tokens / prefill_total_time);
stats["decode_speed"] = picojson::value(decode_total_tokens / decode_total_time);
stats["embed_speed"] = picojson::value(embed_total_tokens / embed_total_time);
return picojson::value(stats).serialize(true);
}

//----------------------------
// Statistics
//----------------------------
double embed_total_time = 0;
double decode_total_time = 0;
double prefill_total_time = 0;
int64_t embed_total_tokens = 0;
int64_t decode_total_tokens = 0;
int64_t prefill_total_tokens = 0;
//----------------------------
// Model configurations
//----------------------------
Expand All @@ -400,6 +437,8 @@ class ModelModule : public ModuleNode {
ObjectRef params_;
// Shared NDArray
NDArray input_token_ids_{nullptr};
NDArray embeddings_{nullptr};
NDArray logit_pos_arr_{nullptr};
NDArray temperature_arr_{nullptr};
};

Expand Down
15 changes: 15 additions & 0 deletions cpp/serve/request_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ RequestModelState::RequestModelState(int model_id, Array<Data> inputs) {
data_ = std::move(n);
}

TVM_REGISTER_OBJECT_TYPE(RequestStateNode);

RequestState::RequestState(int num_models, Array<Data> inputs, int raw_input_length) {
ObjectPtr<RequestStateNode> n = make_object<RequestStateNode>();
Array<RequestModelState> mstates;
mstates.reserve(num_models);
for (int i = 0; i < num_models; ++i) {
mstates.push_back(RequestModelState(i, inputs));
}
n->mstates = std::move(mstates);
n->raw_input_length = raw_input_length;
n->tadd = std::chrono::high_resolution_clock::now();
data_ = std::move(n);
}

} // namespace serve
} // namespace llm
} // namespace mlc
23 changes: 15 additions & 8 deletions cpp/serve/request_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@ class RequestModelState : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestModelState, ObjectRef, RequestModelStateNode);
};

struct RequestState {
class RequestStateNode : public Object {
public:
/*!
* \brief The state with regard to each model.
* \sa RequestModelState
*/
Array<RequestModelState> mstates;

/*! \brief The summed up input length of the request. */
int raw_input_length = 0;
/*! \brief The decoded text string output. */
std::string output = "";

Expand All @@ -101,13 +104,17 @@ struct RequestState {
/*! \brief The time of finishing prefill stage. */
std::chrono::_V2::system_clock::time_point tprefill_finish;

explicit RequestState(int num_models, Array<Data> inputs) {
mstates.reserve(num_models);
for (int i = 0; i < num_models; ++i) {
mstates.push_back(RequestModelState(i, inputs));
}
tadd = std::chrono::high_resolution_clock::now();
}
static constexpr const char* _type_key = "mlc.serve.RequestState";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(RequestStateNode, Object);
};

class RequestState : public ObjectRef {
public:
explicit RequestState(int num_models, Array<Data> inputs, int raw_input_length);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestState, ObjectRef, RequestStateNode);
};

} // namespace serve
Expand Down
12 changes: 6 additions & 6 deletions python/mlc_chat/serve/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ def reset(self) -> None:
def stats(self) -> Dict[str, float]:
"""The engine runtime statistics.
We collect the following entries:
- prefill token latency (s/tok)
avg latency of processing one token in prefill
- decode token latency (s/tok)
avg latency of processing one token in decode
- token throughput (tok/s)
avg number of tokens processed per second (prefill + decode)
- single token prefill latency (s/tok): avg latency of processing one token in prefill
- single token decode latency (s/tok): avg latency of processing one token in decode
- engine time for prefill (sec)
- engine time for decode (sec)
- total number of processed tokens in prefill.
- total number of processed tokens in decode.
"""
stats_json_str = self._get_stats_func()
stats = json.loads(stats_json_str)
Expand Down
Loading

0 comments on commit 8a671fa

Please sign in to comment.