Skip to content

Commit

Permalink
Merge pull request #243 from EricLBuehler/correct_eos_tokens
Browse files Browse the repository at this point in the history
Source bos, eos tokens from generation_config.json
  • Loading branch information
EricLBuehler authored Apr 29, 2024
2 parents 65889f3 + 5100ad9 commit 4505a5e
Show file tree
Hide file tree
Showing 16 changed files with 184 additions and 176 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ either = { version = "1.10.0", features = ["serde"] }
accelerate-src = { version = "0.3.2" }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
futures = "0.3"
clap = { version = "4.5.1", features = ["derive"] }
pyo3 = { version = "0.21.0", features = ["full"] } # pyo3 = { version = "0.21.0", features = ["extension-module", "full"] }
Expand Down
12 changes: 9 additions & 3 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use mistralrs_core::{
};
use std::sync::Arc;
use std::{fmt::Display, sync::mpsc::channel};
use tracing::{info, warn};
use tracing::info;
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;

enum TestName {
Prompt(usize),
Expand Down Expand Up @@ -251,7 +253,11 @@ fn main() -> anyhow::Result<()> {
#[cfg(not(feature = "metal"))]
let device = Device::cuda_if_available(0)?;

tracing_subscriber::fmt().init();
let filter = EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy();
tracing_subscriber::fmt().with_env_filter(filter).init();

let token_source = TokenSource::CacheToken;
info!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
Expand All @@ -274,7 +280,7 @@ fn main() -> anyhow::Result<()> {
| ModelKind::XLoraGGUF
)
{
warn!("Using flash attention with a quantized model has no effect!")
info!("⚠️ WARNING: Using flash attention with a quantized model has no effect!")
}
info!("Model kind is: {}", loader.get_kind().as_ref());
let pipeline = loader.load_model(
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/aici/bintokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use anyhow::{anyhow, bail, Result};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use tokenizers::{normalizers::Sequence, NormalizerWrapper, Tokenizer};
use tracing::{error, warn};
use tracing::{error, info};

#[derive(Serialize, Deserialize)]
pub struct ByteTokenizer {
Expand Down Expand Up @@ -155,7 +155,7 @@ impl ByteTokenizer {
panic!();
}
} else {
warn!("missing token: {}", tok_id);
info!("⚠️ WARNING: missing token: {}", tok_id);
}
}

Expand Down
6 changes: 3 additions & 3 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
use futures::future;
use rand::SeedableRng;
use rand_isaac::Isaac64Rng;
use tracing::warn;
use tracing::info;

use crate::{
get_mut_arcmutex, handle_pipeline_forward_error, handle_seq_error,
Expand Down Expand Up @@ -96,7 +96,7 @@ impl Engine {
let mut scheduled = self.scheduler.schedule();
if let Ok(dtype) = self.isq_rx.try_recv() {
if let Err(e) = get_mut_arcmutex!(self.pipeline).re_isq_model(dtype) {
warn!("ISQ requantization failed: {e:?}");
info!("⚠️ WARNING: ISQ requantization failed: {e:?}");
}
}

Expand Down Expand Up @@ -751,7 +751,7 @@ impl Engine {
10
};
prompt = prompt[(currently_over + sampling_max)..].to_vec();
warn!("Prompt for request {} was {} tokens over the model maximum length. The last {} tokens were truncated to make space for generation.", request.id, currently_over, prompt_len - prompt.len());
info!("⚠️ WARNING: Prompt for request {} was {} tokens over the model maximum length. The last {} tokens were truncated to make space for generation.", request.id, currently_over, prompt_len - prompt.len());
}
}
let prefill_cache = handle_seq_error!(
Expand Down
59 changes: 55 additions & 4 deletions mistralrs-core/src/pipeline/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,61 @@ impl ChatTemplate {
}
}

pub fn calculate_eos_tokens(chat_template: &ChatTemplate, tokenizer: &Tokenizer) -> Vec<u32> {
pub fn calculate_eos_tokens(
chat_template: &ChatTemplate,
gen_conf: Option<GenerationConfig>,
tokenizer: &Tokenizer,
) -> Vec<u32> {
let mut eos_tok_ids = vec![chat_template.eos_tok()];
let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();

for alternate in SUPPORTED_ALTERNATE_EOS {
if tokenizer.get_vocab(true).get(alternate).is_some() {
eos_tok_ids.push(alternate.to_string())
}
}

if let Some(gen_conf) = gen_conf {
let ids = match gen_conf.eos_token_id {
Either::Left(id) => vec![id],
Either::Right(ids) => ids,
};
for id in ids {
let s = tokenizer
.decode(&[id], false)
.unwrap_or_else(|_| panic!("Unable to decode id {id})"));
if !eos_tok_ids.contains(&s) {
eos_tok_ids.push(s);
}
}

let ids = match gen_conf.bos_token_id {
Either::Left(id) => vec![id],
Either::Right(ids) => ids,
};
for id in ids {
let s = tokenizer
.decode(&[id], false)
.unwrap_or_else(|_| panic!("Unable to decode id {id})"));
if !bos_tok_ids.contains(&s) {
bos_tok_ids.push(s);
}
}
}

let bos_render = bos_tok_ids
.iter()
.map(|val| format!("{:?}", val))
.collect::<Vec<String>>()
.join(", ");
let eos_render = eos_tok_ids
.iter()
.map(|val| format!("{:?}", val))
.collect::<Vec<String>>()
.join(", ");

info!(
"bos_tok = {}, eos_tok = {:?}, unk_tok = {}",
chat_template.bos_tok().unwrap_or("`None`".to_string()),
eos_tok_ids,
"bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
chat_template.unk_tok().unwrap_or("`None`".to_string()),
);

Expand All @@ -114,6 +156,15 @@ pub fn calculate_eos_tokens(chat_template: &ChatTemplate, tokenizer: &Tokenizer)
eos_toks
}

#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct GenerationConfig {
#[serde(with = "either::serde_untagged")]
bos_token_id: Either<u32, Vec<u32>>,
#[serde(with = "either::serde_untagged")]
eos_token_id: Either<u32, Vec<u32>>,
}

pub fn apply_chat_template_to(
messages: Vec<IndexMap<String, String>>,
add_generation_prompt: bool,
Expand Down
58 changes: 8 additions & 50 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::models::Cache;
use crate::pipeline::chat_template::calculate_eos_tokens;
use crate::pipeline::ChatTemplate;
use crate::pipeline::{ChatTemplate, SimpleModelPaths};
use crate::utils::varbuilder_utils::from_mmaped_safetensors;
use crate::xlora_models::{NonGranularState, XLoraConfig};
use crate::xlora_models::NonGranularState;
use crate::{deserialize_chat_template, get_paths, DeviceMapMetadata};
use crate::{
models::quantized_llama::ModelWeights as QLlama, sequence::Sequence, utils::tokens::get_token,
Expand All @@ -18,7 +18,7 @@ use anyhow::Result;
use candle_core::quantized::{ggml_file, GgmlDType};
use candle_core::{DType, Device, Tensor};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use mistralrs_lora::{LoraConfig, Ordering};
use mistralrs_lora::Ordering;
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
Expand All @@ -27,55 +27,13 @@ use std::str::FromStr;
use std::sync::Arc;
use std::sync::Mutex;
use tokenizers::Tokenizer;
use tracing::{info, warn};
use tracing::info;

enum Model {
Llama(QLlama),
XLoraLlama(XLoraQLlama),
}

pub struct MistralModelPaths<P> {
tokenizer_filename: P,
config_filename: P,
template_filename: P,
filenames: Vec<P>,
xlora_adapter_filenames: Option<Vec<(String, P)>>,
xlora_adapter_configs: Option<Vec<(String, LoraConfig)>>,
classifier_path: Option<P>,
classifier_config: Option<XLoraConfig>,
xlora_ordering: Option<Ordering>,
}

impl ModelPaths for MistralModelPaths<PathBuf> {
fn get_config_filename(&self) -> &PathBuf {
&self.config_filename
}
fn get_tokenizer_filename(&self) -> &PathBuf {
&self.tokenizer_filename
}
fn get_weight_filenames(&self) -> &[PathBuf] {
&self.filenames
}
fn get_adapter_filenames(&self) -> &Option<Vec<(String, PathBuf)>> {
&self.xlora_adapter_filenames
}
fn get_adapter_configs(&self) -> &Option<Vec<(String, LoraConfig)>> {
&self.xlora_adapter_configs
}
fn get_classifier_config(&self) -> &Option<XLoraConfig> {
&self.classifier_config
}
fn get_classifier_path(&self) -> &Option<PathBuf> {
&self.classifier_path
}
fn get_ordering(&self) -> &Option<Ordering> {
&self.xlora_ordering
}
fn get_template_filename(&self) -> &PathBuf {
&self.template_filename
}
}

pub struct GGMLPipeline {
model: Model,
config: GGMLSpecificConfig,
Expand Down Expand Up @@ -267,7 +225,7 @@ impl Loader for GGMLLoader {
silent: bool,
) -> Result<Box<dyn ModelPaths>> {
get_paths!(
MistralModelPaths,
SimpleModelPaths,
&token_source,
revision,
self,
Expand All @@ -292,7 +250,7 @@ impl Loader for GGMLLoader {
);
}
if !mapper.is_dummy() {
warn!("GGML models do not support device mapping. Device mapping will not work. Please consider using a GGUF model.");
info!("⚠️ WARNING: GGML models do not support device mapping. Device mapping will not work. Please consider using a GGUF model.");
}

let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
Expand Down Expand Up @@ -357,12 +315,12 @@ impl Loader for GGMLLoader {
let tokenizer =
Tokenizer::from_file(paths.get_tokenizer_filename()).map_err(anyhow::Error::msg)?;

let chat_template: ChatTemplate = deserialize_chat_template!(paths, self);
let (chat_template, gen_conf) = deserialize_chat_template!(paths, self);

Ok(Arc::new(Mutex::new(GGMLPipeline {
model,
config: self.config,
eos_tok: calculate_eos_tokens(&chat_template, &tokenizer),
eos_tok: calculate_eos_tokens(&chat_template, gen_conf, &tokenizer),
tok_trie: build_tok_trie(tokenizer.clone()).into(),
tokenizer: tokenizer.into(),
no_kv_cache: self.no_kv_cache,
Expand Down
54 changes: 6 additions & 48 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::models::Cache;
use crate::pipeline::chat_template::calculate_eos_tokens;
use crate::pipeline::ChatTemplate;
use crate::pipeline::{ChatTemplate, SimpleModelPaths};
use crate::utils::varbuilder_utils::from_mmaped_safetensors;
use crate::xlora_models::{NonGranularState, XLoraConfig};
use crate::xlora_models::NonGranularState;
use crate::{deserialize_chat_template, get_paths, DeviceMapMetadata};
use crate::{
models::quantized_llama::ModelWeights as QLlama, models::quantized_phi2::ModelWeights as QPhi,
Expand All @@ -18,7 +18,7 @@ use anyhow::{bail, Result};
use candle_core::quantized::{gguf_file, GgmlDType};
use candle_core::{DType, Device, Tensor};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use mistralrs_lora::{LoraConfig, Ordering};
use mistralrs_lora::Ordering;
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
Expand All @@ -35,48 +35,6 @@ enum Model {
XLoraLlama(XLoraQLlama),
}

pub struct MistralModelPaths<P> {
tokenizer_filename: P,
config_filename: P,
template_filename: P,
filenames: Vec<P>,
xlora_adapter_filenames: Option<Vec<(String, P)>>,
xlora_adapter_configs: Option<Vec<(String, LoraConfig)>>,
classifier_path: Option<P>,
classifier_config: Option<XLoraConfig>,
xlora_ordering: Option<Ordering>,
}

impl ModelPaths for MistralModelPaths<PathBuf> {
fn get_config_filename(&self) -> &PathBuf {
&self.config_filename
}
fn get_tokenizer_filename(&self) -> &PathBuf {
&self.tokenizer_filename
}
fn get_weight_filenames(&self) -> &[PathBuf] {
&self.filenames
}
fn get_adapter_filenames(&self) -> &Option<Vec<(String, PathBuf)>> {
&self.xlora_adapter_filenames
}
fn get_adapter_configs(&self) -> &Option<Vec<(String, LoraConfig)>> {
&self.xlora_adapter_configs
}
fn get_classifier_config(&self) -> &Option<XLoraConfig> {
&self.classifier_config
}
fn get_classifier_path(&self) -> &Option<PathBuf> {
&self.classifier_path
}
fn get_ordering(&self) -> &Option<Ordering> {
&self.xlora_ordering
}
fn get_template_filename(&self) -> &PathBuf {
&self.template_filename
}
}

pub struct GGUFPipeline {
model: Model,
config: GGUFSpecificConfig,
Expand Down Expand Up @@ -301,7 +259,7 @@ impl Loader for GGUFLoader {
silent: bool,
) -> Result<Box<dyn ModelPaths>> {
get_paths!(
MistralModelPaths,
SimpleModelPaths,
&token_source,
revision,
self,
Expand Down Expand Up @@ -410,12 +368,12 @@ impl Loader for GGUFLoader {
let tokenizer =
Tokenizer::from_file(paths.get_tokenizer_filename()).map_err(anyhow::Error::msg)?;

let chat_template: ChatTemplate = deserialize_chat_template!(paths, self);
let (chat_template, gen_conf) = deserialize_chat_template!(paths, self);

Ok(Arc::new(Mutex::new(GGUFPipeline {
model,
config: self.config,
eos_tok: calculate_eos_tokens(&chat_template, &tokenizer),
eos_tok: calculate_eos_tokens(&chat_template, gen_conf, &tokenizer),
tok_trie: build_tok_trie(tokenizer.clone()).into(),
tokenizer: tokenizer.into(),
no_kv_cache: self.no_kv_cache,
Expand Down
Loading

0 comments on commit 4505a5e

Please sign in to comment.