Skip to content

Commit

Permalink
Refactor sampler (#611)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Jul 23, 2024
1 parent a23876f commit a053312
Show file tree
Hide file tree
Showing 9 changed files with 322 additions and 335 deletions.
9 changes: 4 additions & 5 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::lora::Ordering;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
Expand All @@ -21,8 +22,8 @@ use crate::utils::model_config as ModelConfig;
use crate::utils::tokenizer::get_tokenizer;
use crate::xlora_models::NonGranularState;
use crate::{
do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, PagedAttentionConfig, Pipeline,
TryIntoDType, DEBUG,
get_mut_arcmutex, get_paths, DeviceMapMetadata, PagedAttentionConfig, Pipeline, TryIntoDType,
DEBUG,
};
use crate::{
models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
Expand Down Expand Up @@ -50,7 +51,6 @@ enum Model {
pub struct GGMLPipeline {
model: Model,
tokenizer: Arc<Tokenizer>,
tok_trie: Arc<TokTrie>,
no_kv_cache: bool,
chat_template: Arc<ChatTemplate>,
model_id: String,
Expand Down Expand Up @@ -329,7 +329,6 @@ impl Loader for GGMLLoader {
let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
Ok(Arc::new(Mutex::new(GGMLPipeline {
model,
tok_trie: tok_trie.clone(),
tokenizer: tokenizer.into(),
no_kv_cache: self.no_kv_cache,
chat_template: Arc::new(chat_template),
Expand Down Expand Up @@ -522,7 +521,7 @@ impl Pipeline for GGMLPipeline {
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
do_sample!(self, seqs, logits, prefix_cacher, disable_eos_stop, rng)
sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
}
fn category(&self) -> ModelCategory {
ModelCategory::Text
Expand Down
9 changes: 4 additions & 5 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::paged_attention::{
calculate_cache_config, AttentionImplementation, CacheEngine, ModelConfigLike,
};
use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig};
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::ChatTemplate;
use crate::pipeline::{get_chat_template, Cache};
use crate::prefix_cacher::PrefixCacheManager;
Expand All @@ -27,8 +28,8 @@ use crate::utils::model_config as ModelConfig;
use crate::utils::tokenizer::get_tokenizer;
use crate::xlora_models::NonGranularState;
use crate::{
do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, LocalModelPaths,
PagedAttentionConfig, Pipeline, TryIntoDType, DEBUG,
get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, LocalModelPaths, PagedAttentionConfig,
Pipeline, TryIntoDType, DEBUG,
};
use crate::{
models::quantized_llama::ModelWeights as QLlama,
Expand Down Expand Up @@ -70,7 +71,6 @@ enum Model {
pub struct GGUFPipeline {
model: Model,
tokenizer: Arc<Tokenizer>,
tok_trie: Arc<TokTrie>,
no_kv_cache: bool,
chat_template: Arc<ChatTemplate>,
model_id: String,
Expand Down Expand Up @@ -563,7 +563,6 @@ impl Loader for GGUFLoader {
let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
Ok(Arc::new(Mutex::new(GGUFPipeline {
model,
tok_trie: tok_trie.clone(),
tokenizer: tokenizer.into(),
no_kv_cache: self.no_kv_cache,
chat_template: Arc::new(chat_template),
Expand Down Expand Up @@ -786,7 +785,7 @@ impl Pipeline for GGUFPipeline {
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
do_sample!(self, seqs, logits, prefix_cacher, disable_eos_stop, rng)
sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
}
fn category(&self) -> ModelCategory {
ModelCategory::Text
Expand Down
3 changes: 1 addition & 2 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ use crate::amoe::{
AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs,
AnyMoeTrainingResult,
};
use crate::lora::{LoraConfig, Ordering};
use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigMetadata, PagedAttentionConfig};
use crate::prefix_cacher::PrefixCacheManager;
mod sampling_pipeline;
use crate::lora::{LoraConfig, Ordering};
use crate::{DeviceMapMetadata, TryIntoDType};
pub use amoe::{AnyMoeLoader, AnyMoePipeline};
use candle_core::quantized::GgmlDType;
Expand Down
7 changes: 3 additions & 4 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::amoe::AnyMoeExpertType;
use crate::lora::Ordering;
use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
Expand All @@ -27,7 +28,7 @@ use crate::utils::tokenizer::get_tokenizer;
use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
use crate::xlora_models::NonGranularState;
use crate::{
api_dir_list, api_get_file, do_sample, get_mut_arcmutex, get_paths, lora_model_loader,
api_dir_list, api_get_file, get_mut_arcmutex, get_paths, lora_model_loader,
normal_model_loader, xlora_model_loader, DeviceMapMetadata, PagedAttentionConfig, Pipeline,
TryIntoDType,
};
Expand All @@ -49,7 +50,6 @@ use tracing::{info, warn};
pub struct NormalPipeline {
model: Box<dyn NormalModel + Send + Sync>,
tokenizer: Arc<Tokenizer>,
tok_trie: Arc<TokTrie>,
no_kv_cache: bool,
chat_template: Arc<ChatTemplate>,
non_granular_state: Option<NonGranularState>,
Expand Down Expand Up @@ -343,7 +343,6 @@ impl Loader for NormalLoader {
let sliding_window = model.config().sliding_window;
Ok(Arc::new(Mutex::new(NormalPipeline {
model,
tok_trie: tok_trie.clone(),
tokenizer: tokenizer.into(),
no_kv_cache: self.no_kv_cache,
chat_template: Arc::new(chat_template),
Expand Down Expand Up @@ -498,7 +497,7 @@ impl Pipeline for NormalPipeline {
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
) -> Result<(), candle_core::Error> {
do_sample!(self, seqs, logits, prefix_cacher, disable_eos_stop, rng)
sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
}
fn category(&self) -> ModelCategory {
ModelCategory::Text
Expand Down
Loading

0 comments on commit a053312

Please sign in to comment.