diff --git a/crates/ggml/src/context.rs b/crates/ggml/src/context.rs index c367c83c..dea550cd 100644 --- a/crates/ggml/src/context.rs +++ b/crates/ggml/src/context.rs @@ -18,11 +18,8 @@ pub struct Context { /// with it if the underlying context has been deallocated. inner: Arc, - /// Memory mapping information - pub mmap: Option, - - /// Backing buffer (in case we own it) - pub buffer: Option, + /// The storage for this context. This is stored so that the buffer can be dropped when the context is dropped. + storage: Option, /// Whether the context can offload tensors to the GPU pub can_offload: bool, @@ -46,7 +43,6 @@ pub(crate) struct ContextInner { // interface and its scratch buffer solution. pub offloaded_tensors: Mutex>, } - impl ContextInner { pub(crate) fn new(ptr: *mut ggml_sys::ggml_context) -> Arc { Arc::new(Self { @@ -56,60 +52,72 @@ impl ContextInner { } } +/// Controls how the context uses memory. +pub enum ContextStorage { + /// Use the provided buffer as memory. + Buffer(Buffer), + /// Use the provided memory mapped file as memory. + Mmap(Mmap), + /// Allocate `mem_size` bytes of memory. + Allocate { + /// The size, in bytes, of the memory in to allocate. + mem_size: usize, + }, +} + impl Context { - /// Creates a new [Context] using the buffer provided as memory - pub fn init_buffer(buffer: Buffer) -> Self { - let raw = unsafe { - sys::ggml_init(sys::ggml_init_params { + /// Creates a new [Context] with the given storage.. + pub fn new(storage: ContextStorage) -> Self { + let init_params = match &storage { + ContextStorage::Buffer(buffer) => sys::ggml_init_params { mem_size: buffer.size(), mem_buffer: buffer.data, no_alloc: false, - }) + }, + ContextStorage::Mmap(mmap) => sys::ggml_init_params { + mem_size: mmap.len(), + mem_buffer: std::ptr::null_mut(), + // We are mmapping so ggml does not need to allocate any memory for us + no_alloc: true, + }, + ContextStorage::Allocate { mem_size } => sys::ggml_init_params { + mem_size: *mem_size, + // Null here means we want ggml to own this memory. + mem_buffer: std::ptr::null_mut(), + // It doesn't make sense to `no_alloc` when passing in a `mem_size` in this mode. + no_alloc: false, + }, }; + let raw = unsafe { sys::ggml_init(init_params) }; Self { inner: ContextInner::new(raw), - mmap: None, - buffer: Some(buffer), + storage: Some(storage), can_offload: false, } } - /// Creates a new [Context] with the memory mapped file provided - pub fn init_mmap(mmap: Mmap) -> Self { - let raw = unsafe { - sys::ggml_init(sys::ggml_init_params { - mem_size: mmap.len(), - mem_buffer: std::ptr::null_mut(), - no_alloc: true, // We are mmapping so ggml does not need to allocate any memory for us - }) - }; + /// Creates a new [Context] with the specified buffer. + /// The buffer will be used by GGML. + pub fn new_with_buffer(buffer: Buffer) -> Self { + Self::new(ContextStorage::Buffer(buffer)) + } - Self { - inner: ContextInner::new(raw), - mmap: Some(mmap), - buffer: None, - can_offload: false, - } + /// Creates a new [Context] with the specified memory mapped file. + pub fn new_with_mmap(mmap: Mmap) -> Self { + Self::new(ContextStorage::Mmap(mmap)) } - /// Creates a new [Context] with the specified `mem_size` as a working area. - pub fn init(mem_size: usize, alloc: bool) -> Self { - let raw = unsafe { - sys::ggml_init(sys::ggml_init_params { - mem_size, - // Null here means we want ggml to own this memory. - mem_buffer: std::ptr::null_mut(), - no_alloc: !alloc, - }) - }; + /// Creates a new [Context] with the specified memory size. + /// The memory will be allocated by GGML. + pub fn new_with_allocate(mem_size: usize) -> Self { + Self::new(ContextStorage::Allocate { mem_size }) + } - Self { - inner: ContextInner::new(raw), - mmap: None, - buffer: None, - can_offload: false, - } + /// Recreates this context using the same storage. + pub fn recreate(&mut self) { + // This is the only operation that can consume the `self.storage`, so we can unwrap here. + *self = Self::new(self.storage.take().unwrap()); } /// If offloading is enabled, all tensors created by this context will be offloaded to the GPU @@ -182,6 +190,14 @@ impl Context { let raw = unsafe { sys::ggml_new_f32(self.as_ptr(), x) }; self.new_tensor_raw(raw) } + + /// Returns the mmap used by this [Context], if any. + pub fn mmap(&self) -> Option<&Mmap> { + match &self.storage { + Some(ContextStorage::Mmap(mmap)) => Some(mmap), + _ => None, + } + } } // Operations impl Context { diff --git a/crates/ggml/src/lib.rs b/crates/ggml/src/lib.rs index 168a292a..af833b00 100644 --- a/crates/ggml/src/lib.rs +++ b/crates/ggml/src/lib.rs @@ -20,7 +20,7 @@ pub mod util; pub mod accelerator; -pub use context::Context; +pub use context::{Context, ContextStorage}; pub use tensor::Tensor; pub use ggml_sys as sys; diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 9b45687c..ee988eae 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -141,7 +141,7 @@ impl InferenceSession { ggml::accelerator::set_scratch_size(config.n_batch * 1024 * 1024); } - let session_ctx = Arc::new(ggml::Context::init(ctx_size, true)); + let session_ctx = Arc::new(ggml::Context::new_with_allocate(ctx_size)); // Initialize key + value memory tensors let n_mem = n_layer * n_ctx; @@ -167,7 +167,7 @@ impl InferenceSession { }; let eval = Buffer::new(buf_size); - let ctx0 = ggml::Context::init_buffer(eval); + let ctx0 = ggml::Context::new_with_buffer(eval); // Set up Metal support #[cfg(feature = "metal")] @@ -216,7 +216,7 @@ impl InferenceSession { F: FnOnce(BuildContext) -> (ComputationGraph, GraphOutputs), { // Build a graph - self.ctx0 = ggml::Context::init_buffer(self.ctx0.buffer.take().unwrap()); + self.ctx0.recreate(); let ctx0 = &mut self.ctx0; let mut embd = ctx0.new_tensor_1d(ggml::Type::I32, input_tokens.len()); ggml::set_tensor_name(&embd, "embd"); diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 8c5d2653..c3ff904b 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -527,10 +527,10 @@ pub fn load( unsafe { let mmap = Mmap::map(&file)?; let file_size = mmap.len() as u64; - (Context::init_mmap(mmap), file_size) + (Context::new_with_mmap(mmap), file_size) } } else { - (Context::init(ctx_size, true), file.metadata()?.len()) + (Context::new_with_allocate(ctx_size), file.metadata()?.len()) }; let tensors_len = tensors.len(); @@ -646,7 +646,7 @@ impl TensorLoader for MmapCompatibleLoader<'_> { &self.context, &mut self.file, &self.path, - self.context.mmap.as_ref(), + self.context.mmap(), ); let mut tensor = main_context.get_tensor(info)?; diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index 8cdc2c88..9dcba74a 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -105,7 +105,7 @@ impl LoraAdapter { // Create a temporary context for the patching operations // TODO: test if GPU can be enabled (make it configurable) - let patch_context = ggml::Context::init(patch_context_size, true); + let patch_context = ggml::Context::new_with_allocate(patch_context_size); let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path, None); // Load the A and B tensors