Skip to content

Commit

Permalink
llama : simplify Mamba with advanced batch splits (ggerganov#8526)
Browse files Browse the repository at this point in the history
* llama : advanced batch splits

This includes equal-sequence-length batch splits which are useful
to simplify recurrent model operators.

* llama : always make recurrent state slots contiguous

* ggml : simplify mamba operators

* llama : fix integer signedness mixing

* llama : logits_all has priority over batch->logits

Otherwise, the server embeddings tests failed.
This was likely an existing problem but was only detected here
because of an additional assertion.

* llama : apply suggestions

Co-authored-by: Georgi Gerganov <[email protected]>

* llama : fix t5 segfault

* llama : fix Mamba session save and restore

* llama : minor cosmetic changes

* llama : rename llama_reorder_outputs to llama_output_reorder

Also move it closer to llama_output_reserve.

* llama : fix pooled embeddings when using batches with equal_seqs

* minor : add struct members for clarity

ggml-ci

* llama : fix T5 segfault again

* llama : fix Mamba pooled embeddings with multiple sequences

Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.

* llama : add llama_model_is_recurrent to simplify figuring that out

This will make it easier to more cleanly support RWKV-v6 and Mamba-2.

* llama : fix simple splits when the batch contains embeddings

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
compilade and ggerganov authored Aug 21, 2024
1 parent fc54ef0 commit a1631e5
Show file tree
Hide file tree
Showing 4 changed files with 1,137 additions and 678 deletions.
9 changes: 3 additions & 6 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1777,10 +1777,8 @@ extern "C" {

GGML_API struct ggml_tensor * ggml_ssm_conv(
struct ggml_context * ctx,
struct ggml_tensor * s,
struct ggml_tensor * x,
struct ggml_tensor * c,
struct ggml_tensor * sq);
struct ggml_tensor * sx,
struct ggml_tensor * c);

GGML_API struct ggml_tensor * ggml_ssm_scan(
struct ggml_context * ctx,
Expand All @@ -1789,8 +1787,7 @@ extern "C" {
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C,
struct ggml_tensor * sq);
struct ggml_tensor * C);

// partition into non-overlapping windows with padding if needed
// example:
Expand Down
Loading

0 comments on commit a1631e5

Please sign in to comment.