diff --git a/Cargo.lock b/Cargo.lock index 27adb40..fdb4896 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,6 +380,7 @@ dependencies = [ "futures", "llama_cpp", "tokio", + "tracing-subscriber", ] [[package]] @@ -394,9 +395,18 @@ dependencies = [ [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] [[package]] name = "memchr" @@ -440,6 +450,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -465,6 +485,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -545,8 +571,17 @@ checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.5", + "regex-syntax 0.8.2", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -557,9 +592,15 @@ checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.2", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.2" @@ -612,6 +653,15 @@ version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -694,6 +744,16 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "tokio" version = "1.36.0" @@ -753,6 +813,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -761,6 +851,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -779,6 +875,28 @@ dependencies = [ "rustix", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/crates/llama_cpp/src/detail.rs b/crates/llama_cpp/src/detail.rs index eff1f6d..fdd14ba 100644 --- a/crates/llama_cpp/src/detail.rs +++ b/crates/llama_cpp/src/detail.rs @@ -96,9 +96,9 @@ pub(crate) unsafe extern "C" fn llama_log_callback( }; match level { - ggml_log_level_GGML_LOG_LEVEL_ERROR => error!("ggml: {text}"), - ggml_log_level_GGML_LOG_LEVEL_INFO => info!("ggml: {text}"), - ggml_log_level_GGML_LOG_LEVEL_WARN => warn!("ggml: {text}"), + ggml_log_level_GGML_LOG_LEVEL_ERROR => error!(target: "llama.cpp", "{text}"), + ggml_log_level_GGML_LOG_LEVEL_INFO => info!(target: "llama.cpp", "{text}"), + ggml_log_level_GGML_LOG_LEVEL_WARN => warn!(target: "llama.cpp", "{text}"), _ => trace!("ggml: {text}"), } } diff --git a/crates/llama_cpp/src/model/mod.rs b/crates/llama_cpp/src/model/mod.rs index 4af18f4..84fe0a0 100644 --- a/crates/llama_cpp/src/model/mod.rs +++ b/crates/llama_cpp/src/model/mod.rs @@ -2,25 +2,28 @@ use std::borrow::Borrow; use std::cmp::min; -use std::ffi::{CStr, CString}; +use std::ffi::{c_char, CStr, CString}; use std::path::{Path, PathBuf}; use std::ptr::slice_from_raw_parts; use std::sync::{atomic::AtomicUsize, Arc}; +use std::usize; use derive_more::{Deref, DerefMut}; use futures::executor::block_on; use thiserror::Error; use tokio::sync::Mutex; use tokio::sync::RwLock; -use tracing::{info, trace, warn}; +use tracing::{error, info, trace, warn}; use backend::BackendRef; use llama_cpp_sys::{ - llama_context, llama_context_default_params, llama_context_params, llama_decode, - llama_free_model, llama_get_embeddings_ith, llama_kv_cache_clear, llama_load_model_from_file, - llama_model, llama_n_ctx_train, llama_n_embd, llama_n_vocab, llama_new_context_with_model, - llama_token_bos, llama_token_eos, llama_token_eot, llama_token_get_text, llama_token_middle, - llama_token_nl, llama_token_prefix, llama_token_suffix, llama_token_to_piece, llama_tokenize, + ggml_graph_overhead_custom, ggml_row_size, ggml_tensor_overhead, ggml_type, llama_context, + llama_context_default_params, llama_context_params, llama_decode, llama_free_model, + llama_get_embeddings_ith, llama_kv_cache_clear, llama_load_model_from_file, llama_model, + llama_model_meta_val_str, llama_n_ctx_train, llama_n_embd, llama_n_vocab, + llama_new_context_with_model, llama_token_bos, llama_token_eos, llama_token_eot, + llama_token_get_text, llama_token_middle, llama_token_nl, llama_token_prefix, + llama_token_suffix, llama_token_to_piece, llama_tokenize, }; pub use params::*; @@ -131,6 +134,23 @@ pub struct LlamaModel { /// The number of tokens in the context the model was trained with. training_size: usize, + + /// The number of layers in the model's network. + layers: usize, + + /// ??? + kv_heads: usize, + /// Dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads + k_attention: usize, + /// Dimension of values (d_v) aka n_embd_head + v_attention: usize, + + /// State Space Models conv kernel + ssm_d_conv: usize, + /// State Space Models inner size + ssm_d_inner: usize, + /// State Space Models state size + ssm_d_state: usize, } unsafe impl Send for LlamaModel {} @@ -182,6 +202,36 @@ impl LlamaModel { llama_n_vocab(model) }; + let n_embd = unsafe { llama_n_embd(model) } as usize; + + // Lots of redundant fetches here because llama.cpp doesn't expose any of this directly + + let heads = get_metadata(model, "%s.attention.head_count") + .parse::() + .unwrap_or(0); + + let layers = get_metadata(model, "%s.block_count") + .parse::() + .unwrap_or(0); + let kv_heads = get_metadata(model, "%s.attention.head_count_kv") + .parse::() + .unwrap_or(heads); + let k_attention = get_metadata(model, "%s.attention.key_length") + .parse::() + .unwrap_or(n_embd / heads); + let v_attention = get_metadata(model, "%s.attention.value_length") + .parse::() + .unwrap_or(n_embd / heads); + let ssm_d_conv = get_metadata(model, "%s.ssm.conv_kernel") + .parse::() + .unwrap_or(0); + let ssm_d_inner = get_metadata(model, "%s.ssm.inner_size") + .parse::() + .unwrap_or(0); + let ssm_d_state = get_metadata(model, "%s.ssm.state_size") + .parse::() + .unwrap_or(0); + Ok(Self { model: Arc::new(LlamaModelInner { model, @@ -195,8 +245,15 @@ impl LlamaModel { infill_middle_token: Token(unsafe { llama_token_middle(model) }), infill_suffix_token: Token(unsafe { llama_token_suffix(model) }), eot_token: Token(unsafe { llama_token_eot(model) }), - embedding_length: unsafe { llama_n_embd(model) } as usize, + embedding_length: n_embd, training_size: unsafe { llama_n_ctx_train(model) } as usize, + layers, + kv_heads, + k_attention, + v_attention, + ssm_d_conv, + ssm_d_inner, + ssm_d_state, }) } } @@ -429,6 +486,57 @@ impl LlamaModel { }) } + /// Calculates and returns an estimate of how much local memory a [`LlamaSession`] will take. + /// + /// At the moment, the value returned should always be more than the real value, possibly double. + /// + /// # Parameters + /// + /// * `session_params` - the parameters of the session to be created. + pub fn estimate_session_size(&self, session_params: &SessionParams) -> usize { + let kv_size = session_params.n_ctx as i64; // TODO exception for mamba arch + + // dimension of key embeddings across all k-v heads + let n_embd_k_gqa = self.k_attention * self.kv_heads; + // dimension of value embeddings across all k-v heads + let n_embd_v_gqa = self.v_attention * self.kv_heads; + + // dimension of the rolling state embeddings + let n_embd_k_s = if self.ssm_d_conv > 0 { + (self.ssm_d_conv - 1) * self.ssm_d_inner + } else { + 0 + }; + // dimension of the recurrent state embeddings + let n_embd_v_s = self.ssm_d_state * self.ssm_d_inner; + + let k_row_size = unsafe { + ggml_row_size( + session_params.type_k as ggml_type, + (n_embd_k_gqa + n_embd_k_s) as i64 * kv_size, + ) + }; + let v_row_size = unsafe { + ggml_row_size( + session_params.type_v as ggml_type, + (n_embd_v_gqa + n_embd_v_s) as i64 * kv_size, + ) + }; + + let cache_size = self.layers * (k_row_size + v_row_size); + + const LLAMA_MAX_NODES: usize = 8192; + + let compute_size = unsafe { + ggml_tensor_overhead() * LLAMA_MAX_NODES + + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false) + }; + + // TODO while llama doesn't offer memory estimation utilities, this is the best that can be done realistically + // https://github.com/ggerganov/llama.cpp/issues/4315 + (cache_size + compute_size) * 2 + } + /// Performs embeddings decoding on the given batch and returns the result. fn embeddings_decode( &self, @@ -624,4 +732,78 @@ impl LlamaModel { pub fn train_len(&self) -> usize { self.training_size } + + /// Return the number of layers of the model. + pub fn layers(&self) -> usize { + self.layers + } +} + +/// Retrieves a value in string form from a model's metadata. +/// +/// # Parameters +/// +/// * `model` - a pointer to the model to retrieve values from. +/// * `key` - the key of the metadata value. +/// +/// # Limitations +/// +/// At the moment, the implementation will retrieves values of limited length, so this shouldn't be used to retrieve +/// something like the model's grammar. +fn get_metadata(model: *mut llama_model, key: &str) -> String { + let c_key = if let Some(stripped) = key.strip_prefix("%s") { + let arch_key = CStr::from_bytes_with_nul(b"general.architecture\0").unwrap(); // Should never fail + let mut arch_val = vec![0u8; 128]; + + let res = unsafe { + llama_model_meta_val_str( + model, + arch_key.as_ptr(), + arch_val.as_mut_ptr() as *mut c_char, + arch_val.len(), + ) + }; + + if let Ok(len) = usize::try_from(res) { + if let Ok(c_str) = CStr::from_bytes_with_nul(&arch_val[..=len]) { + let formatted = format!("{}{stripped}", c_str.to_string_lossy()); + CString::new(formatted.as_bytes()).unwrap() + } else { + // This should be unreachable + error!("Could not parse architecture metadata"); + return String::new(); + } + } else { + // This should be unreachable + error!("Could not find architecture metadata"); + return String::new(); + } + } else { + CString::new(key).unwrap() + }; + + // This implementation assumes large values such as the model's vocabulary will never be queried + let mut val = vec![0u8; 128]; + let res = unsafe { + llama_model_meta_val_str( + model, + c_key.as_ptr(), + val.as_mut_ptr() as *mut c_char, + val.len(), + ) + }; + + if let Ok(len) = usize::try_from(res) { + if let Ok(val_str) = CStr::from_bytes_with_nul(&val[..=len]) + .map(move |val| val.to_string_lossy().to_string()) + { + val_str + } else { + error!("Failed to parse retrieved metadata"); + String::new() + } + } else { + warn!(key, "Could not find metadata"); + String::new() + } } diff --git a/crates/llama_cpp/src/session/mod.rs b/crates/llama_cpp/src/session/mod.rs index cba6368..2050b25 100644 --- a/crates/llama_cpp/src/session/mod.rs +++ b/crates/llama_cpp/src/session/mod.rs @@ -510,4 +510,18 @@ impl LlamaSession { Ok(copy) } + + /// Returns the maximum size in bytes this session is occupying in memory. + /// + /// This function may **NOT*** be called in async environments, for an async version see [`async_memory_size`]. + pub fn memory_size(&self) -> usize { + let ctx = self.inner.ctx.blocking_lock(); + unsafe { llama_get_state_size(ctx.ptr) } + } + + /// Asynchronously returns the maximum size in bytes this session is occupying in memory. + pub async fn async_memory_size(&self) -> usize { + let ctx = self.inner.ctx.lock().await; + unsafe { llama_get_state_size(ctx.ptr) } + } } diff --git a/crates/llama_cpp_sys/build.rs b/crates/llama_cpp_sys/build.rs index 9464d50..685a6d0 100644 --- a/crates/llama_cpp_sys/build.rs +++ b/crates/llama_cpp_sys/build.rs @@ -4,6 +4,7 @@ use std::io::Write; use std::path::{Path, PathBuf}; use std::process::Command; +use bindgen::callbacks::{ItemInfo, ItemKind, ParseCallbacks}; use cc::Build; use once_cell::sync::Lazy; @@ -84,6 +85,9 @@ compile_error!("feature \"clblas\" cannot be enabled alongside other GPU based f ))] compile_error!("feature \"vulkan\" cannot be enabled alongside other GPU based features"); +/// The general prefix used to rename conflicting symbols. +const PREFIX: &str = "llama_"; + static LLAMA_PATH: Lazy = Lazy::new(|| PathBuf::from("./thirdparty/llama.cpp")); fn compile_bindings(out_path: &Path) { @@ -92,12 +96,11 @@ fn compile_bindings(out_path: &Path) { .header(LLAMA_PATH.join("ggml.h").to_string_lossy()) .header(LLAMA_PATH.join("llama.h").to_string_lossy()) .derive_partialeq(true) + .allowlist_function("ggml_.*") .allowlist_type("ggml_.*") .allowlist_function("llama_.*") .allowlist_type("llama_.*") - .parse_callbacks(Box::new( - bindgen::CargoCallbacks::new().rerun_on_header_files(false), - )) + .parse_callbacks(Box::new(GGMLLinkRename {})) .generate() .expect("Unable to generate bindings"); @@ -106,6 +109,24 @@ fn compile_bindings(out_path: &Path) { .expect("Couldn't write bindings!"); } +#[derive(Debug)] +struct GGMLLinkRename {} + +impl ParseCallbacks for GGMLLinkRename { + fn generated_link_name_override(&self, item_info: ItemInfo<'_>) -> Option { + match item_info.kind { + ItemKind::Function => { + if item_info.name.starts_with("ggml_") { + Some(format!("{PREFIX}{}", item_info.name)) + } else { + None + } + } + _ => None, + } + } +} + /// Add platform appropriate flags and definitions present in all compilation configurations. fn push_common_flags(cx: &mut Build, cxx: &mut Build) { cx.static_flag(true).cpp(false).std("c11"); @@ -599,7 +620,7 @@ mod compat { }, ], ); - objcopy_redefine(&objcopy, ggml_lib_name, "llama_", symbols, &out_path); + objcopy_redefine(&objcopy, ggml_lib_name, PREFIX, symbols, &out_path); // Modifying the symbols llama depends on from ggml @@ -617,7 +638,7 @@ mod compat { }, ], ); - objcopy_redefine(&objcopy, llama_lib_name, "llama_", symbols, &out_path); + objcopy_redefine(&objcopy, llama_lib_name, PREFIX, symbols, &out_path); if let Some(gpu_lib_name) = additional_lib { // Modifying the symbols of the GPU library diff --git a/crates/llama_cpp_tests/Cargo.toml b/crates/llama_cpp_tests/Cargo.toml index 73e2449..4fc3a3d 100644 --- a/crates/llama_cpp_tests/Cargo.toml +++ b/crates/llama_cpp_tests/Cargo.toml @@ -9,6 +9,7 @@ license = "MIT OR Apache-2.0" futures = { workspace = true } llama_cpp = { version = "^0.3.1", path = "../llama_cpp", default-features = false, features = ["native", "compat"] } tokio = { workspace = true, features = ["full"] } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } [features] vulkan = ["llama_cpp/vulkan"] diff --git a/crates/llama_cpp_tests/src/lib.rs b/crates/llama_cpp_tests/src/lib.rs index 79fc1e3..d7667e2 100644 --- a/crates/llama_cpp_tests/src/lib.rs +++ b/crates/llama_cpp_tests/src/lib.rs @@ -8,17 +8,37 @@ mod tests { use std::io; use std::io::Write; use std::path::Path; + use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use futures::StreamExt; use tokio::select; use tokio::time::Instant; + use tracing_subscriber::layer::SubscriberExt; + use tracing_subscriber::util::SubscriberInitExt; use llama_cpp::standard_sampler::StandardSampler; use llama_cpp::{ CompletionHandle, EmbeddingsParams, LlamaModel, LlamaParams, SessionParams, TokensToStrings, }; + fn init_tracing() { + static SUBSCRIBER_SET: AtomicBool = AtomicBool::new(false); + + if !SUBSCRIBER_SET.swap(true, Ordering::SeqCst) { + let format = tracing_subscriber::fmt::layer().compact(); + let filter = tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or( + tracing_subscriber::EnvFilter::default() + .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()), + ); + + tracing_subscriber::registry() + .with(format) + .with(filter) + .init(); + } + } + async fn list_models(dir: impl AsRef) -> Vec { let dir = dir.as_ref(); @@ -66,6 +86,8 @@ mod tests { #[tokio::test] async fn execute_completions() { + init_tracing(); + let dir = std::env::var("LLAMA_CPP_TEST_MODELS").unwrap_or_else(|_| { panic!( "LLAMA_CPP_TEST_MODELS environment variable not set. \ @@ -135,6 +157,8 @@ mod tests { #[tokio::test] async fn embed() { + init_tracing(); + let dir = std::env::var("LLAMA_EMBED_MODELS_DIR").unwrap_or_else(|_| { panic!( "LLAMA_EMBED_MODELS_DIR environment variable not set. \