diff --git a/Cargo.lock b/Cargo.lock index fdb4896..4ee4333 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "aho-corasick" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] @@ -59,7 +59,7 @@ version = "0.69.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "cexpr", "clang-sys", "itertools", @@ -72,7 +72,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.48", + "syn 2.0.53", "which", ] @@ -84,9 +84,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "bytes" @@ -96,9 +96,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.83" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" dependencies = [ "jobserver", "libc", @@ -227,7 +227,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.53", ] [[package]] @@ -274,9 +274,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "hermit-abi" -version = "0.3.6" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5256b483761cd23699d0da46cc6fd2ee3be420bbe6d020ae4a091e70b7e9fd" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "home" @@ -325,12 +325,12 @@ checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-sys 0.48.0", + "windows-targets 0.52.4", ] [[package]] @@ -431,9 +431,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "wasi", @@ -533,14 +533,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2", - "syn 2.0.48", + "syn 2.0.53", ] [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] @@ -571,7 +571,7 @@ checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.5", + "regex-automata 0.4.6", "regex-syntax 0.8.2", ] @@ -586,9 +586,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -630,11 +630,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.31" +version = "0.38.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", @@ -649,9 +649,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "semver" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" [[package]] name = "sharded-slab" @@ -688,18 +688,18 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" dependencies = [ "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -715,9 +715,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -726,22 +726,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.53", ] [[package]] @@ -781,7 +781,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.53", ] [[package]] @@ -803,7 +803,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.53", ] [[package]] @@ -912,7 +912,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.4", ] [[package]] @@ -932,17 +932,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", + "windows_aarch64_gnullvm 0.52.4", + "windows_aarch64_msvc 0.52.4", + "windows_i686_gnu 0.52.4", + "windows_i686_msvc 0.52.4", + "windows_x86_64_gnu 0.52.4", + "windows_x86_64_gnullvm 0.52.4", + "windows_x86_64_msvc 0.52.4", ] [[package]] @@ -953,9 +953,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" @@ -965,9 +965,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" @@ -977,9 +977,9 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" @@ -989,9 +989,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" @@ -1001,9 +1001,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" @@ -1013,9 +1013,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" @@ -1025,6 +1025,6 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" diff --git a/Cargo.toml b/Cargo.toml index abec77b..b53e8ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,4 @@ members = [ [workspace.dependencies] futures = "0.3.30" -tokio = "1.34.0" +tokio = "1.36.0" diff --git a/crates/llama_cpp/src/detail.rs b/crates/llama_cpp/src/detail.rs index fbb3782..afec13a 100644 --- a/crates/llama_cpp/src/detail.rs +++ b/crates/llama_cpp/src/detail.rs @@ -6,14 +6,11 @@ use std::ffi::{c_char, c_void, CStr}; -use tracing::{error, info, trace, warn}; +use tracing::{debug, error, info, warn}; -use llama_cpp_sys::{ - ggml_log_level, ggml_log_level_GGML_LOG_LEVEL_ERROR, ggml_log_level_GGML_LOG_LEVEL_INFO, - ggml_log_level_GGML_LOG_LEVEL_WARN, -}; +use llama_cpp_sys::ggml_log_level; -#[no_mangle] +#[allow(improper_ctypes_definitions)] pub(crate) unsafe extern "C" fn llama_log_callback( level: ggml_log_level, text: *const c_char, @@ -37,9 +34,10 @@ pub(crate) unsafe extern "C" fn llama_log_callback( }; match level { - ggml_log_level_GGML_LOG_LEVEL_ERROR => error!(target: "llama.cpp", "{text}"), - ggml_log_level_GGML_LOG_LEVEL_INFO => info!(target: "llama.cpp", "{text}"), - ggml_log_level_GGML_LOG_LEVEL_WARN => warn!(target: "llama.cpp", "{text}"), - _ => trace!("ggml: {text}"), + ggml_log_level::GGML_LOG_LEVEL_INFO => info!(target: "llama.cpp", "{text}"), + ggml_log_level::GGML_LOG_LEVEL_DEBUG => debug!(target: "llama.cpp", "{text}"), + ggml_log_level::GGML_LOG_LEVEL_WARN => warn!(target: "llama.cpp", "{text}"), + ggml_log_level::GGML_LOG_LEVEL_ERROR => error!(target: "llama.cpp", "{text}"), + _ => unimplemented!(), } } diff --git a/crates/llama_cpp/src/lib.rs b/crates/llama_cpp/src/lib.rs index be3abde..121d1ac 100644 --- a/crates/llama_cpp/src/lib.rs +++ b/crates/llama_cpp/src/lib.rs @@ -133,3 +133,20 @@ pub trait Sampler { candidates_p: llama_token_data_array, ) -> Token; } + +/// Memory requirements for something. +/// +/// This is typically returned by [`LlamaModel::estimate_session_size`] and +/// [`LlamaModel::estimate_embeddings_session_size`] as an estimation of memory usage. +#[derive(Debug)] +pub struct ResourceUsage { + /// The host memory required, in bytes. + pub host_memory: usize, + + /// The device memory required, in bytes. + /// + /// The device depends on features used to build this crate, as well as the main gpu selected during model creation. + /// + /// If the device is the CPU, this is additional host memory required. + pub device_memory: usize, +} diff --git a/crates/llama_cpp/src/model/backend.rs b/crates/llama_cpp/src/model/backend.rs index f00210a..aa61c80 100644 --- a/crates/llama_cpp/src/model/backend.rs +++ b/crates/llama_cpp/src/model/backend.rs @@ -7,8 +7,7 @@ use std::sync::Mutex; use tracing::error; use llama_cpp_sys::{ - ggml_numa_strategy_GGML_NUMA_STRATEGY_DISTRIBUTE, llama_backend_free, llama_backend_init, - llama_log_set, llama_numa_init, + ggml_numa_strategy, llama_backend_free, llama_backend_init, llama_log_set, llama_numa_init, }; use crate::detail; @@ -35,7 +34,7 @@ impl Backend { llama_backend_init(); // TODO look into numa strategies, this should probably be part of the API - llama_numa_init(ggml_numa_strategy_GGML_NUMA_STRATEGY_DISTRIBUTE); + llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_DISTRIBUTE); // SAFETY: performs a simple assignment to static variables. Should only execute once // before any logs are made. diff --git a/crates/llama_cpp/src/model/mod.rs b/crates/llama_cpp/src/model/mod.rs index 5b03117..3bfed6f 100644 --- a/crates/llama_cpp/src/model/mod.rs +++ b/crates/llama_cpp/src/model/mod.rs @@ -3,6 +3,7 @@ use std::borrow::Borrow; use std::cmp::min; use std::ffi::{c_char, CStr, CString}; +use std::mem::size_of; use std::path::{Path, PathBuf}; use std::ptr::slice_from_raw_parts; use std::sync::{atomic::AtomicUsize, Arc, Mutex, RwLock}; @@ -14,12 +15,11 @@ use tracing::{error, info, trace, warn}; use backend::BackendRef; use llama_cpp_sys::{ - ggml_graph_overhead_custom, ggml_row_size, ggml_tensor_overhead, ggml_type, llama_context, - llama_context_default_params, llama_context_params, llama_decode, llama_free_model, - llama_get_embeddings_ith, llama_kv_cache_clear, llama_load_model_from_file, llama_model, - llama_model_meta_val_str, llama_n_ctx_train, llama_n_embd, llama_n_vocab, - llama_new_context_with_model, llama_token_bos, llama_token_eos, llama_token_eot, - llama_token_get_text, llama_token_middle, llama_token_nl, llama_token_prefix, + ggml_row_size, llama_context, llama_context_params, llama_decode, llama_free_model, + llama_get_embeddings_ith, llama_get_embeddings_seq, llama_kv_cache_clear, + llama_load_model_from_file, llama_model, llama_model_meta_val_str, llama_n_ctx_train, + llama_n_embd, llama_n_vocab, llama_new_context_with_model, llama_token_bos, llama_token_eos, + llama_token_eot, llama_token_get_text, llama_token_middle, llama_token_nl, llama_token_prefix, llama_token_suffix, llama_token_to_piece, llama_tokenize, }; pub use params::*; @@ -27,7 +27,7 @@ pub use params::*; use crate::batch::Batch; use crate::{ LlamaContextError, LlamaContextInner, LlamaInternalError, LlamaSession, LlamaSessionInner, - SessionParams, Token, + ResourceUsage, SessionParams, Token, }; mod backend; @@ -490,7 +490,7 @@ impl LlamaModel { /// # Parameters /// /// * `session_params` - the parameters of the session to be created. - pub fn estimate_session_size(&self, session_params: &SessionParams) -> usize { + pub fn estimate_session_size(&self, session_params: &SessionParams) -> ResourceUsage { let kv_size = session_params.n_ctx as i64; // TODO exception for mamba arch // dimension of key embeddings across all k-v heads @@ -509,29 +509,43 @@ impl LlamaModel { let k_row_size = unsafe { ggml_row_size( - session_params.type_k as ggml_type, + session_params.type_k.into(), (n_embd_k_gqa + n_embd_k_s) as i64 * kv_size, ) }; let v_row_size = unsafe { ggml_row_size( - session_params.type_v as ggml_type, + session_params.type_v.into(), (n_embd_v_gqa + n_embd_v_s) as i64 * kv_size, ) }; let cache_size = self.layers * (k_row_size + v_row_size); + trace!("KV cache size: {}MB", cache_size / 1024 / 1024); - const LLAMA_MAX_NODES: usize = 8192; - - let compute_size = unsafe { - ggml_tensor_overhead() * LLAMA_MAX_NODES - + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false) + let batch = min(session_params.n_ctx, session_params.n_batch) as usize; + let logits_size = self.vocabulary_size * batch; + let embed_size = if session_params.embedding { + self.embedding_length * batch + } else { + 0 }; - - // TODO while llama doesn't offer memory estimation utilities, this is the best that can be done realistically - // https://github.com/ggerganov/llama.cpp/issues/4315 - (cache_size + compute_size) * 2 + let output_size = (logits_size + embed_size) * size_of::(); + trace!("Output buffer size: {}MB", output_size / 1024 / 1024); + + // const LLAMA_MAX_NODES: usize = 8192; + // + // let compute_size = unsafe { + // ggml_tensor_overhead() * LLAMA_MAX_NODES + // + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false) + // }; + + ResourceUsage { + host_memory: cache_size + output_size, + // TODO while llama doesn't offer memory estimation utilities, this is the best that can be done realistically + // https://github.com/ggerganov/llama.cpp/issues/4315 + device_memory: output_size, + } } /// Performs embeddings decoding on the given batch and returns the result. @@ -539,7 +553,7 @@ impl LlamaModel { &self, context: *mut llama_context, batch: &Batch, - input_count: usize, + token_counts: &[usize], ) -> Result>, LlamaContextError> { let res = unsafe { // clear previous kv_cache values (irrelevant for embeddings) @@ -551,11 +565,22 @@ impl LlamaModel { return Err(LlamaContextError::DecodeFailed(res)); } - let mut out = Vec::with_capacity(input_count); + let mut out = Vec::with_capacity(token_counts.len()); - for i in 0..input_count { + for (i, count) in token_counts.iter().enumerate() { let embedding = unsafe { - let ptr = llama_get_embeddings_ith(context, i as i32); + let mut ptr = llama_get_embeddings_seq(context, i as i32); + + if ptr.is_null() { + ptr = llama_get_embeddings_ith(context, (count - 1) as i32); + } + + if ptr.is_null() { + return Err(LlamaContextError::EmbeddingsFailed( + "Could not retrieve embeddings".to_string(), + )); + } + slice_from_raw_parts(ptr, self.embedding_length) .as_ref() .ok_or(LlamaContextError::EmbeddingsFailed( @@ -596,10 +621,11 @@ impl LlamaModel { ) -> Result>, LlamaContextError> { let mut total_tokens = 0; let mut max_tokens = 0; - for tokens in &inputs { - total_tokens += tokens.len(); - if max_tokens < tokens.len() { - max_tokens = tokens.len(); + let token_counts: Vec = inputs.iter().map(|v| v.len()).collect(); + for count in &token_counts { + total_tokens += count; + if max_tokens < *count { + max_tokens = *count; } } @@ -609,20 +635,14 @@ impl LlamaModel { } else { min(self.training_size, total_tokens) }; - let mut batch = Batch::new(batch_capacity, 0, inputs.len()); + let mut batch = Batch::new(batch_capacity, 0, 1); let mut out = Vec::with_capacity(inputs.len()); + let context_params = params.as_context_params(batch_capacity); let context = unsafe { - // SAFETY: Stack constructor, always safe. - let mut ctx_params = llama_context_default_params(); - ctx_params.embedding = true; - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch; - ctx_params.n_ctx = batch_capacity as u32; - ctx_params.n_batch = batch_capacity as u32; // SAFETY: due to `_model` being declared in the `LlamaContext`, `self` must live // for at least the lifetime of `LlamaContext`. - llama_new_context_with_model(**self.model, ctx_params) + llama_new_context_with_model(**self.model, context_params) }; if context.is_null() { @@ -630,11 +650,17 @@ impl LlamaModel { } let mut batch_input_count = 0; + let mut submitted = 0; for input in inputs { if batch.tokens() + input.len() > batch_capacity { trace!("Decoding {} embedding tokens", batch.tokens()); - out.append(&mut self.embeddings_decode(context, &batch, batch_input_count)?); + out.append(&mut self.embeddings_decode( + context, + &batch, + &token_counts[submitted..batch_input_count], + )?); batch.clear(); + submitted = batch_input_count; batch_input_count = 0; } @@ -647,7 +673,11 @@ impl LlamaModel { if 0 < batch_input_count { trace!("Decoding remaining {} embedding tokens", batch.tokens()); - out.append(&mut self.embeddings_decode(context, &batch, batch_input_count)?); + out.append(&mut self.embeddings_decode( + context, + &batch, + &token_counts[submitted..batch_input_count], + )?); } Ok(out) @@ -680,6 +710,36 @@ impl LlamaModel { .unwrap() } + /// Return an estimation of how much memory embeddings generation is gonna require for the provided parameters and + /// input tokens. + pub fn estimate_embeddings_session_size( + &self, + inputs: &[Vec], + params: &EmbeddingsParams, + ) -> ResourceUsage { + let mut total_tokens = 0; + let mut max_tokens = 0; + for tokens in inputs { + total_tokens += tokens.len(); + if max_tokens < tokens.len() { + max_tokens = tokens.len(); + } + } + + let batch_capacity = if max_tokens > self.training_size { + warn!("Large embedding input requires a context larger than the model's training context."); + max_tokens + } else { + min(self.training_size, total_tokens) + }; + + let context_params = params.as_context_params(batch_capacity); + + let mut ret = self.estimate_session_size(&context_params.into()); + ret.device_memory += ret.device_memory / 4; // bad workaround for device memory, see estimate_session_size + ret + } + /// Returns the beginning of sentence (BOS) token for this context. pub fn bos(&self) -> Token { self.bos_token diff --git a/crates/llama_cpp/src/model/params.rs b/crates/llama_cpp/src/model/params.rs index 65afbfc..ed4abc6 100644 --- a/crates/llama_cpp/src/model/params.rs +++ b/crates/llama_cpp/src/model/params.rs @@ -3,9 +3,8 @@ use std::ptr; use llama_cpp_sys::{ - llama_model_default_params, llama_model_params, llama_split_mode, - llama_split_mode_LLAMA_SPLIT_MODE_LAYER, llama_split_mode_LLAMA_SPLIT_MODE_NONE, - llama_split_mode_LLAMA_SPLIT_MODE_ROW, + llama_context_default_params, llama_context_params, llama_model_default_params, + llama_model_params, llama_split_mode, }; /// Parameters for llama. @@ -65,9 +64,9 @@ pub enum SplitMode { impl From for llama_split_mode { fn from(value: SplitMode) -> Self { match value { - SplitMode::None => llama_split_mode_LLAMA_SPLIT_MODE_NONE, - SplitMode::Layer => llama_split_mode_LLAMA_SPLIT_MODE_LAYER, - SplitMode::Row => llama_split_mode_LLAMA_SPLIT_MODE_ROW, + SplitMode::None => llama_split_mode::LLAMA_SPLIT_MODE_NONE, + SplitMode::Layer => llama_split_mode::LLAMA_SPLIT_MODE_LAYER, + SplitMode::Row => llama_split_mode::LLAMA_SPLIT_MODE_ROW, } } } @@ -76,9 +75,9 @@ impl From for SplitMode { fn from(value: llama_split_mode) -> Self { #![allow(non_upper_case_globals)] match value { - llama_split_mode_LLAMA_SPLIT_MODE_NONE => SplitMode::None, - llama_split_mode_LLAMA_SPLIT_MODE_LAYER => SplitMode::Layer, - llama_split_mode_LLAMA_SPLIT_MODE_ROW => SplitMode::Row, + llama_split_mode::LLAMA_SPLIT_MODE_NONE => SplitMode::None, + llama_split_mode::LLAMA_SPLIT_MODE_LAYER => SplitMode::Layer, + llama_split_mode::LLAMA_SPLIT_MODE_ROW => SplitMode::Row, _ => unimplemented!(), } } @@ -126,6 +125,22 @@ pub struct EmbeddingsParams { pub n_threads_batch: u32, } +impl EmbeddingsParams { + pub(crate) fn as_context_params(&self, batch_capacity: usize) -> llama_context_params { + // SAFETY: Stack constructor, always safe. + let mut ctx_params = unsafe { llama_context_default_params() }; + + ctx_params.embeddings = true; + ctx_params.n_threads = self.n_threads; + ctx_params.n_threads_batch = self.n_threads_batch; + ctx_params.n_ctx = batch_capacity as u32; + ctx_params.n_batch = batch_capacity as u32; + ctx_params.n_ubatch = batch_capacity as u32; + + ctx_params + } +} + impl Default for EmbeddingsParams { fn default() -> Self { let threads = num_cpus::get_physical() as u32 - 1; diff --git a/crates/llama_cpp/src/session/mod.rs b/crates/llama_cpp/src/session/mod.rs index 99aa1af..1dcf7ce 100644 --- a/crates/llama_cpp/src/session/mod.rs +++ b/crates/llama_cpp/src/session/mod.rs @@ -110,6 +110,14 @@ pub enum LlamaContextError { /// An error occurred on the other side of the FFI boundary; check your logs. #[error("failed to process embeddings (reason: {0})")] EmbeddingsFailed(String), + + /// An error occurred operating over kv cache due to invalid range. + #[error("failed to operate over kv cache due to invalid range")] + InvalidRange, + + /// Tried to start completing before advancing the context. + #[error("cannot start completing without any history")] + NoContext, } impl LlamaSession { @@ -233,7 +241,7 @@ impl LlamaSession { /// Starts generating tokens at the end of the context using a greedy /// sampler - pub fn start_completing(&mut self) -> CompletionHandle { + pub fn start_completing(&mut self) -> Result { self.start_completing_with( StandardSampler::new_greedy(), self.params().n_ctx as usize - self.context_size(), @@ -245,14 +253,19 @@ impl LlamaSession { &mut self, mut sampler: S, max_predictions: usize, - ) -> CompletionHandle + ) -> Result where S: Sampler + Send + Sync + 'static, { - let (tx, rx) = unbounded_channel(); let history_size = self.context_size(); + + if history_size == 0 { + return Err(LlamaContextError::NoContext); + } + + let (tx, rx) = unbounded_channel(); let session = self.clone(); - // TODO deal with 0 history size + info!("Generating completions with {history_size} tokens of history"); thread::spawn(move || { @@ -267,7 +280,9 @@ impl LlamaSession { // the model if session.inner.last_batch_size.load(Ordering::SeqCst) == 0 { // Remove last token - unsafe { llama_kv_cache_seq_rm(context.ptr, -1, token_buf.len() as i32 - 1, -1) } + unsafe { + llama_kv_cache_seq_rm(context.ptr, -1, token_buf.len() as i32 - 1, -1); + } // Decode last token batch.add(*token_buf.last().unwrap(), current_pos, &[0], true); @@ -345,10 +360,10 @@ impl LlamaSession { } }); - CompletionHandle { + Ok(CompletionHandle { rx, model: self.model(), - } + }) } /// Returns the model this session was created from. @@ -377,7 +392,10 @@ impl LlamaSession { /// /// Note that calling this is not equivalent to calling [`LlamaSession::set_context`] with the /// same list of tokens that this method produces. - pub fn remove_tokens_in_range(&mut self, range: impl RangeBounds) { + pub fn remove_tokens_in_range( + &mut self, + range: impl RangeBounds, + ) -> Result<(), LlamaContextError> { let start_bound = match range.start_bound() { Bound::Included(i) => *i as i32, Bound::Excluded(i) => *i as i32 + 1, @@ -393,7 +411,11 @@ impl LlamaSession { let context = self.inner.ctx.lock().unwrap(); // -1 here to match all sequences - unsafe { llama_kv_cache_seq_rm(context.ptr, -1, start_bound, end_bound) } + let success = unsafe { llama_kv_cache_seq_rm(context.ptr, -1, start_bound, end_bound) }; + + if !success { + return Err(LlamaContextError::InvalidRange); + } // If we delete to the end, store 0 to indicate that there are no logits if end_bound == -1 || end_bound as usize >= self.context_size() { @@ -401,10 +423,12 @@ impl LlamaSession { } self.inner.tokens.write().unwrap().drain(range); + + Ok(()) } /// Removes all but the first `n_tokens` tokens from the context. - pub fn truncate_context(&mut self, n_tokens: usize) { + pub fn truncate_context(&mut self, n_tokens: usize) -> Result<(), LlamaContextError> { self.remove_tokens_in_range(n_tokens..) } @@ -427,7 +451,7 @@ impl LlamaSession { std::mem::drop(old_tokens); - self.truncate_context(shared_prefix); + self.truncate_context(shared_prefix)?; self.advance_context_with_tokens(&new_tokens[shared_prefix..]) } @@ -513,7 +537,9 @@ impl LlamaSession { Ok(copy) } - /// Returns the maximum size in bytes this session is occupying in memory. + /// Returns the maximum size in bytes this session is occupying in host memory. + /// + /// Currently there is no way to check the amount of memory occupied in devices. pub fn memory_size(&self) -> usize { let ctx = self.inner.ctx.lock().unwrap(); unsafe { llama_get_state_size(ctx.ptr) } diff --git a/crates/llama_cpp/src/session/params.rs b/crates/llama_cpp/src/session/params.rs index 066ca1b..d678968 100644 --- a/crates/llama_cpp/src/session/params.rs +++ b/crates/llama_cpp/src/session/params.rs @@ -4,8 +4,7 @@ use std::ptr::null_mut; use llama_cpp_sys::{ ggml_type, llama_context_default_params, llama_context_params, llama_pooling_type, - llama_pooling_type_LLAMA_POOLING_TYPE_CLS, llama_pooling_type_LLAMA_POOLING_TYPE_MEAN, - llama_pooling_type_LLAMA_POOLING_TYPE_NONE, llama_pooling_type_LLAMA_POOLING_TYPE_UNSPECIFIED, + llama_rope_scaling_type, }; /// whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) @@ -24,10 +23,10 @@ pub enum PoolingType { impl From for llama_pooling_type { fn from(value: PoolingType) -> Self { match value { - PoolingType::Unspecified => llama_pooling_type_LLAMA_POOLING_TYPE_UNSPECIFIED, - PoolingType::None => llama_pooling_type_LLAMA_POOLING_TYPE_NONE, - PoolingType::Mean => llama_pooling_type_LLAMA_POOLING_TYPE_MEAN, - PoolingType::Cls => llama_pooling_type_LLAMA_POOLING_TYPE_CLS, + PoolingType::Unspecified => llama_pooling_type::LLAMA_POOLING_TYPE_UNSPECIFIED, + PoolingType::None => llama_pooling_type::LLAMA_POOLING_TYPE_NONE, + PoolingType::Mean => llama_pooling_type::LLAMA_POOLING_TYPE_MEAN, + PoolingType::Cls => llama_pooling_type::LLAMA_POOLING_TYPE_CLS, } } } @@ -36,10 +35,182 @@ impl From for PoolingType { fn from(value: llama_pooling_type) -> Self { #![allow(non_upper_case_globals)] match value { - llama_pooling_type_LLAMA_POOLING_TYPE_UNSPECIFIED => PoolingType::Unspecified, - llama_pooling_type_LLAMA_POOLING_TYPE_NONE => PoolingType::None, - llama_pooling_type_LLAMA_POOLING_TYPE_MEAN => PoolingType::Mean, - llama_pooling_type_LLAMA_POOLING_TYPE_CLS => PoolingType::Cls, + llama_pooling_type::LLAMA_POOLING_TYPE_UNSPECIFIED => PoolingType::Unspecified, + llama_pooling_type::LLAMA_POOLING_TYPE_NONE => PoolingType::None, + llama_pooling_type::LLAMA_POOLING_TYPE_MEAN => PoolingType::Mean, + llama_pooling_type::LLAMA_POOLING_TYPE_CLS => PoolingType::Cls, + _ => unimplemented!(), + } + } +} + +/// A rope scaling type. +#[derive(Clone, Copy)] +pub enum RopeScaling { + /// Unspecified. + Unspecified, + /// None. + None, + /// Linear. + Linear, + /// Yarn. + Yarn, +} + +impl From for llama_rope_scaling_type { + fn from(value: RopeScaling) -> Self { + match value { + RopeScaling::Unspecified => { + llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED + } + RopeScaling::None => llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_NONE, + RopeScaling::Linear => llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_LINEAR, + RopeScaling::Yarn => llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_YARN, + } + } +} + +impl From for RopeScaling { + fn from(value: llama_rope_scaling_type) -> Self { + match value { + llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED => { + RopeScaling::Unspecified + } + llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_NONE => RopeScaling::None, + llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_LINEAR => RopeScaling::Linear, + llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_YARN => RopeScaling::Yarn, + _ => unimplemented!(), + } + } +} + +/// The type of key or value in the cache. +#[derive(Clone, Copy)] +pub enum CacheType { + /// 32 bit float. + F32, + /// 16 bit float. + F16, + /// TODO ??? + Q4_0, + /// TODO ??? + Q4_1, + /// TODO ??? + Q5_0, + /// TODO ??? + Q5_1, + /// TODO ??? + Q8_0, + /// TODO ??? + Q8_1, + /// TODO ??? + Q2K, + /// TODO ??? + Q3K, + /// TODO ??? + Q4K, + /// TODO ??? + Q5K, + /// TODO ??? + Q6K, + /// TODO ??? + Q8K, + /// TODO ??? + IQ2XXS, + /// TODO ??? + IQ2XS, + /// TODO ??? + IQ3XXS, + /// TODO ??? + IQ1S, + /// TODO ??? + IQ4NL, + /// TODO ??? + IQ3S, + /// TODO ??? + IQ2S, + /// TODO ??? + IQ4XS, + /// 8 bit integer. + I8, + /// 16 bit integer. + I16, + /// 32 bit integer. + I32, + /// 64 bit integer. + I64, + /// 64 bit float. + F64, + /// Number of values in this enum. Not applicable to rust. + Count, +} + +impl From for ggml_type { + fn from(value: CacheType) -> Self { + match value { + CacheType::F32 => ggml_type::GGML_TYPE_F32, + CacheType::F16 => ggml_type::GGML_TYPE_F16, + CacheType::Q4_0 => ggml_type::GGML_TYPE_Q4_0, + CacheType::Q4_1 => ggml_type::GGML_TYPE_Q4_1, + CacheType::Q5_0 => ggml_type::GGML_TYPE_Q5_0, + CacheType::Q5_1 => ggml_type::GGML_TYPE_Q5_1, + CacheType::Q8_0 => ggml_type::GGML_TYPE_Q8_0, + CacheType::Q8_1 => ggml_type::GGML_TYPE_Q8_1, + CacheType::Q2K => ggml_type::GGML_TYPE_Q2_K, + CacheType::Q3K => ggml_type::GGML_TYPE_Q3_K, + CacheType::Q4K => ggml_type::GGML_TYPE_Q4_K, + CacheType::Q5K => ggml_type::GGML_TYPE_Q5_K, + CacheType::Q6K => ggml_type::GGML_TYPE_Q6_K, + CacheType::Q8K => ggml_type::GGML_TYPE_Q8_K, + CacheType::IQ2XXS => ggml_type::GGML_TYPE_IQ2_XXS, + CacheType::IQ2XS => ggml_type::GGML_TYPE_IQ2_XS, + CacheType::IQ3XXS => ggml_type::GGML_TYPE_IQ3_XXS, + CacheType::IQ1S => ggml_type::GGML_TYPE_IQ1_S, + CacheType::IQ4NL => ggml_type::GGML_TYPE_IQ4_NL, + CacheType::IQ3S => ggml_type::GGML_TYPE_IQ3_S, + CacheType::IQ2S => ggml_type::GGML_TYPE_IQ2_S, + CacheType::IQ4XS => ggml_type::GGML_TYPE_IQ4_XS, + CacheType::I8 => ggml_type::GGML_TYPE_I8, + CacheType::I16 => ggml_type::GGML_TYPE_I16, + CacheType::I32 => ggml_type::GGML_TYPE_I32, + CacheType::I64 => ggml_type::GGML_TYPE_I64, + CacheType::F64 => ggml_type::GGML_TYPE_F64, + CacheType::Count => ggml_type::GGML_TYPE_COUNT, + } + } +} + +impl From for CacheType { + fn from(value: ggml_type) -> Self { + match value { + ggml_type::GGML_TYPE_F32 => CacheType::F32, + ggml_type::GGML_TYPE_F16 => CacheType::F16, + ggml_type::GGML_TYPE_Q4_0 => CacheType::Q4_0, + ggml_type::GGML_TYPE_Q4_1 => CacheType::Q4_1, + ggml_type::GGML_TYPE_Q5_0 => CacheType::Q5_0, + ggml_type::GGML_TYPE_Q5_1 => CacheType::Q5_1, + ggml_type::GGML_TYPE_Q8_0 => CacheType::Q8_0, + ggml_type::GGML_TYPE_Q8_1 => CacheType::Q8_1, + ggml_type::GGML_TYPE_Q2_K => CacheType::Q2K, + ggml_type::GGML_TYPE_Q3_K => CacheType::Q3K, + ggml_type::GGML_TYPE_Q4_K => CacheType::Q4K, + ggml_type::GGML_TYPE_Q5_K => CacheType::Q5K, + ggml_type::GGML_TYPE_Q6_K => CacheType::Q6K, + ggml_type::GGML_TYPE_Q8_K => CacheType::Q8K, + ggml_type::GGML_TYPE_IQ2_XXS => CacheType::IQ2XXS, + ggml_type::GGML_TYPE_IQ2_XS => CacheType::IQ2XS, + ggml_type::GGML_TYPE_IQ3_XXS => CacheType::IQ3XXS, + ggml_type::GGML_TYPE_IQ1_S => CacheType::IQ1S, + ggml_type::GGML_TYPE_IQ4_NL => CacheType::IQ4NL, + ggml_type::GGML_TYPE_IQ3_S => CacheType::IQ3S, + ggml_type::GGML_TYPE_IQ2_S => CacheType::IQ2S, + ggml_type::GGML_TYPE_IQ4_XS => CacheType::IQ4XS, + ggml_type::GGML_TYPE_I8 => CacheType::I8, + ggml_type::GGML_TYPE_I16 => CacheType::I16, + ggml_type::GGML_TYPE_I32 => CacheType::I32, + ggml_type::GGML_TYPE_I64 => CacheType::I64, + ggml_type::GGML_TYPE_F64 => CacheType::F64, + ggml_type::GGML_TYPE_COUNT => CacheType::Count, _ => unimplemented!(), } } @@ -57,6 +228,12 @@ pub struct SessionParams { /// prompt processing maximum batch size pub n_batch: u32, + /// physical maximum batch size used for computations + pub n_ubatch: u32, + + /// max number of sequences (i.e. distinct states for recurrent models) + pub n_seq_max: u32, + /// number of threads to use for generation pub n_threads: u32, @@ -64,7 +241,7 @@ pub struct SessionParams { pub n_threads_batch: u32, /// RoPE scaling type, from [`llama_rope_scaling_type`] - pub rope_scaling_type: i32, + pub rope_scaling_type: RopeScaling, /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 @@ -90,10 +267,10 @@ pub struct SessionParams { pub yarn_orig_ctx: u32, /// data type for K cache - pub type_k: u32, + pub type_k: CacheType, /// data type for V cache - pub type_v: u32, + pub type_v: CacheType, /// embedding mode only pub embedding: bool, @@ -121,9 +298,11 @@ impl Default for SessionParams { seed: c_defaults.seed, n_ctx: c_defaults.n_ctx, n_batch: c_defaults.n_batch, + n_ubatch: c_defaults.n_ubatch, + n_seq_max: c_defaults.n_seq_max, n_threads: threads, n_threads_batch: threads, - rope_scaling_type: c_defaults.rope_scaling_type, + rope_scaling_type: c_defaults.rope_scaling_type.into(), rope_freq_base: c_defaults.rope_freq_base, rope_freq_scale: c_defaults.rope_freq_scale, yarn_ext_factor: c_defaults.yarn_ext_factor, @@ -131,9 +310,9 @@ impl Default for SessionParams { yarn_beta_fast: c_defaults.yarn_beta_fast, yarn_beta_slow: c_defaults.yarn_beta_slow, yarn_orig_ctx: c_defaults.yarn_orig_ctx, - type_k: c_defaults.type_k as u32, - type_v: c_defaults.type_v as u32, - embedding: c_defaults.embedding, + type_k: c_defaults.type_k.into(), + type_v: c_defaults.type_v.into(), + embedding: c_defaults.embeddings, offload_kqv: c_defaults.offload_kqv, pooling: c_defaults.pooling_type.into(), defrag_threshold: c_defaults.defrag_thold, @@ -147,9 +326,11 @@ impl From for llama_context_params { seed: value.seed, n_ctx: value.n_ctx, n_batch: value.n_batch, + n_ubatch: value.n_ubatch, + n_seq_max: value.n_seq_max, n_threads: value.n_threads, n_threads_batch: value.n_threads_batch, - rope_scaling_type: value.rope_scaling_type, + rope_scaling_type: value.rope_scaling_type.into(), rope_freq_base: value.rope_freq_base, rope_freq_scale: value.rope_freq_scale, yarn_ext_factor: value.yarn_ext_factor, @@ -160,10 +341,10 @@ impl From for llama_context_params { defrag_thold: value.defrag_threshold, cb_eval: None, cb_eval_user_data: null_mut(), - type_k: value.type_k as ggml_type, - type_v: value.type_v as ggml_type, + type_k: value.type_k.into(), + type_v: value.type_v.into(), logits_all: false, // Deprecated - embedding: value.embedding, + embeddings: value.embedding, offload_kqv: value.offload_kqv, pooling_type: value.pooling.into(), abort_callback: None, @@ -171,3 +352,31 @@ impl From for llama_context_params { } } } + +impl From for SessionParams { + fn from(value: llama_context_params) -> Self { + Self { + seed: value.seed, + n_ctx: value.n_ctx, + n_batch: value.n_batch, + n_ubatch: value.n_ubatch, + n_seq_max: value.n_seq_max, + n_threads: value.n_threads, + n_threads_batch: value.n_threads_batch, + rope_scaling_type: value.rope_scaling_type.into(), + rope_freq_base: value.rope_freq_base, + rope_freq_scale: value.rope_freq_scale, + yarn_ext_factor: value.yarn_ext_factor, + yarn_attn_factor: value.yarn_attn_factor, + yarn_beta_fast: value.yarn_beta_fast, + yarn_beta_slow: value.yarn_beta_slow, + yarn_orig_ctx: value.yarn_orig_ctx, + type_k: value.type_k.into(), + type_v: value.type_v.into(), + embedding: value.embeddings, + offload_kqv: value.offload_kqv, + pooling: value.pooling_type.into(), + defrag_threshold: value.defrag_thold, + } + } +} diff --git a/crates/llama_cpp_sys/Cargo.toml b/crates/llama_cpp_sys/Cargo.toml index f18dd7b..0e1d4b7 100644 --- a/crates/llama_cpp_sys/Cargo.toml +++ b/crates/llama_cpp_sys/Cargo.toml @@ -12,7 +12,7 @@ publish = true links = "llama" [dependencies] -ash = { version = "0.37.3+1.3.251", default-features = false, features = ["linked"], optional = true } +ash = { version = "0.37.3", default-features = false, features = ["linked"], optional = true } cudarc = { version = "0.10.0", features = ["cublaslt"], optional = true } link-cplusplus = "1.0.9" diff --git a/crates/llama_cpp_sys/build.rs b/crates/llama_cpp_sys/build.rs index 11b7eab..0d38b5c 100644 --- a/crates/llama_cpp_sys/build.rs +++ b/crates/llama_cpp_sys/build.rs @@ -5,6 +5,7 @@ use std::path::{Path, PathBuf}; use std::process::Command; use bindgen::callbacks::{ItemInfo, ItemKind, ParseCallbacks}; +use bindgen::EnumVariation; use cc::Build; use once_cell::sync::Lazy; @@ -86,13 +87,13 @@ compile_error!("feature \"clblas\" cannot be enabled alongside other GPU based f compile_error!("feature \"vulkan\" cannot be enabled alongside other GPU based features"); /// The general prefix used to rename conflicting symbols. -const PREFIX: &str = "llama_"; +const PREFIX: &str = "llm_"; static LLAMA_PATH: Lazy = Lazy::new(|| PathBuf::from("./thirdparty/llama.cpp")); fn compile_bindings(out_path: &Path) { println!("Generating bindings.."); - let bindings = bindgen::Builder::default() + let mut bindings = bindgen::Builder::default() .header(LLAMA_PATH.join("ggml.h").to_string_lossy()) .header(LLAMA_PATH.join("llama.h").to_string_lossy()) .derive_partialeq(true) @@ -100,18 +101,37 @@ fn compile_bindings(out_path: &Path) { .allowlist_type("ggml_.*") .allowlist_function("llama_.*") .allowlist_type("llama_.*") - .parse_callbacks(Box::new(GGMLLinkRename {})) - .generate() - .expect("Unable to generate bindings"); + .default_enum_style(EnumVariation::Rust { + non_exhaustive: true, + }) + .constified_enum("llama_gretype"); + + #[cfg(all( + feature = "compat", + not(any(target_os = "macos", target_os = "ios", target_os = "dragonfly")) + ))] + { + bindings = bindings.parse_callbacks(Box::new(GGMLLinkRename {})); + } + + let bindings = bindings.generate().expect("Unable to generate bindings"); bindings .write_to_file(out_path.join("bindings.rs")) .expect("Couldn't write bindings!"); } +#[cfg(all( + feature = "compat", + not(any(target_os = "macos", target_os = "ios", target_os = "dragonfly")) +))] #[derive(Debug)] struct GGMLLinkRename {} +#[cfg(all( + feature = "compat", + not(any(target_os = "macos", target_os = "ios", target_os = "dragonfly")) +))] impl ParseCallbacks for GGMLLinkRename { fn generated_link_name_override(&self, item_info: ItemInfo<'_>) -> Option { match item_info.kind { @@ -129,8 +149,14 @@ impl ParseCallbacks for GGMLLinkRename { /// Add platform appropriate flags and definitions present in all compilation configurations. fn push_common_flags(cx: &mut Build, cxx: &mut Build) { - cx.static_flag(true).cpp(false).std("c11"); - cxx.static_flag(true).cpp(true).std("c++14"); // MSVC does not support C++11 + cx.static_flag(true) + .cpp(false) + .std("c11") + .define("GGML_SCHED_MAX_COPIES", "4"); + cxx.static_flag(true) + .cpp(true) + .std("c++11") + .define("GGML_SCHED_MAX_COPIES", "4"); if !cfg!(debug_assertions) { cx.define("NDEBUG", None); @@ -351,9 +377,9 @@ fn compile_blis(cx: &mut Build) { fn compile_hipblas(cx: &mut Build, cxx: &mut Build, mut hip: Build) -> &'static str { const DEFAULT_ROCM_PATH_STR: &str = "/opt/rocm/"; - let rocm_path_str = env::var("ROCM_PATH").map_err(|_| { - DEFAULT_ROCM_PATH_STR.to_string() - }).unwrap(); + let rocm_path_str = env::var("ROCM_PATH") + .map_err(|_| DEFAULT_ROCM_PATH_STR.to_string()) + .unwrap(); println!("Compiling HIPBLAS GGML. Using ROCm from {rocm_path_str}"); let rocm_path = PathBuf::from(rocm_path_str); @@ -518,11 +544,11 @@ fn compile_vulkan(cx: &mut Build, cxx: &mut Build) -> &'static str { if cfg!(debug_assertions) { cx.define("GGML_VULKAN_DEBUG", None) + .define("GGML_VULKAN_CHECK_RESULTS", None) .define("GGML_VULKAN_VALIDATE", None); - //.define("GGML_VULKAN_CHECK_RESULTS", None) cxx.define("GGML_VULKAN_DEBUG", None) + .define("GGML_VULKAN_CHECK_RESULTS", None) .define("GGML_VULKAN_VALIDATE", None); - //.define("GGML_VULKAN_CHECK_RESULTS", None) } cx.define("GGML_USE_VULKAN", None); @@ -552,6 +578,7 @@ fn compile_ggml(mut cx: Build) { fn compile_llama(mut cxx: Build, _out_path: impl AsRef) { println!("Compiling Llama.cpp.."); cxx.include(LLAMA_PATH.as_path()) + .file(LLAMA_PATH.join("unicode.cpp")) .file(LLAMA_PATH.join("llama.cpp")) .compile("llama"); } @@ -605,13 +632,23 @@ fn main() { compile_ggml(cx); compile_llama(cxx, &out_path); - #[cfg(feature = "compat")] + #[cfg(all( + feature = "compat", + not(any(target_os = "macos", target_os = "ios", target_os = "dragonfly")) + ))] { compat::redefine_symbols(out_path, feat_lib); } } -#[cfg(feature = "compat")] +// MacOS will prefix all exported symbols with a leading underscore. +// Additionally, it seems that there are no collision issues when building with both llama and whisper crates, so the +// compat feature can be ignored. + +#[cfg(all( + feature = "compat", + not(any(target_os = "macos", target_os = "ios", target_os = "dragonfly")) +))] mod compat { use std::collections::HashSet; use std::fmt::{Display, Formatter}; @@ -713,7 +750,7 @@ mod compat { }, ], ); - objcopy_redefine(&objcopy, &lib_name, "llama_", symbols, &out_path); + objcopy_redefine(&objcopy, &lib_name, PREFIX, symbols, &out_path); } } diff --git a/crates/llama_cpp_sys/include/build-info.h b/crates/llama_cpp_sys/include/build-info.h index 94ac3ca..03eb6dd 100644 --- a/crates/llama_cpp_sys/include/build-info.h +++ b/crates/llama_cpp_sys/include/build-info.h @@ -13,7 +13,7 @@ #ifndef BUILD_INFO_H #define BUILD_INFO_H -#define BUILD_NUMBER 2333 -#define BUILD_COMMIT "4ffcdce" +#define BUILD_NUMBER 2465 +#define BUILD_COMMIT "d0d5de4" #endif // BUILD_INFO_H diff --git a/crates/llama_cpp_sys/thirdparty/llama.cpp b/crates/llama_cpp_sys/thirdparty/llama.cpp index 4ffcdce..d0d5de4 160000 --- a/crates/llama_cpp_sys/thirdparty/llama.cpp +++ b/crates/llama_cpp_sys/thirdparty/llama.cpp @@ -1 +1 @@ -Subproject commit 4ffcdce2ff877ebb683cd217ea38faf20faa5ffe +Subproject commit d0d5de42e5a65865b5fddb6f5c785083539b74c3 diff --git a/crates/llama_cpp_tests/src/lib.rs b/crates/llama_cpp_tests/src/lib.rs index d7667e2..71054fe 100644 --- a/crates/llama_cpp_tests/src/lib.rs +++ b/crates/llama_cpp_tests/src/lib.rs @@ -108,12 +108,27 @@ mod tests { .await .expect("Failed to load model"); - let mut params = SessionParams::default(); - params.n_ctx = 2048; + let params = SessionParams { + n_ctx: 2048, + ..Default::default() + }; + + let estimate = model.estimate_session_size(¶ms); + println!( + "Predict chat session size: Host {}MB, Device {}MB", + estimate.host_memory / 1024 / 1024, + estimate.device_memory / 1024 / 1024, + ); + let mut session = model .create_session(params) .expect("Failed to create session"); + println!( + "Real chat session size: Host {}MB", + session.memory_size() / 1024 / 1024 + ); + session .advance_context_async("<|SYSTEM|>You are a helpful assistant.") .await @@ -129,6 +144,7 @@ mod tests { let mut completions = session .start_completing_with(StandardSampler::default(), 1024) + .expect("Failed to start completing") .into_strings(); let timeout_by = Instant::now() + Duration::from_secs(500); @@ -176,9 +192,9 @@ mod tests { let mut input = vec![]; - for _phrase_idx in 0..2 { + for _phrase_idx in 0..10 { let mut phrase = String::new(); - for _word_idx in 0..3000 { + for _word_idx in 0..200 { phrase.push_str("word "); } phrase.truncate(phrase.len() - 1); @@ -186,17 +202,38 @@ mod tests { } let params = EmbeddingsParams::default(); + + let tokenized_input = model + .tokenize_slice(&input, true, false) + .expect("Failed to tokenize input"); + let estimate = model.estimate_embeddings_session_size(&tokenized_input, ¶ms); + println!( + "Predict embeddings session size: Host {}MB, Device {}MB", + estimate.host_memory / 1024 / 1024, + estimate.device_memory / 1024 / 1024, + ); + let res = model .embeddings_async(&input, params) .await .expect("Failed to infer embeddings"); + println!("{:?}", res[0]); + for embedding in &res { - assert!(embedding[0].is_normal(), "Embedding value isn't normal"); - assert!(embedding[0] >= 0f32, "Embedding value isn't normalised"); - assert!(embedding[0] <= 1f32, "Embedding value isn't normalised"); + let mut sum = 0f32; + for value in embedding { + assert!(value.is_normal(), "Embedding value isn't normal"); + assert!(*value >= -1f32, "Embedding value isn't normalised"); + assert!(*value <= 1f32, "Embedding value isn't normalised"); + sum += value * value; + } + + const ERROR: f32 = 0.0001; + let mag = sum.sqrt(); + assert!(mag < 1. + ERROR, "Vector magnitude is not close to 1"); + assert!(mag > 1. - ERROR, "Vector magnitude is not close to 1"); } - println!("{:?}", res[0]); } } } diff --git a/flake.nix b/flake.nix index 97c38e3..ce1935d 100644 --- a/flake.nix +++ b/flake.nix @@ -59,7 +59,7 @@ pkg-config ]; - devInputs = clangBuildInputs ++ nativeBuildInputs ++ (with pkgs; [ nixfmt openssl vulkan-loader ]); + devInputs = clangBuildInputs ++ nativeBuildInputs ++ (with pkgs; [ nixfmt vulkan-loader ]); stdenv = pkgs.stdenv; lib = pkgs.lib;