Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug for space token decoding & remove redundant code #72

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ Currently, candle-vllm supports chat serving for the following models.

| Model ID | Model Type | Supported | Speed (A100, BF16) | Throughput (bs=16)
|--|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 386 tks/s (7B) |
| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 456 tks/s (7B) |
| #2 | **Mistral** |✅|70 tks/s (7B)| 291 tks/s (7B) |
| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|TBD|
| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 467 tks/s (3.8B)|
| #5 | **Yi** |✅|75 tks/s (6B)| 375 tks/s (6B) |
| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 502 tks/s (3.8B)|
| #5 | **Yi** |✅|75 tks/s (6B)| 395 tks/s (6B) |
| #6 | **StableLM** |✅|99 tks/s (3B)|TBD|
| #7 | BigCode/StarCode |TBD|TBD|TBD |
| #8 | ChatGLM |TBD|TBD|TBD |
Expand Down
11 changes: 2 additions & 9 deletions src/openai/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::Config;
use crate::openai::models::linear::{linear_b, linear_no_bias as linear, Linear};
use crate::paged_attention::input_metadata::InputMetadata;
use crate::paged_attention::PagedAttention;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_core as candle;
use candle_nn::Activation;
use candle_nn::{RmsNorm, VarBuilder};
Expand Down Expand Up @@ -354,18 +354,11 @@ impl Gemma {
}

fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
let seqlen_offset = 0;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((b_size, 1, tgt_len, tgt_len))?
.to_dtype(self.dtype)
}

Expand Down
12 changes: 3 additions & 9 deletions src/openai/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::Config;
use crate::openai::models::linear::{linear_no_bias as linear, Linear};
use crate::paged_attention::input_metadata::InputMetadata;
use crate::paged_attention::PagedAttention;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_core as candle;
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use candle_transformers::models::with_tracing::RmsNorm;
Expand Down Expand Up @@ -305,18 +305,12 @@ pub struct Llama {

impl Llama {
fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
let seqlen_offset = 0;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((b_size, 1, tgt_len, tgt_len))?
.contiguous()?
.to_dtype(self.dtype)
}

Expand Down
17 changes: 5 additions & 12 deletions src/openai/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::Config;
use crate::openai::models::linear::{linear_no_bias, Linear};
use crate::paged_attention::input_metadata::InputMetadata;
use crate::paged_attention::PagedAttention;
use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_core::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_nn::{Activation, VarBuilder};
use candle_transformers::models::with_tracing::RmsNorm;
use either::Either;
Expand Down Expand Up @@ -335,8 +335,7 @@ impl Mistral {
})
}

fn prepare_decoder_attention_mask(&self, tgt_len: usize) -> Result<Tensor> {
let seqlen_offset = 0;
fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
Expand All @@ -350,13 +349,7 @@ impl Mistral {
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((b_size, 1, tgt_len, tgt_len))?
.to_dtype(self.dtype)
}

Expand All @@ -367,11 +360,11 @@ impl Mistral {
kv_caches: Option<&Vec<(Tensor, Tensor)>>,
input_metadata: &mut InputMetadata,
) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(seq_len)?;
let mask = self.prepare_decoder_attention_mask(b_size, seq_len)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
Expand Down
9 changes: 1 addition & 8 deletions src/openai/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,18 +336,11 @@ impl Phi2 {
}

fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
let seqlen_offset = 0;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((b_size, 1, tgt_len, tgt_len))?
.to_dtype(DType::F32)
}

Expand Down
9 changes: 1 addition & 8 deletions src/openai/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,18 +443,11 @@ impl Phi {
}

fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
let seqlen_offset = 0;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((b_size, 1, tgt_len, tgt_len))?
.to_dtype(self.dtype)
}

