diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index a9f5d7843..3692b0072 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -16,6 +16,7 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::debug::setup_logger_and_debug; use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; use crate::xlora_models::NonGranularState; @@ -196,6 +197,8 @@ impl GGMLLoader { tokenizer_json: Option, tgt_non_granular_index: Option, ) -> Self { + setup_logger_and_debug(); + let model_id = if let Some(id) = model_id { id } else { diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 60ab71740..46f5f292c 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -17,6 +17,7 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::debug::setup_logger_and_debug; use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; use crate::xlora_models::NonGranularState; @@ -29,7 +30,7 @@ use crate::{ xlora_models::{XLoraQLlama, XLoraQPhi3}, GgufTokenizerConversion, }; -use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline, DEBUG}; +use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline}; use anyhow::{bail, Context, Result}; use candle_core::quantized::GgmlDType; use candle_core::{DType, Device, Tensor}; @@ -45,8 +46,6 @@ use strum::EnumString; use tokenizers::Tokenizer; use tokio::sync::Mutex; use tracing::info; -use tracing::level_filters::LevelFilter; -use tracing_subscriber::EnvFilter; enum Model { Llama(QLlama), @@ -231,6 +230,8 @@ impl GGUFLoader { chat_template: Option, tgt_non_granular_index: Option, ) -> Self { + setup_logger_and_debug(); + let model_id = if let Some(id) = model_id { Some(id) } else if let Some(xlora_order) = xlora_order.clone() { @@ -291,20 +292,6 @@ impl Loader for GGUFLoader { mapper: DeviceMapMetadata, in_situ_quant: Option, ) -> Result>> { - let is_debug = std::env::var("MISTRALRS_DEBUG") - .unwrap_or_default() - .contains('1'); - DEBUG.store(is_debug, std::sync::atomic::Ordering::Relaxed); - - let filter = EnvFilter::builder() - .with_default_directive(if is_debug { - LevelFilter::INFO.into() - } else { - LevelFilter::DEBUG.into() - }) - .from_env_lossy(); - tracing_subscriber::fmt().with_env_filter(filter).init(); - if in_situ_quant.is_some() { anyhow::bail!( "You are trying to in-situ quantize a GGUF model. This will not do anything." diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 7986ba60e..5ed6c022e 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -20,12 +20,13 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::debug::setup_logger_and_debug; use crate::utils::tokenizer::get_tokenizer; use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors}; use crate::xlora_models::NonGranularState; use crate::{ do_sample, get_mut_arcmutex, get_paths, lora_model_loader, normal_model_loader, - xlora_model_loader, DeviceMapMetadata, Pipeline, DEBUG, + xlora_model_loader, DeviceMapMetadata, Pipeline, }; use anyhow::Result; use candle_core::quantized::GgmlDType; @@ -40,8 +41,6 @@ use std::sync::Arc; use tokenizers::Tokenizer; use tokio::sync::Mutex; use tracing::info; -use tracing::level_filters::LevelFilter; -use tracing_subscriber::EnvFilter; pub struct NormalPipeline { model: Box, @@ -155,6 +154,8 @@ impl NormalLoaderBuilder { } pub fn build(self, loader: NormalLoaderType) -> Box { + setup_logger_and_debug(); + let loader: Box = match loader { NormalLoaderType::Mistral => Box::new(MistralLoader), NormalLoaderType::Gemma => Box::new(GemmaLoader), @@ -213,20 +214,6 @@ impl Loader for NormalLoader { mapper: DeviceMapMetadata, in_situ_quant: Option, ) -> Result>> { - let is_debug = std::env::var("MISTRALRS_DEBUG") - .unwrap_or_default() - .contains('1'); - DEBUG.store(is_debug, std::sync::atomic::Ordering::Relaxed); - - let filter = EnvFilter::builder() - .with_default_directive(if is_debug { - LevelFilter::INFO.into() - } else { - LevelFilter::DEBUG.into() - }) - .from_env_lossy(); - tracing_subscriber::fmt().with_env_filter(filter).init(); - let config = std::fs::read_to_string(paths.get_config_filename())?; let default_dtype = if device.is_cuda() && mapper.is_dummy() { DType::BF16 diff --git a/mistralrs-core/src/utils/debug.rs b/mistralrs-core/src/utils/debug.rs new file mode 100644 index 000000000..ebb71ff01 --- /dev/null +++ b/mistralrs-core/src/utils/debug.rs @@ -0,0 +1,21 @@ +use tracing::level_filters::LevelFilter; +use tracing_subscriber::EnvFilter; + +use crate::DEBUG; + +// This should be called in each `Loader` when it is created. +pub(crate) fn setup_logger_and_debug() { + let is_debug = std::env::var("MISTRALRS_DEBUG") + .unwrap_or_default() + .contains('1'); + DEBUG.store(is_debug, std::sync::atomic::Ordering::Relaxed); + + let filter = EnvFilter::builder() + .with_default_directive(if is_debug { + LevelFilter::DEBUG.into() + } else { + LevelFilter::INFO.into() + }) + .from_env_lossy(); + tracing_subscriber::fmt().with_env_filter(filter).init(); +} diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index e5911b836..8b6bb9ae0 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod debug; pub(crate) mod model_config; pub(crate) mod progress; pub(crate) mod tokenizer;