diff --git a/crates/ggml/Cargo.toml b/crates/ggml/Cargo.toml index 2c4f43bd..fe60f7a9 100644 --- a/crates/ggml/Cargo.toml +++ b/crates/ggml/Cargo.toml @@ -7,8 +7,9 @@ description = "Semi-idiomatic Rust bindings for the ggml library (from `ggml-sys license = "MIT" [dependencies] -thiserror = { workspace = true } ggml-sys = { path = "sys", version = "0.2.0-dev" } + +thiserror = { workspace = true } memmap2 = { workspace = true } [dev-dependencies] diff --git a/crates/ggml/src/context.rs b/crates/ggml/src/context.rs index e5c4cdb0..2f2d04f0 100644 --- a/crates/ggml/src/context.rs +++ b/crates/ggml/src/context.rs @@ -56,6 +56,13 @@ impl PartialEq for ContextInner { impl Eq for ContextInner {} impl ContextInner { pub(crate) fn new(ptr: *mut ggml_sys::ggml_context) -> Arc { + // This context can only be used from one thread at a time - hence why + // it doesn't implement `Send/Sync` - but higher-level abstractions may + // choose to layer their own abstractions that implement higher-level + // synchronization that can offer thread-safety guarantees. To ensure + // that we don't break those, we still use an `Arc` here. + // TODO: check if this is correct? + #[allow(clippy::arc_with_non_send_sync)] Arc::new(Self { ptr: NonNull::new(ptr).expect("Should not be null"), offloaded_tensors: Default::default(), @@ -118,7 +125,9 @@ impl PartialEq for ContextStorage { impl Eq for ContextStorage {} impl Context { - /// Creates a new [Context] with the given storage.. + // See explanation in [`ContextInner::new`]. + #[allow(clippy::arc_with_non_send_sync)] + /// Creates a new [Context] with the given storage. pub fn new(storage: ContextStorage) -> Self { let init_params = match &storage { ContextStorage::Buffer(buffer) => sys::ggml_init_params { @@ -296,7 +305,7 @@ impl Context { self.new_tensor_raw(tensor) } - /// Repeats the `a` tensor along the first dimension of the `b` tensor. + /// Repeats the `a` tensor along the first dimension of the `b` tensor. pub fn op_repeat(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { sys::ggml_repeat(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) diff --git a/crates/ggml/src/format/loader.rs b/crates/ggml/src/format/loader.rs index 8a1a42ae..8b94e6a3 100644 --- a/crates/ggml/src/format/loader.rs +++ b/crates/ggml/src/format/loader.rs @@ -167,7 +167,7 @@ pub fn load( match container_type { ContainerType::Ggml | ContainerType::Ggmf(1) - | ContainerType::Ggjt(1 | 2 | 3) + | ContainerType::Ggjt(1..=3) | ContainerType::Ggla(1) => {} _ => return Err(LoadError::InvalidFormatVersion(container_type)), } diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index e3f5a785..67408b34 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -8,8 +8,8 @@ use tracing::{instrument, log}; use ggml::accelerator::metal::MetalContext; use crate::{ - mulf, util, InferenceParameters, Model, ModelParameters, OutputRequest, Prompt, TokenId, - TokenUtf8Buffer, TokenizationError, + mulf, util, InferenceParameters, Model, ModelContext, ModelParameters, OutputRequest, Prompt, + TokenId, TokenUtf8Buffer, TokenizationError, }; // The size of a scratch buffer used for inference. This is used for temporary @@ -148,6 +148,10 @@ impl InferenceSession { ggml::accelerator::set_scratch_size(config.n_batch * 1024 * 1024); } + // TODO: revisit this with `Rc`, maybe? We should be able to prove that the session + // context is only accessed from one thread at a time, but I've already spent enough + // time on this as-is. + #[allow(clippy::arc_with_non_send_sync)] let session_ctx = Arc::new(ggml::Context::new_with_allocate(context_byte_size)); // Initialize key + value memory tensors @@ -215,7 +219,7 @@ impl InferenceSession { /// Compute a model (possibly building a graph in the provided closure when called for the first time and/or when parameters have) pub fn compute( &mut self, - #[allow(unused_variables)] model_context: Arc, + #[allow(unused_variables)] model_context: ModelContext, input_tokens: &[TokenId], builder: F, ) -> GraphOutputs @@ -242,7 +246,7 @@ impl InferenceSession { #[cfg(feature = "metal")] { if let Some(ref mut metal_context) = self.metal_context { - metal_context.add_context(model_context); + metal_context.add_context(model_context.0); } } diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index dd13d7bc..e07c8852 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -35,7 +35,7 @@ pub use loader::{ }; pub use lora::{LoraAdapter, LoraParameters}; pub use memmap2::Mmap; -pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest}; +pub use model::{Hyperparameters, KnownModel, Model, ModelContext, ModelParameters, OutputRequest}; pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use regex::Regex; pub use tokenizer::{ diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index c98ab048..d95ed348 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -5,11 +5,12 @@ use std::{ fs::File, io::{BufRead, BufReader, Read, Seek, SeekFrom}, path::{Path, PathBuf}, + sync::Arc, }; use crate::{ - util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, TokenId, - Tokenizer, TokenizerLoadError, TokenizerSource, + util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelContext, ModelParameters, + TokenId, Tokenizer, TokenizerLoadError, TokenizerSource, }; pub use ggml::{format::FormatMagic, ContainerType}; use ggml::{ @@ -398,7 +399,7 @@ pub trait TensorLoader { /// Gets a tensor from the loader. fn load(&mut self, name: &str) -> Result; /// Finish loading the model, returning the context. - fn finish(self) -> Context; + fn finish(self) -> ModelContext; } /// Load a GGML model from the `path` and configure it per the `params`. The status @@ -653,12 +654,7 @@ impl TensorLoader for MmapCompatibleLoader<'_> { path: Default::default(), })?; - let mut main_context = FileContext::new( - &self.context, - &mut self.file, - &self.path, - self.context.storage().as_mmap(), - ); + let mut main_context = FileContext::new(&self.context, &mut self.file, &self.path); let mut tensor = main_context.get_tensor(info)?; @@ -681,8 +677,11 @@ impl TensorLoader for MmapCompatibleLoader<'_> { Ok(tensor) } - fn finish(self) -> Context { - self.context + fn finish(self) -> ModelContext { + // We can ignore this warning as it's OK to share this particular + // context around, being that it is immutable. + #[allow(clippy::arc_with_non_send_sync)] + ModelContext(Arc::new(self.context)) } } @@ -690,20 +689,13 @@ pub(crate) struct FileContext<'a> { context: &'a Context, file: &'a mut File, path: &'a Path, - mmap: Option<&'a Mmap>, } impl<'a> FileContext<'a> { - pub(crate) fn new( - context: &'a Context, - file: &'a mut File, - path: &'a Path, - mmap: Option<&'a Mmap>, - ) -> Self { + pub(crate) fn new(context: &'a Context, file: &'a mut File, path: &'a Path) -> Self { Self { context, file, path, - mmap, } } @@ -738,7 +730,7 @@ impl<'a> FileContext<'a> { } }; - match self.mmap { + match self.context.storage().as_mmap() { Some(mmap) => unsafe { let ptr = mmap.as_ptr().offset(info.start_offset as isize); tensor.set_data(ptr as *mut std::ffi::c_void); diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index b6ed4a0f..c6d1d8a2 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -106,7 +106,7 @@ impl LoraAdapter { // Create a temporary context for the patching operations // TODO: test if GPU can be enabled (make it configurable) let patch_context = ggml::Context::new_with_allocate(patch_context_size); - let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path, None); + let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path); // Load the A and B tensors let a = patch_file.get_tensor(&a_info)?; diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index b31faf56..ab30e4f2 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -5,6 +5,7 @@ use std::{ fmt::Debug, io::{BufRead, Write}, path::{Path, PathBuf}, + sync::Arc, }; use ggml::accelerator::Backend; @@ -263,3 +264,13 @@ pub struct OutputRequest { /// `n_batch * n_embd`. pub embeddings: Option>, } + +/// Contains the GGML context for a [`Model`]. Implements `Send` and `Sync` +/// to allow for the free transfer of models; this is made possible by this +/// context being effectively inert after creation, so that it cannot be +/// modified across threads. +#[derive(Clone)] +#[allow(clippy::arc_with_non_send_sync)] +pub struct ModelContext(pub(crate) Arc); +unsafe impl Send for ModelContext {} +unsafe impl Sync for ModelContext {} diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index 349ebab6..efa1f338 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -2,13 +2,11 @@ //! for the `llm` ecosystem. #![deny(missing_docs)] -use std::sync::Arc; - use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; /// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom) @@ -37,7 +35,7 @@ pub struct Bloom { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Bloom {} @@ -101,7 +99,7 @@ impl KnownModel for Bloom { output_norm_bias, output, layers, - context: Arc::new(context), + context, }) } diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs index 914c22bb..0322e2f2 100644 --- a/crates/models/falcon/src/lib.rs +++ b/crates/models/falcon/src/lib.rs @@ -7,14 +7,12 @@ //! supported. It is currently only available as a preview. #![deny(missing_docs)] -use std::sync::Arc; - use ggml::Tensor; use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; /// The Falcon model. Ref: [Technology Innovation Institute](https://huggingface.co/tiiuae) @@ -39,7 +37,7 @@ pub struct Falcon { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Falcon {} @@ -138,7 +136,7 @@ impl KnownModel for Falcon { output_norm_b, lm_head, layers, - context: Arc::new(context), + context, }) } diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index ccd1d012..b4434ad5 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -1,14 +1,12 @@ //! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem. #![deny(missing_docs)] -use std::sync::Arc; - use ggml::Tensor; use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; /// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/) @@ -38,7 +36,7 @@ pub struct Gpt2 { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Gpt2 {} @@ -123,7 +121,7 @@ impl KnownModel for Gpt2 { wte, wpe, lm_head, - context: Arc::new(context), + context, }) } diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index b5fd4fc5..c013625a 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -1,14 +1,14 @@ //! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem. #![deny(missing_docs)] -use std::{error::Error, sync::Arc}; +use std::error::Error; use ggml::Tensor; use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, }; /// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b) @@ -35,7 +35,7 @@ pub struct GptJ { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for GptJ {} @@ -117,7 +117,7 @@ impl KnownModel for GptJ { lmh_g, lmh_b, layers, - context: Arc::new(context), + context, }) } diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index b420ec13..9075eb01 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -2,14 +2,14 @@ //! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model. #![deny(missing_docs)] -use std::{error::Error, sync::Arc}; +use std::error::Error; use ggml::Tensor; use llm_base::{ ggml, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, }; /// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox) @@ -35,7 +35,7 @@ pub struct GptNeoX { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for GptNeoX {} @@ -137,7 +137,7 @@ impl KnownModel for GptNeoX { wte, lmh_g, layers, - context: Arc::new(context), + context, }) } diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index db78b3a0..a70f315f 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -1,13 +1,13 @@ //! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem. #![deny(missing_docs)] -use std::{error::Error, sync::Arc}; +use std::error::Error; use llm_base::{ ggml::{self}, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, }; /// The LLaMA model. Ref: [Introducing LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) @@ -31,7 +31,7 @@ pub struct Llama { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Llama {} @@ -125,7 +125,7 @@ impl KnownModel for Llama { norm, output, layers, - context: Arc::new(context), + context, }) } diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 351ddd6c..3d22efff 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -1,14 +1,12 @@ //! An implementation of [MPT](https://huggingface.co/mosaicml) for the `llm` ecosystem. #![deny(missing_docs)] -use std::sync::Arc; - use ggml::Tensor; use llm_base::{ ggml::{self}, model::{common, HyperparametersWriteError}, util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; /// The MosaicML Pretrained Transformer (MPT) model. Ref: [Mosaic ML](https://www.mosaicml.com/blog/mpt-7b) @@ -31,7 +29,7 @@ pub struct Mpt { layers: Vec, // must be kept alive for the model - context: Arc, + context: ModelContext, } unsafe impl Send for Mpt {} @@ -78,7 +76,7 @@ impl KnownModel for Mpt { wte, norm, layers, - context: Arc::new(context), + context, }) }