Replies: 4 comments 3 replies
-
A proposed return type from mid-process is a list of
pub struct Branch {
/// How many KV cache entries to remove before any new ones are added.
pub backtrack: u32,
/// If None, no sampling is performed.
/// If Some(vob), only tokens from vob (set of tokens represented as bitvector) are allowed.
pub sample_mask: Option<SimpleVob>,
/// If no sampling, there should be exactly one sequence of tokens to be appended.
/// Otherwise, for every allowed token there can be a sequence that starts with that token -
/// when that starting token is sampled, all the other tokens in that sequence are appended as well.
pub ff_tokens: Vec<Vec<TokenId>>,
} Sampling is performed as follows: seq.pop_kv_cache(b.backtrack);
let to_append = match b.sample_mask {
Some(mask) => {
let tok = seq.sample_with_mask(mask);
match b.ff_tokens.iter().find(|t| t[0] == tok) {
Some(toks) => toks,
None => vec![tok],
}
}
None => {
assert!(b.ff_tokens.len() == 1);
b.ff_tokens[0]
}
};
seq.append_tokens(to_append); Or the same thing in Python: class Branch:
backtrack: int
sample_mask: TokenSet | None
ff_tokens: list[list[Token]]
def sample(branches: list[Branch], seq):
if len(branches) == 0:
seq.stop()
else:
seqs = [seq] + [seq.fork() for _ in range(len(branches) - 1)]
for (idx, (seq, b)) in enumerate(zip(seqs, branches)):
seq.branch_idx = idx
seq.last_backtrack = b.backtrack
seq.remove_tokens(b.backtrack)
if b.sample_mask is None:
assert len(b.ff_tokens) == 1
seq.last_append = b.ff_tokens[0]
else:
tok = seq.sample_with_bias(b.sample_mask)
ff_tokens = (t for t in b.ff_tokens if t[0] == tok)
seq.last_append = next(ff_tokens, b.ff_tokens[0])
seq.append_tokens(seq.last_append)
# we remember branch_idx, last_backtrack and last_append for
# the next mid_process() invocation
# callback provided by the controller:
def mid_process(branch_idx: int, last_backtrack: int, last_append: list[Token]) -> list[Branch]:
... Note that |
Beta Was this translation helpful? Give feedback.
-
Another comment regarding performance: Speculative decoding uses a draft model (10-100x smaller) with the same tokenizer to generate a number of tokens (say 5 or 10), and then using the main model to validate the guess of the draft model in parallel. When placing constraints on the output, we want to do it on the small model as well as the main one. However, the time bounds on the draft model are going to be much tighter. |
Beta Was this translation helpful? Give feedback.
-
To further avoid pipeline stalls, should
|
Beta Was this translation helpful? Give feedback.
-
This has been implemented in #92. Note that the pyctrl/jsctrl still use post_process() as a client-side abstraction. However, that post_process() method is only called from the Wasm-level mid_process() callback (either at the beginning, handling previous round, or at the end when tokens are deterministic). |
Beta Was this translation helpful? Give feedback.
-
The
pre_process
andpost_process
callbacks currently run in the critical path of inference - we have measured the overhead at about 0.3ms per token in rLLM, however it may be worse with Python-based LLM infrastructure. The overhead is primarily the inter-process communication delay (especially the fact that the OS can decide the de-schedule one of the involved processes).A solution would be only leave the
mid_process
callback, and add possible return values for it that indicate that the further generation needs to be forked, or that the current token needs to be discarded (it's kind of supported already via backtrack=1).The downside is that certain operations may incur a one-token overhead in some cases:
Many of these can be mitigated to some extent (eg., when requesting fork we could return splice commands for each branch, when returning a small set of allowed tokens, we could say "if token X is selected, then fast-forward by YZW).
It would be also impossible to directly implement lock-step generation between different forks, making certain beam-search approaches harder.
The advantage is much simpler interface and no overhead, which might be easier sell for LLM infrastructure folks.
Beta Was this translation helpful? Give feedback.
All reactions