Expand Down
11 changes: 2 additions & 9 deletions src/openai/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::Config;
use crate::openai::models::linear::{linear, linear_no_bias, Linear};
use crate::paged_attention::input_metadata::InputMetadata;
use crate::paged_attention::PagedAttention;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_core as candle;
use candle_nn::VarBuilder;
use candle_transformers::models::with_tracing::RmsNorm;
Expand Down Expand Up @@ -345,7 +345,6 @@ impl Qwen2 {

fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
// Sliding window mask?
let seqlen_offset = 0;
let mask: Vec<_> = if self.sliding_window.is_some() {
let sliding_window = self.sliding_window.unwrap();
(0..tgt_len)
Expand All @@ -366,13 +365,7 @@ impl Qwen2 {
};

let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((b_size, 1, tgt_len, tgt_len))?
.to_dtype(self.dtype)
}

Expand Down
9 changes: 1 addition & 8 deletions src/openai/models/stable_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,18 +347,11 @@ impl StableLM {

fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
// Sliding window mask?
let seqlen_offset = 0;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((b_size, 1, tgt_len, tgt_len))?
.to_dtype(self.dtype)
}

Expand Down
11 changes: 2 additions & 9 deletions src/openai/models/yi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::Config;
use crate::openai::models::linear::{linear_no_bias, Linear};
use crate::paged_attention::input_metadata::InputMetadata;
use crate::paged_attention::PagedAttention;
use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_core::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_nn::{Activation, VarBuilder};
use candle_transformers::models::with_tracing::RmsNorm;
use either::Either;
Expand Down Expand Up @@ -333,18 +333,11 @@ impl Yi {

fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result<Tensor> {
// Sliding window mask?
let seqlen_offset = 0;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((b_size, 1, tgt_len, tgt_len))?
.to_dtype(self.dtype)
}

Expand Down
22 changes: 0 additions & 22 deletions src/openai/pipelines/llm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use tokio::sync::Notify;
struct PreparedInputs {
tokens: Tensor,
positions: Vec<Vec<usize>>,
positions_tensor: Tensor,
metadata: InputMetadata,
}

Expand Down Expand Up @@ -204,7 +203,6 @@ impl LLMEngine {
let PreparedInputs {
tokens,
positions,
positions_tensor,
metadata,
} = if seqs.values().nth(0).unwrap().deref().is_prompt() {
self.prepare_prompt(scheduled)
Expand Down Expand Up @@ -466,15 +464,6 @@ impl LLMEngine {
0,
&self.pipeline.device(),
)?;
let input_positions_tensor = _make_tensor_with_pad(
input_positions
.iter()
.map(|x| x.iter().map(|x| *x as i64).collect::<Vec<_>>())
.collect::<Vec<_>>(),
*max_prompt_len,
0,
&self.pipeline.device(),
)?;
let slot_mapping = _make_tensor_with_pad(
slot_mappings,
*max_prompt_len,
Expand All @@ -484,7 +473,6 @@ impl LLMEngine {

Ok(PreparedInputs {
tokens: input_tokens,
positions_tensor: input_positions_tensor,
positions: input_positions,
metadata: InputMetadata {
prompt_lens,
Expand Down Expand Up @@ -567,15 +555,6 @@ impl LLMEngine {
0,
&self.pipeline.device(),
)?;
let input_positions_tensor = _make_tensor_with_pad(
input_positions
.iter()
.map(|x| x.iter().map(|x| *x as i64).collect::<Vec<_>>())
.collect::<Vec<_>>(),
1,
0,
&self.pipeline.device(),
)?;
let slot_mapping =
_make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, &self.pipeline.device())?;

Expand All @@ -600,7 +579,6 @@ impl LLMEngine {
Ok(PreparedInputs {
tokens: input_tokens,
positions: input_positions,
positions_tensor: input_positions_tensor,
metadata: InputMetadata {
prompt_lens: vec![],
slot_mapping,
Expand Down
15 changes: 11 additions & 4 deletions src/openai/pipelines/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,20 @@ impl ModulePipeline for DefaultPipeline {
};

let next_token = try_api!(self.logits_processor.sample(&logits));
let text = self
let mut text = self
.tokenizer
.tokenizer()
.decode(&[next_token], false)
.unwrap_or(" ".to_string());
let origin_text = self
.tokenizer
.tokenizer()
.id_to_token(next_token)
.unwrap_or("".to_string())
.replace("▁", " ")
.replace("<0x0A>", "\n");
.unwrap_or("".to_string());
//properly handle space token
if origin_text.contains("▁") && origin_text.replace("▁", "") == text {
text = origin_text.replace("▁", " ");
}
if self.stop_token_ids.contains(&next_token) && tokens_generated > 1 {
result.push(Right("stop".to_string()));
break;
Expand Down
Loading