From ffb0519dde7a33097253b7102ee8d17bd7686c48 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 20 Aug 2023 20:16:25 +0200 Subject: [PATCH 01/33] refactor: move ggml format to module --- binaries/llm-cli/src/cli_args.rs | 6 +- binaries/llm-cli/src/main.rs | 2 +- crates/ggml/src/format/{ => ggml}/loader.rs | 72 +--------------- crates/ggml/src/format/ggml/mod.rs | 93 +++++++++++++++++++++ crates/ggml/src/format/{ => ggml}/saver.rs | 4 +- crates/ggml/src/{ => format/ggml}/tests.rs | 39 +++++---- crates/ggml/src/format/mod.rs | 59 +++++++++++-- crates/ggml/src/lib.rs | 88 ------------------- crates/ggml/src/util.rs | 24 +++++- crates/llm-base/src/loader.rs | 17 ++-- crates/llm-base/src/lora.rs | 2 +- crates/llm-base/src/model/mod.rs | 2 +- crates/llm-base/src/quantize.rs | 10 ++- crates/llm-base/src/tokenizer/mod.rs | 2 +- 14 files changed, 220 insertions(+), 200 deletions(-) rename crates/ggml/src/format/{ => ggml}/loader.rs (72%) create mode 100644 crates/ggml/src/format/ggml/mod.rs rename crates/ggml/src/format/{ => ggml}/saver.rs (98%) rename crates/ggml/src/{ => format/ggml}/tests.rs (83%) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 21b4a897..71440d47 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -672,11 +672,11 @@ impl fmt::Display for SaveContainerType { } } } -impl From for ggml_format::SaveContainerType { +impl From for ggml_format::ggml::SaveContainerType { fn from(value: SaveContainerType) -> Self { match value { - SaveContainerType::Ggml => ggml_format::SaveContainerType::Ggml, - SaveContainerType::GgjtV3 => ggml_format::SaveContainerType::GgjtV3, + SaveContainerType::Ggml => ggml_format::ggml::SaveContainerType::Ggml, + SaveContainerType::GgjtV3 => ggml_format::ggml::SaveContainerType::GgjtV3, } } } diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index b0eabece..a623d721 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -143,7 +143,7 @@ fn info(args: &cli_args::Info) -> eyre::Result<()> { // We purposely do not print progress here, as we are only interested in the metadata }); - llm::ggml_format::load(&mut reader, &mut loader)?; + llm::ggml_format::ggml::load(&mut reader, &mut loader)?; log::info!("Container type: {:?}", loader.container_type); log::info!("Hyperparameters: {:?}", loader.hyperparameters); diff --git a/crates/ggml/src/format/loader.rs b/crates/ggml/src/format/ggml/loader.rs similarity index 72% rename from crates/ggml/src/format/loader.rs rename to crates/ggml/src/format/ggml/loader.rs index 8a1a42ae..0700d60e 100644 --- a/crates/ggml/src/format/loader.rs +++ b/crates/ggml/src/format/ggml/loader.rs @@ -6,67 +6,16 @@ use std::{ error::Error, - fmt, io::{BufRead, Seek, SeekFrom}, }; use crate::{ + format::{data_size, header_size, LoadError}, util::{has_data_left, read_bytes_with_len, read_f32, read_i32, read_u32}, - ContainerType, ElementType, + ElementType, }; -/// Helper struct that wraps the magic number of a file format, -/// so that it can be printed in a human-readable format. -pub struct FormatMagic(pub u32); -impl fmt::Display for FormatMagic { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{:x} ({})", - self.0, - String::from_utf8_lossy(&self.0.to_le_bytes()) - ) - } -} -impl fmt::Debug for FormatMagic { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - fmt::Display::fmt(self, f) - } -} - -#[derive(Debug, thiserror::Error)] -/// Errors that can occur while loading a model. -pub enum LoadError { - #[error("invalid file magic number: {0}")] - /// The file magic number is invalid. - InvalidMagic(FormatMagic), - #[error("invalid ggml format: format={0:?}")] - /// An unsupported format version was found. - InvalidFormatVersion(ContainerType), - #[error("non-specific I/O error")] - /// A non-specific IO error. - Io(#[from] std::io::Error), - #[error("could not convert bytes to a UTF-8 string")] - /// One of the strings encountered was not valid UTF-8. - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("implementation error")] - /// An error `E` was returned by the implementation of the loader. - ImplementationError(#[source] E), - #[error("unsupported tensor type {ftype} for tensor {tensor_name}")] - /// One of the tensors encountered had an unsupported data type. - UnsupportedElementType { - /// The name of the tensor. - tensor_name: String, - /// The format type that was encountered. - ftype: u32, - }, - #[error("invariant broken: {0}")] - /// An invariant was broken. - InvariantBroken(String), -} +use super::ContainerType; #[derive(Debug, Clone)] /// Information about a [tensor](https://en.wikipedia.org/wiki/Tensor_(machine_learning)) that is being read. @@ -118,21 +67,6 @@ impl TensorLoadInfo { } } -/// Returns the size occupied by a tensor's data in bytes given the element type and number of elements. -pub(crate) fn data_size(element_type: ElementType, n_elements: usize) -> usize { - (crate::type_size(element_type) * n_elements) / crate::blck_size(element_type) -} - -/// Returns the size of the ggml tensor header in bytes. -pub(crate) fn header_size() -> usize { - crate::Tensor::C_TYPE_SIZE + crate::OBJECT_SIZE -} - -/// Returns the size of a tensor in bytes given the element type and number of elements. This includes the tensor's header. -pub fn tensor_size(element_type: ElementType, n_elements: usize) -> usize { - header_size() + data_size(element_type, n_elements) -} - #[derive(Debug, Clone)] /// Information present within GGML [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) /// that is required to continue loading the model. diff --git a/crates/ggml/src/format/ggml/mod.rs b/crates/ggml/src/format/ggml/mod.rs new file mode 100644 index 00000000..73277cfe --- /dev/null +++ b/crates/ggml/src/format/ggml/mod.rs @@ -0,0 +1,93 @@ +//! Loading and saving of [GGML](https://github.com/ggerganov/ggml) files. + +mod loader; +mod saver; + +pub use loader::*; +pub use saver::*; + +#[cfg(test)] +mod tests; + +use crate::{format::LoadError, util}; + +/// Magic constant for `ggml` files (unversioned). +pub const FILE_MAGIC_GGML: u32 = 0x67676d6c; +/// Magic constant for `ggml` files (versioned, ggmf). +pub const FILE_MAGIC_GGMF: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggjt). +pub const FILE_MAGIC_GGJT: u32 = 0x67676a74; +/// Magic constant for `ggla` files (LoRA adapter). +pub const FILE_MAGIC_GGLA: u32 = 0x67676C61; + +#[derive(Debug, PartialEq, Clone, Copy)] +/// The format of the file containing the model. +pub enum ContainerType { + /// Legacy format, oldest ggml tensor file format + Ggml, + /// Legacy format. Introduces versioning. Newer than GGML, older than GGJT. + Ggmf(u32), + /// [mmap](https://en.wikipedia.org/wiki/Mmap)-able format. Current version of the format. + Ggjt(u32), + /// LoRA adapter format. + Ggla(u32), +} +impl ContainerType { + /// Does this container type support mmap? + pub fn support_mmap(&self) -> bool { + match self { + ContainerType::Ggml => false, + ContainerType::Ggmf(_) => false, + ContainerType::Ggla(_) => false, + ContainerType::Ggjt(_) => true, + } + } + + /// Read the container type from a reader. + pub fn read( + reader: &mut dyn std::io::BufRead, + ) -> Result> { + // Verify magic + let magic = util::read_u32(reader)?; + let container_type: ContainerType = match magic { + FILE_MAGIC_GGML => ContainerType::Ggml, + FILE_MAGIC_GGMF => { + let version = util::read_u32(reader)?; + ContainerType::Ggmf(version) + } + FILE_MAGIC_GGJT => { + let version = util::read_u32(reader)?; + ContainerType::Ggjt(version) + } + FILE_MAGIC_GGLA => { + let version = util::read_u32(reader)?; + ContainerType::Ggla(version) + } + magic => return Err(LoadError::InvalidMagic(util::FormatMagic(magic))), + }; + + Ok(container_type) + } + + /// Write the container type to a writer. + pub fn write(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> { + match self { + ContainerType::Ggml => { + util::write_u32(writer, FILE_MAGIC_GGML)?; + } + ContainerType::Ggmf(version) => { + util::write_u32(writer, FILE_MAGIC_GGMF)?; + util::write_u32(writer, *version)?; + } + ContainerType::Ggjt(version) => { + util::write_u32(writer, FILE_MAGIC_GGJT)?; + util::write_u32(writer, *version)?; + } + ContainerType::Ggla(version) => { + util::write_u32(writer, FILE_MAGIC_GGLA)?; + util::write_u32(writer, *version)?; + } + } + Ok(()) + } +} diff --git a/crates/ggml/src/format/saver.rs b/crates/ggml/src/format/ggml/saver.rs similarity index 98% rename from crates/ggml/src/format/saver.rs rename to crates/ggml/src/format/ggml/saver.rs index 86b4bd24..d8b87a52 100644 --- a/crates/ggml/src/format/saver.rs +++ b/crates/ggml/src/format/ggml/saver.rs @@ -9,7 +9,9 @@ use std::{ io::{Seek, Write}, }; -use crate::{util, ContainerType, ElementType}; +use crate::{util, ElementType}; + +use super::ContainerType; #[derive(Debug, thiserror::Error)] /// Errors that can occur while writing a model. diff --git a/crates/ggml/src/tests.rs b/crates/ggml/src/format/ggml/tests.rs similarity index 83% rename from crates/ggml/src/tests.rs rename to crates/ggml/src/format/ggml/tests.rs index b842f45d..0c535975 100644 --- a/crates/ggml/src/tests.rs +++ b/crates/ggml/src/format/ggml/tests.rs @@ -4,9 +4,12 @@ use std::{ io::{BufRead, Write}, }; -use crate::*; use rand::{distributions::Uniform, prelude::*}; +use crate::format::data_size; + +use super::*; + #[derive(Debug)] struct DummyError; impl std::fmt::Display for DummyError { @@ -25,7 +28,7 @@ fn can_roundtrip_loader_and_saver_ggml() { ("efficient".as_bytes().to_vec(), 0.0), ]; - roundtrip_test(format::SaveContainerType::Ggml, tokenizer).unwrap(); + roundtrip_test(SaveContainerType::Ggml, tokenizer).unwrap(); } #[test] @@ -38,10 +41,10 @@ fn will_fail_on_scored_ggml_save() { ]; assert_eq!( - roundtrip_test(format::SaveContainerType::Ggml, tokenizer) + roundtrip_test(SaveContainerType::Ggml, tokenizer) .unwrap_err() .to_string(), - format::SaveError::::VocabularyScoringNotSupported.to_string() + SaveError::::VocabularyScoringNotSupported.to_string() ); } @@ -54,11 +57,11 @@ fn can_roundtrip_loader_and_saver_ggjt_v3() { ("efficient".as_bytes().to_vec(), 0.4), ]; - roundtrip_test(format::SaveContainerType::GgjtV3, tokenizer).unwrap(); + roundtrip_test(SaveContainerType::GgjtV3, tokenizer).unwrap(); } fn roundtrip_test( - save_container_type: format::SaveContainerType, + save_container_type: SaveContainerType, tokenizer: Vec<(Vec, f32)>, ) -> anyhow::Result<()> { let mut rng = rand::thread_rng(); @@ -79,13 +82,13 @@ fn roundtrip_test( .collect::>(); let n_elements = dims.iter().product::(); - let data = (0..format::data_size(element_type, n_elements)) + let data = (0..data_size(element_type, n_elements)) .map(|_| random()) .collect::>(); ( format!("tensor_{}", i), - format::TensorSaveInfo { + TensorSaveInfo { n_dims, dims: dims.try_into().unwrap(), element_type, @@ -100,7 +103,7 @@ fn roundtrip_test( let mut buffer = Vec::new(); let mut cursor = std::io::Cursor::new(&mut buffer); let mut save_handler = MockSaveHandler { model: &model }; - format::save( + save( &mut cursor, &mut save_handler, save_container_type, @@ -115,7 +118,7 @@ fn roundtrip_test( loaded_model: Model::default(), expected_container_type: save_container_type.into(), }; - format::load(&mut cursor, &mut load_handler)?; + load(&mut cursor, &mut load_handler)?; assert_eq!(load_handler.loaded_model, model); Ok(()) @@ -148,19 +151,19 @@ impl Hyperparameters { struct Model { hyperparameters: Hyperparameters, tokenizer: Vec<(Vec, f32)>, - tensors: BTreeMap, + tensors: BTreeMap, } struct MockSaveHandler<'a> { model: &'a Model, } -impl format::SaveHandler for MockSaveHandler<'_> { +impl SaveHandler for MockSaveHandler<'_> { fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), DummyError> { self.model.hyperparameters.write(writer).unwrap(); Ok(()) } - fn tensor_data(&mut self, tensor_name: &str) -> Result { + fn tensor_data(&mut self, tensor_name: &str) -> Result { self.model .tensors .get(tensor_name) @@ -174,7 +177,7 @@ struct MockLoadHandler<'a> { loaded_model: Model, expected_container_type: ContainerType, } -impl format::LoadHandler for MockLoadHandler<'_> { +impl LoadHandler for MockLoadHandler<'_> { fn container_type(&mut self, container_type: ContainerType) -> Result<(), DummyError> { assert_eq!(container_type, self.expected_container_type); Ok(()) @@ -189,9 +192,9 @@ impl format::LoadHandler for MockLoadHandler<'_> { fn read_hyperparameters( &mut self, reader: &mut dyn BufRead, - ) -> Result { + ) -> Result { self.loaded_model.hyperparameters = Hyperparameters::read(reader).unwrap(); - Ok(format::PartialHyperparameters { + Ok(PartialHyperparameters { n_vocab: self .loaded_model .hyperparameters @@ -201,8 +204,8 @@ impl format::LoadHandler for MockLoadHandler<'_> { }) } - fn tensor_buffer(&mut self, info: format::TensorLoadInfo) -> Result<(), DummyError> { - let data = format::TensorSaveInfo { + fn tensor_buffer(&mut self, info: TensorLoadInfo) -> Result<(), DummyError> { + let data = TensorSaveInfo { n_dims: info.n_dims, dims: info.dims, element_type: info.element_type, diff --git a/crates/ggml/src/format/mod.rs b/crates/ggml/src/format/mod.rs index f1a939b7..5370343f 100644 --- a/crates/ggml/src/format/mod.rs +++ b/crates/ggml/src/format/mod.rs @@ -1,7 +1,56 @@ -//! Loading and saving of [GGML](https://github.com/ggerganov/ggml) files. +//! Loading and saving of GGML-related files. -mod loader; -mod saver; +use std::error::Error; -pub use loader::*; -pub use saver::*; +use crate::{util::FormatMagic, ElementType}; + +pub mod ggml; + +#[derive(Debug, thiserror::Error)] +/// Errors that can occur while loading a model. +pub enum LoadError { + #[error("invalid file magic number: {0}")] + /// The file magic number is invalid. + InvalidMagic(FormatMagic), + #[error("invalid ggml format: format={0:?}")] + /// An unsupported format version was found. + InvalidFormatVersion(ggml::ContainerType), + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("implementation error")] + /// An error `E` was returned by the implementation of the loader. + ImplementationError(#[source] E), + #[error("unsupported tensor type {ftype} for tensor {tensor_name}")] + /// One of the tensors encountered had an unsupported data type. + UnsupportedElementType { + /// The name of the tensor. + tensor_name: String, + /// The format type that was encountered. + ftype: u32, + }, + #[error("invariant broken: {0}")] + /// An invariant was broken. + InvariantBroken(String), +} + +/// Returns the size occupied by a tensor's data in bytes given the element type and number of elements. +pub(crate) fn data_size(element_type: ElementType, n_elements: usize) -> usize { + (crate::type_size(element_type) * n_elements) / crate::blck_size(element_type) +} + +/// Returns the size of the ggml tensor header in bytes. +pub(crate) fn header_size() -> usize { + crate::Tensor::C_TYPE_SIZE + crate::OBJECT_SIZE +} + +/// Returns the size of a tensor in bytes given the element type and number of elements. This includes the tensor's header. +pub fn tensor_size(element_type: ElementType, n_elements: usize) -> usize { + header_size() + data_size(element_type, n_elements) +} diff --git a/crates/ggml/src/lib.rs b/crates/ggml/src/lib.rs index 8b6910eb..0cef3591 100644 --- a/crates/ggml/src/lib.rs +++ b/crates/ggml/src/lib.rs @@ -26,97 +26,9 @@ pub use tensor::Tensor; pub use ggml_sys as sys; -#[cfg(test)] -mod tests; - /// The type of a tensor element. pub type ElementType = Type; -#[derive(Debug, PartialEq, Clone, Copy)] -/// The format of the file containing the model. -pub enum ContainerType { - /// Legacy format, oldest ggml tensor file format - Ggml, - /// Legacy format. Introduces versioning. Newer than GGML, older than GGJT. - Ggmf(u32), - /// [mmap](https://en.wikipedia.org/wiki/Mmap)-able format. Current version of the format. - Ggjt(u32), - /// LoRA adapter format. - Ggla(u32), -} -impl ContainerType { - /// Does this container type support mmap? - pub fn support_mmap(&self) -> bool { - match self { - ContainerType::Ggml => false, - ContainerType::Ggmf(_) => false, - ContainerType::Ggla(_) => false, - ContainerType::Ggjt(_) => true, - } - } - - /// Read the container type from a reader. - pub fn read( - reader: &mut dyn std::io::BufRead, - ) -> Result> { - // Verify magic - let magic = util::read_u32(reader)?; - let container_type: ContainerType = match magic { - crate::FILE_MAGIC_GGML => ContainerType::Ggml, - crate::FILE_MAGIC_GGMF => { - let version = util::read_u32(reader)?; - ContainerType::Ggmf(version) - } - crate::FILE_MAGIC_GGJT => { - let version = util::read_u32(reader)?; - ContainerType::Ggjt(version) - } - crate::FILE_MAGIC_GGLA => { - let version = util::read_u32(reader)?; - ContainerType::Ggla(version) - } - magic => { - return Err(crate::format::LoadError::InvalidMagic(format::FormatMagic( - magic, - ))) - } - }; - - Ok(container_type) - } - - /// Write the container type to a writer. - pub fn write(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> { - match self { - ContainerType::Ggml => { - util::write_u32(writer, FILE_MAGIC_GGML)?; - } - ContainerType::Ggmf(version) => { - util::write_u32(writer, FILE_MAGIC_GGMF)?; - util::write_u32(writer, *version)?; - } - ContainerType::Ggjt(version) => { - util::write_u32(writer, FILE_MAGIC_GGJT)?; - util::write_u32(writer, *version)?; - } - ContainerType::Ggla(version) => { - util::write_u32(writer, FILE_MAGIC_GGLA)?; - util::write_u32(writer, *version)?; - } - } - Ok(()) - } -} - -/// Magic constant for `ggml` files (unversioned). -pub const FILE_MAGIC_GGML: u32 = 0x67676d6c; -/// Magic constant for `ggml` files (versioned, ggmf). -pub const FILE_MAGIC_GGMF: u32 = 0x67676d66; -/// Magic constant for `ggml` files (versioned, ggjt). -pub const FILE_MAGIC_GGJT: u32 = 0x67676a74; -/// Magic constant for `ggla` files (LoRA adapter). -pub const FILE_MAGIC_GGLA: u32 = 0x67676C61; - /// The current quantization version. pub const QNT_VERSION: u32 = sys::GGML_QNT_VERSION; /// The factor by which to divide `ftype` to determine the current quantization version. diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index 69b20de4..8d2e92f0 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -1,6 +1,28 @@ //! Utilities for reading and writing. -use std::io::{BufRead, Write}; +use std::{ + fmt, + io::{BufRead, Write}, +}; + +/// Helper struct that wraps the magic number of a file format, +/// so that it can be printed in a human-readable format. +pub struct FormatMagic(pub u32); +impl fmt::Display for FormatMagic { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{:x} ({})", + self.0, + String::from_utf8_lossy(&self.0.to_le_bytes()) + ) + } +} +impl fmt::Debug for FormatMagic { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + fmt::Display::fmt(self, f) + } +} /// Read a fixed-size array of bytes from a reader. pub fn read_bytes(reader: &mut dyn BufRead) -> Result<[u8; N], std::io::Error> { diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index c98ab048..47155146 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -11,9 +11,12 @@ use crate::{ util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, TokenId, Tokenizer, TokenizerLoadError, TokenizerSource, }; -pub use ggml::{format::FormatMagic, ContainerType}; +pub use ggml::{format::ggml::ContainerType, util::FormatMagic}; use ggml::{ - format::{LoadError as FormatLoadError, PartialHyperparameters, TensorLoadInfo}, + format::{ + ggml::{PartialHyperparameters, TensorLoadInfo}, + LoadError as FormatLoadError, + }, Context, MAX_NAME_LENGTH, }; use memmap2::Mmap; @@ -442,7 +445,7 @@ pub fn load( let tokenizer = tokenizer_source.retrieve(path)?; let mut loader = Loader::new(tokenizer, load_progress_callback); - ggml::format::load(&mut reader, &mut loader) + ggml::format::ggml::load(&mut reader, &mut loader) .map_err(|err| LoadError::from_format_error(err, path.to_owned()))?; log::trace!("Loaded GGML model from reader"); @@ -462,9 +465,9 @@ pub fn load( let quantization_version = if quantization_version == 0 { // HACK: I think llama.cpp does not actually write the quantization version correctly, // so we need to guess it from the container type. - if container_type == ggml::ContainerType::Ggjt(2) { + if container_type == ContainerType::Ggjt(2) { 1 - } else if container_type == ggml::ContainerType::Ggjt(3) { + } else if container_type == ContainerType::Ggjt(3) { 2 } else { quantization_version @@ -506,7 +509,7 @@ pub fn load( // Most LoRAs are small enough that this is not necessary, but it would be nice to have. let mut lora_loader: Loader = Loader::new(Tokenizer::empty_embedded(), |_| {}); - ggml::format::load(&mut lora_reader, &mut lora_loader) + ggml::format::ggml::load(&mut lora_reader, &mut lora_loader) .map_err(|err| LoadError::from_format_error(err, lora_path.to_owned()))?; // Collect the names of the tensors that should be patched @@ -595,7 +598,7 @@ impl Loader { } } } -impl ggml::format::LoadHandler +impl ggml::format::ggml::LoadHandler for Loader { fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> { diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index b6ed4a0f..2fab1a04 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -3,7 +3,7 @@ use crate::{ LoadError, }; -use ggml::{format::TensorLoadInfo, GraphExecutionPlan}; +use ggml::{format::ggml::TensorLoadInfo, GraphExecutionPlan}; use std::{ collections::{HashMap, HashSet}, fs::File, diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index b31faf56..272e7735 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -194,7 +194,7 @@ pub enum HyperparametersWriteError { /// Parameters for model-wide behaviour. #[derive(Debug, Clone)] pub struct ModelParameters { - /// For [GGML formats](ggml::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap) + /// For [GGML formats](ggml::format::ggml::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap) /// is the default. Although mmap typically improves performance, setting this value to `false` may /// be preferred in resource-constrained environments. pub prefer_mmap: bool, diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index d3d2a0cf..efb30044 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -4,7 +4,9 @@ use crate::{ loader::FileTypeFormat, model::HyperparametersWriteError, Hyperparameters, KnownModel, LoadError, LoadProgress, Loader, Tokenizer, }; -use ggml::format::{SaveError, SaveHandler, TensorLoadInfo, TensorSaveInfo}; +use ggml::format::ggml::{ + SaveContainerType, SaveError, SaveHandler, TensorLoadInfo, TensorSaveInfo, +}; use half::f16; use regex::Regex; use std::{ @@ -140,7 +142,7 @@ pub fn quantize( reader: &mut R, writer: &mut W, tokenizer: Tokenizer, - save_container_type: ggml::format::SaveContainerType, + save_container_type: SaveContainerType, quantization_type: ggml::Type, progress_callback: impl Fn(QuantizeProgress), ) -> Result<(), QuantizeError> { @@ -162,7 +164,7 @@ pub fn quantize( } } }); - ggml::format::load(reader, &mut loader) + ggml::format::ggml::load(reader, &mut loader) .map_err(|err| LoadError::from_format_error(err, PathBuf::default()))?; // Save the quantized model, quantizing as we go @@ -196,7 +198,7 @@ pub fn quantize( reader, |p| progress_callback(p), ); - ggml::format::save( + ggml::format::ggml::save( writer, &mut saver, save_container_type, diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 03b2f0b9..48be8926 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -76,7 +76,7 @@ impl TokenizerSource { /// Retrieve the tokenizer from the source. /// /// Note that this may make a blocking HTTP request to Hugging Face to retrieve the tokenizer. - /// if `self` is [`Self::HuggingFaceRemote`]. + /// if `self` is `Self::HuggingFaceRemote`. pub fn retrieve(self, model_path: &Path) -> Result { let _ = model_path; From d5c2562ccc38b18fa74b387f6dfe5a7c3282ad4a Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 20 Aug 2023 22:14:43 +0200 Subject: [PATCH 02/33] fix(ggml): use byte-arrays for magic --- crates/ggml/src/format/ggml/mod.rs | 20 +++++++++++--------- crates/ggml/src/util.rs | 9 ++------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/crates/ggml/src/format/ggml/mod.rs b/crates/ggml/src/format/ggml/mod.rs index 73277cfe..6354370f 100644 --- a/crates/ggml/src/format/ggml/mod.rs +++ b/crates/ggml/src/format/ggml/mod.rs @@ -12,13 +12,13 @@ mod tests; use crate::{format::LoadError, util}; /// Magic constant for `ggml` files (unversioned). -pub const FILE_MAGIC_GGML: u32 = 0x67676d6c; +pub const FILE_MAGIC_GGML: [u8; 4] = *b"lmgg"; /// Magic constant for `ggml` files (versioned, ggmf). -pub const FILE_MAGIC_GGMF: u32 = 0x67676d66; +pub const FILE_MAGIC_GGMF: [u8; 4] = *b"fmgg"; /// Magic constant for `ggml` files (versioned, ggjt). -pub const FILE_MAGIC_GGJT: u32 = 0x67676a74; +pub const FILE_MAGIC_GGJT: [u8; 4] = *b"tjgg"; /// Magic constant for `ggla` files (LoRA adapter). -pub const FILE_MAGIC_GGLA: u32 = 0x67676C61; +pub const FILE_MAGIC_GGLA: [u8; 4] = *b"algg"; #[derive(Debug, PartialEq, Clone, Copy)] /// The format of the file containing the model. @@ -48,7 +48,7 @@ impl ContainerType { reader: &mut dyn std::io::BufRead, ) -> Result> { // Verify magic - let magic = util::read_u32(reader)?; + let magic = util::read_bytes::<4>(reader)?; let container_type: ContainerType = match magic { FILE_MAGIC_GGML => ContainerType::Ggml, FILE_MAGIC_GGMF => { @@ -73,18 +73,20 @@ impl ContainerType { pub fn write(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> { match self { ContainerType::Ggml => { - util::write_u32(writer, FILE_MAGIC_GGML)?; + writer.write_all(&FILE_MAGIC_GGML)?; } ContainerType::Ggmf(version) => { - util::write_u32(writer, FILE_MAGIC_GGMF)?; + writer.write_all(&FILE_MAGIC_GGMF)?; util::write_u32(writer, *version)?; } ContainerType::Ggjt(version) => { - util::write_u32(writer, FILE_MAGIC_GGJT)?; + writer.write_all(&FILE_MAGIC_GGJT)?; util::write_u32(writer, *version)?; } ContainerType::Ggla(version) => { - util::write_u32(writer, FILE_MAGIC_GGLA)?; + writer.write_all(&FILE_MAGIC_GGLA)?; + util::write_u32(writer, *version)?; + } util::write_u32(writer, *version)?; } } diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index 8d2e92f0..814b349f 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -7,15 +7,10 @@ use std::{ /// Helper struct that wraps the magic number of a file format, /// so that it can be printed in a human-readable format. -pub struct FormatMagic(pub u32); +pub struct FormatMagic(pub [u8; 4]); impl fmt::Display for FormatMagic { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{:x} ({})", - self.0, - String::from_utf8_lossy(&self.0.to_le_bytes()) - ) + write!(f, "{:x?} ({})", self.0, String::from_utf8_lossy(&self.0)) } } impl fmt::Debug for FormatMagic { From dd7aa263ffcbd57e0f9e517996ca2d526f34a17f Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 20 Aug 2023 22:15:40 +0200 Subject: [PATCH 03/33] feat(ggml): impl unwired gguf loader --- crates/ggml/examples/gguf.rs | 14 ++ crates/ggml/src/format/ggml/loader.rs | 6 + crates/ggml/src/format/ggml/mod.rs | 13 +- crates/ggml/src/format/gguf/mod.rs | 288 ++++++++++++++++++++++++++ crates/ggml/src/format/mod.rs | 1 + crates/ggml/src/util.rs | 13 ++ 6 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 crates/ggml/examples/gguf.rs create mode 100644 crates/ggml/src/format/gguf/mod.rs diff --git a/crates/ggml/examples/gguf.rs b/crates/ggml/examples/gguf.rs new file mode 100644 index 00000000..f7c5328f --- /dev/null +++ b/crates/ggml/examples/gguf.rs @@ -0,0 +1,14 @@ +use std::io::BufReader; + +use ggml::format::gguf; + +fn main() -> anyhow::Result<()> { + let mut file = BufReader::new(std::fs::File::open( + std::env::args().nth(1).expect("need a file to read"), + )?); + + let gguf = gguf::Gguf::load(&mut file)?; + dbg!(gguf); + + Ok(()) +} diff --git a/crates/ggml/src/format/ggml/loader.rs b/crates/ggml/src/format/ggml/loader.rs index 0700d60e..7b079ee8 100644 --- a/crates/ggml/src/format/ggml/loader.rs +++ b/crates/ggml/src/format/ggml/loader.rs @@ -126,6 +126,9 @@ pub fn load( // Legacy model, set empty score 0. } + ContainerType::Gguf(_) => { + unreachable!("This loader should not be used with GGUF") + } }; handler .vocabulary_token(i, token, token_score) @@ -138,6 +141,9 @@ pub fn load( ContainerType::Ggjt(_version) | ContainerType::Ggla(_version) => { load_weights(reader, handler, true) } + ContainerType::Gguf(_) => { + unreachable!("This loader should not be used with GGUF") + } } } diff --git a/crates/ggml/src/format/ggml/mod.rs b/crates/ggml/src/format/ggml/mod.rs index 6354370f..774ce69d 100644 --- a/crates/ggml/src/format/ggml/mod.rs +++ b/crates/ggml/src/format/ggml/mod.rs @@ -19,6 +19,8 @@ pub const FILE_MAGIC_GGMF: [u8; 4] = *b"fmgg"; pub const FILE_MAGIC_GGJT: [u8; 4] = *b"tjgg"; /// Magic constant for `ggla` files (LoRA adapter). pub const FILE_MAGIC_GGLA: [u8; 4] = *b"algg"; +/// Magic constant for `gguf` files. +pub const FILE_MAGIC_GGUF: [u8; 4] = *b"GGUF"; #[derive(Debug, PartialEq, Clone, Copy)] /// The format of the file containing the model. @@ -27,10 +29,12 @@ pub enum ContainerType { Ggml, /// Legacy format. Introduces versioning. Newer than GGML, older than GGJT. Ggmf(u32), - /// [mmap](https://en.wikipedia.org/wiki/Mmap)-able format. Current version of the format. + /// [mmap](https://en.wikipedia.org/wiki/Mmap)-able format. Ggjt(u32), /// LoRA adapter format. Ggla(u32), + /// GGUF format. Current version of the format. + Gguf(u32), } impl ContainerType { /// Does this container type support mmap? @@ -40,6 +44,7 @@ impl ContainerType { ContainerType::Ggmf(_) => false, ContainerType::Ggla(_) => false, ContainerType::Ggjt(_) => true, + ContainerType::Gguf(_) => true, } } @@ -63,6 +68,10 @@ impl ContainerType { let version = util::read_u32(reader)?; ContainerType::Ggla(version) } + FILE_MAGIC_GGUF => { + let version = util::read_u32(reader)?; + ContainerType::Gguf(version) + } magic => return Err(LoadError::InvalidMagic(util::FormatMagic(magic))), }; @@ -87,6 +96,8 @@ impl ContainerType { writer.write_all(&FILE_MAGIC_GGLA)?; util::write_u32(writer, *version)?; } + ContainerType::Gguf(version) => { + writer.write_all(&FILE_MAGIC_GGUF)?; util::write_u32(writer, *version)?; } } diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs new file mode 100644 index 00000000..97156ff7 --- /dev/null +++ b/crates/ggml/src/format/gguf/mod.rs @@ -0,0 +1,288 @@ +#![allow(missing_docs)] + +use std::{ + collections::HashMap, + convert::Infallible, + io::{BufRead, Seek}, +}; + +use crate::{util, ElementType}; + +use super::{ggml::ContainerType, LoadError}; + +pub const DEFAULT_ALIGNMENT: u32 = 32; + +#[derive(Debug, Clone, PartialEq)] +pub struct Gguf { + pub metadata: HashMap, + pub tensor_infos: HashMap, + pub tensor_data_position: u64, +} +impl Gguf { + pub fn load(reader: &mut R) -> Result> { + let container = ContainerType::read(reader)?; + if container != ContainerType::Gguf(1) { + return Err(LoadError::InvalidFormatVersion(container)); + } + + let tensor_count = util::read_u32(reader)? as usize; + let metadata_kv_count = util::read_u32(reader)? as usize; + + let mut metadata = HashMap::with_capacity(metadata_kv_count); + for _ in 0..metadata_kv_count { + let (key, value) = MetadataValue::read_key_value(reader)?; + metadata.insert(key, value); + } + + let alignment = metadata + .get("general.alignment") + .and_then(|v| v.as_uint32()) + .unwrap_or(DEFAULT_ALIGNMENT) as u64; + + let mut tensor_infos = HashMap::with_capacity(tensor_count); + for _ in 0..tensor_count { + let (key, value) = TensorInfo::read_name_value(reader)?; + tensor_infos.insert(key, value); + } + + let tensor_data_position = { + let stream_position = reader.stream_position()?; + stream_position + (alignment - (stream_position % alignment)) % alignment + }; + + Ok(Gguf { + metadata, + tensor_infos, + tensor_data_position, + }) + } +} + +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum MetadataValueType { + // The value is a 8-bit unsigned integer. + UInt8 = 0, + // The value is a 8-bit signed integer. + Int8 = 1, + // The value is a 16-bit unsigned little-endian integer. + UInt16 = 2, + // The value is a 16-bit signed little-endian integer. + Int16 = 3, + // The value is a 32-bit unsigned little-endian integer. + UInt32 = 4, + // The value is a 32-bit signed little-endian integer. + Int32 = 5, + // The value is a 32-bit IEEE754 floating point number. + Float32 = 6, + // The value is a boolean. + // 1-byte value where 0 is false and 1 is true. + // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + Bool = 7, + // The value is a UTF-8 non-null-terminated string, with length prepended. + String = 8, + // The value is an array of other values, with the length and type prepended. + /// + // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + Array = 9, +} +impl TryFrom for MetadataValueType { + type Error = (); + + fn try_from(value: u32) -> Result { + // TODO: consider a macro solution to this? + for test_value in [ + MetadataValueType::UInt8, + MetadataValueType::Int8, + MetadataValueType::UInt16, + MetadataValueType::Int16, + MetadataValueType::UInt32, + MetadataValueType::Int32, + MetadataValueType::Float32, + MetadataValueType::Bool, + MetadataValueType::String, + MetadataValueType::Array, + ] { + if value == test_value as u32 { + return Ok(test_value); + } + } + Err(()) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum MetadataValue { + UInt8(u8), + Int8(i8), + UInt16(u16), + Int16(i16), + UInt32(u32), + Int32(i32), + Float32(f32), + Bool(bool), + String(String), + Array(MetadataArrayValue), +} +impl MetadataValue { + fn read_key_value(reader: &mut dyn BufRead) -> Result<(String, Self), LoadError> { + let key = util::read_string(reader)?; + let value_type = MetadataValueType::try_from(util::read_u32(reader)?) + .expect("TODO: handle invalid value types"); + + let value = Self::read_value(reader, value_type)?; + + Ok((key, value)) + } + + fn read_value( + reader: &mut dyn BufRead, + value_type: MetadataValueType, + ) -> Result> { + match value_type { + MetadataValueType::UInt8 => Self::read_u8(reader).map(MetadataValue::UInt8), + MetadataValueType::Int8 => Self::read_i8(reader).map(MetadataValue::Int8), + MetadataValueType::UInt16 => Self::read_u16(reader).map(MetadataValue::UInt16), + MetadataValueType::Int16 => Self::read_i16(reader).map(MetadataValue::Int16), + MetadataValueType::UInt32 => Self::read_u32(reader).map(MetadataValue::UInt32), + MetadataValueType::Int32 => Self::read_i32(reader).map(MetadataValue::Int32), + MetadataValueType::Float32 => Self::read_f32(reader).map(MetadataValue::Float32), + MetadataValueType::Bool => Self::read_bool(reader).map(MetadataValue::Bool), + MetadataValueType::String => Self::read_string(reader).map(MetadataValue::String), + MetadataValueType::Array => Self::read_array(reader).map(MetadataValue::Array), + } + } + + fn read_u8(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_bytes::<1>(reader)?[0]) + } + + fn read_i8(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_bytes::<1>(reader)?[0] as i8) + } + + fn read_u16(reader: &mut dyn BufRead) -> Result> { + Ok(u16::from_le_bytes(util::read_bytes::<2>(reader)?)) + } + + fn read_i16(reader: &mut dyn BufRead) -> Result> { + Ok(i16::from_le_bytes(util::read_bytes::<2>(reader)?)) + } + + fn read_u32(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_u32(reader)?) + } + + fn read_i32(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_i32(reader)?) + } + + fn read_f32(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_f32(reader)?) + } + + fn read_bool(reader: &mut dyn BufRead) -> Result> { + let v = Self::read_u8(reader)?; + if v == 0 { + Ok(false) + } else if v == 1 { + Ok(true) + } else { + panic!("TODO: error for invalid bools") + } + } + + fn read_string(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_string(reader)?) + } + + fn read_array(reader: &mut dyn BufRead) -> Result> { + MetadataArrayValue::read_value(reader) + } + + pub fn as_uint32(&self) -> Option { + match self { + Self::UInt32(v) => Some(*v), + _ => None, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum MetadataArrayValue { + UInt8(Vec), + Int8(Vec), + UInt16(Vec), + Int16(Vec), + UInt32(Vec), + Int32(Vec), + Float32(Vec), + Bool(Vec), + String(Vec), + Array(Vec), +} +impl MetadataArrayValue { + fn read_value(reader: &mut dyn BufRead) -> Result> { + let value_type = MetadataValueType::try_from(util::read_u32(reader)?) + .expect("TODO: handle invalid value types"); + let length = util::read_u32(reader)? as usize; + + fn read_array( + reader: &mut dyn BufRead, + length: usize, + value_reader: impl Fn(&mut dyn BufRead) -> Result>, + ) -> Result, LoadError> { + (0..length).map(|_| value_reader(reader)).collect() + } + + use MetadataValue as MV; + use MetadataValueType as MVT; + Ok(match value_type { + MVT::UInt8 => read_array(reader, length, MV::read_u8).map(Self::UInt8), + MVT::Int8 => read_array(reader, length, MV::read_i8).map(Self::Int8), + MVT::UInt16 => read_array(reader, length, MV::read_u16).map(Self::UInt16), + MVT::Int16 => read_array(reader, length, MV::read_i16).map(Self::Int16), + MVT::UInt32 => read_array(reader, length, MV::read_u32).map(Self::UInt32), + MVT::Int32 => read_array(reader, length, MV::read_i32).map(Self::Int32), + MVT::Float32 => read_array(reader, length, MV::read_f32).map(Self::Float32), + MVT::Bool => read_array(reader, length, MV::read_bool).map(Self::Bool), + MVT::String => read_array(reader, length, MV::read_string).map(Self::String), + MVT::Array => read_array(reader, length, MV::read_array).map(Self::Array), + }?) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct TensorInfo { + pub dimensions: Vec, + pub element_type: ElementType, + pub offset: u64, +} +impl TensorInfo { + fn read_name_value(reader: &mut dyn BufRead) -> Result<(String, Self), LoadError> { + let name = util::read_string(reader)?; + + let dimension_count = util::read_u32(reader)? as usize; + let dimensions = (0..dimension_count) + .map(|_| util::read_u32(reader)) + .collect::, _>>()?; + + let element_type = util::read_u32(reader)?; + let element_type = + ElementType::try_from(element_type).map_err(|_| LoadError::UnsupportedElementType { + tensor_name: name.clone(), + ftype: element_type, + })?; + + let offset = util::read_u64(reader)?; + + Ok(( + name, + Self { + dimensions, + element_type, + offset, + }, + )) + } +} diff --git a/crates/ggml/src/format/mod.rs b/crates/ggml/src/format/mod.rs index 5370343f..9fe51d2b 100644 --- a/crates/ggml/src/format/mod.rs +++ b/crates/ggml/src/format/mod.rs @@ -5,6 +5,7 @@ use std::error::Error; use crate::{util::FormatMagic, ElementType}; pub mod ggml; +pub mod gguf; #[derive(Debug, thiserror::Error)] /// Errors that can occur while loading a model. diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index 814b349f..7812b71b 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -36,6 +36,11 @@ pub fn read_u32(reader: &mut dyn BufRead) -> Result { Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) } +/// Read a `u64` from a reader. +pub fn read_u64(reader: &mut dyn BufRead) -> Result { + Ok(u64::from_le_bytes(read_bytes::<8>(reader)?)) +} + /// Read a `f32` from a reader. pub fn read_f32(reader: &mut dyn BufRead) -> Result { Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) @@ -64,6 +69,14 @@ pub fn read_bytes_with_len( Ok(bytes) } +/// Read a string from a reader. +pub fn read_string(reader: &mut dyn BufRead) -> Result { + let len = read_u32(reader)? as usize; + let bytes = read_bytes_with_len(reader, len)?; + Ok(String::from_utf8(bytes) + .expect("string was not valid utf-8 (TODO: make this a library error)")) +} + /// Write a `i32` from a writer. pub fn write_i32(writer: &mut dyn Write, value: i32) -> Result<(), std::io::Error> { writer.write_all(&value.to_le_bytes()) From e166b7c40130f7e015f6e422cbfd777ee7c35ed9 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 27 Aug 2023 19:29:02 +0200 Subject: [PATCH 04/33] feat(gguf): gguf-v2 support --- crates/ggml/src/format/gguf/mod.rs | 243 ++++++++++++++++++++--------- crates/ggml/src/util.rs | 31 +++- 2 files changed, 199 insertions(+), 75 deletions(-) diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index 97156ff7..19dce419 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -21,16 +21,20 @@ pub struct Gguf { impl Gguf { pub fn load(reader: &mut R) -> Result> { let container = ContainerType::read(reader)?; - if container != ContainerType::Gguf(1) { + if ![ContainerType::Gguf(1), ContainerType::Gguf(2)].contains(&container) { return Err(LoadError::InvalidFormatVersion(container)); } - let tensor_count = util::read_u32(reader)? as usize; - let metadata_kv_count = util::read_u32(reader)? as usize; + let ctx = GgufContext { + use_64_bit_length: container == ContainerType::Gguf(2), + }; + + let tensor_count = util::read_length(reader, ctx.use_64_bit_length)?; + let metadata_kv_count = util::read_length(reader, ctx.use_64_bit_length)?; let mut metadata = HashMap::with_capacity(metadata_kv_count); for _ in 0..metadata_kv_count { - let (key, value) = MetadataValue::read_key_value(reader)?; + let (key, value) = MetadataValue::read_key_value(&ctx, reader)?; metadata.insert(key, value); } @@ -41,7 +45,7 @@ impl Gguf { let mut tensor_infos = HashMap::with_capacity(tensor_count); for _ in 0..tensor_count { - let (key, value) = TensorInfo::read_name_value(reader)?; + let (key, value) = TensorInfo::read_name_value(&ctx, reader)?; tensor_infos.insert(key, value); } @@ -58,33 +62,46 @@ impl Gguf { } } +struct GgufContext { + use_64_bit_length: bool, +} + #[repr(u32)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum MetadataValueType { - // The value is a 8-bit unsigned integer. + /// The value is a 8-bit unsigned integer. UInt8 = 0, - // The value is a 8-bit signed integer. + /// The value is a 8-bit signed integer. Int8 = 1, - // The value is a 16-bit unsigned little-endian integer. + /// The value is a 16-bit unsigned little-endian integer. UInt16 = 2, - // The value is a 16-bit signed little-endian integer. + /// The value is a 16-bit signed little-endian integer. Int16 = 3, - // The value is a 32-bit unsigned little-endian integer. + /// The value is a 32-bit unsigned little-endian integer. UInt32 = 4, - // The value is a 32-bit signed little-endian integer. + /// The value is a 32-bit signed little-endian integer. Int32 = 5, - // The value is a 32-bit IEEE754 floating point number. + /// The value is a 32-bit IEEE754 floating point number. Float32 = 6, - // The value is a boolean. - // 1-byte value where 0 is false and 1 is true. - // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + /// The value is a boolean. + /// 1-byte value where 0 is false and 1 is true. + /// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. Bool = 7, - // The value is a UTF-8 non-null-terminated string, with length prepended. + /// The value is a UTF-8 non-null-terminated string, with length prepended. String = 8, - // The value is an array of other values, with the length and type prepended. + /// The value is an array of other values, with the length and type prepended. /// - // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + /// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. Array = 9, + /// The value is a 64-bit unsigned little-endian integer. + /// Implemented in GGUFv2. + UInt64 = 10, + /// The value is a 64-bit signed little-endian integer. + /// Implemented in GGUFv2. + Int64 = 11, + /// The value is a 64-bit IEEE754 floating point number. + /// Implemented in GGUFv2. + Float64 = 12, } impl TryFrom for MetadataValueType { type Error = (); @@ -102,6 +119,9 @@ impl TryFrom for MetadataValueType { MetadataValueType::Bool, MetadataValueType::String, MetadataValueType::Array, + MetadataValueType::UInt64, + MetadataValueType::Int64, + MetadataValueType::Float64, ] { if value == test_value as u32 { return Ok(test_value); @@ -123,81 +143,128 @@ pub enum MetadataValue { Bool(bool), String(String), Array(MetadataArrayValue), + UInt64(u64), + Int64(i64), + Float64(f64), } impl MetadataValue { - fn read_key_value(reader: &mut dyn BufRead) -> Result<(String, Self), LoadError> { - let key = util::read_string(reader)?; + fn read_key_value( + ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result<(String, Self), LoadError> { + let key = util::read_string(reader, ctx.use_64_bit_length)?; let value_type = MetadataValueType::try_from(util::read_u32(reader)?) .expect("TODO: handle invalid value types"); - - let value = Self::read_value(reader, value_type)?; + let value = Self::read_value(ctx, reader, value_type)?; Ok((key, value)) } fn read_value( + ctx: &GgufContext, reader: &mut dyn BufRead, value_type: MetadataValueType, ) -> Result> { match value_type { - MetadataValueType::UInt8 => Self::read_u8(reader).map(MetadataValue::UInt8), - MetadataValueType::Int8 => Self::read_i8(reader).map(MetadataValue::Int8), - MetadataValueType::UInt16 => Self::read_u16(reader).map(MetadataValue::UInt16), - MetadataValueType::Int16 => Self::read_i16(reader).map(MetadataValue::Int16), - MetadataValueType::UInt32 => Self::read_u32(reader).map(MetadataValue::UInt32), - MetadataValueType::Int32 => Self::read_i32(reader).map(MetadataValue::Int32), - MetadataValueType::Float32 => Self::read_f32(reader).map(MetadataValue::Float32), - MetadataValueType::Bool => Self::read_bool(reader).map(MetadataValue::Bool), - MetadataValueType::String => Self::read_string(reader).map(MetadataValue::String), - MetadataValueType::Array => Self::read_array(reader).map(MetadataValue::Array), + MetadataValueType::UInt8 => Self::read_u8(ctx, reader).map(MetadataValue::UInt8), + MetadataValueType::Int8 => Self::read_i8(ctx, reader).map(MetadataValue::Int8), + MetadataValueType::UInt16 => Self::read_u16(ctx, reader).map(MetadataValue::UInt16), + MetadataValueType::Int16 => Self::read_i16(ctx, reader).map(MetadataValue::Int16), + MetadataValueType::UInt32 => Self::read_u32(ctx, reader).map(MetadataValue::UInt32), + MetadataValueType::Int32 => Self::read_i32(ctx, reader).map(MetadataValue::Int32), + MetadataValueType::Float32 => Self::read_f32(ctx, reader).map(MetadataValue::Float32), + MetadataValueType::Bool => Self::read_bool(ctx, reader).map(MetadataValue::Bool), + MetadataValueType::String => Self::read_string(ctx, reader).map(MetadataValue::String), + MetadataValueType::Array => Self::read_array(ctx, reader).map(MetadataValue::Array), + MetadataValueType::UInt64 => Self::read_u64(ctx, reader).map(MetadataValue::UInt64), + MetadataValueType::Int64 => Self::read_i64(ctx, reader).map(MetadataValue::Int64), + MetadataValueType::Float64 => Self::read_f64(ctx, reader).map(MetadataValue::Float64), } } - fn read_u8(reader: &mut dyn BufRead) -> Result> { + fn read_u8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result> { Ok(util::read_bytes::<1>(reader)?[0]) } - fn read_i8(reader: &mut dyn BufRead) -> Result> { + fn read_i8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result> { Ok(util::read_bytes::<1>(reader)?[0] as i8) } - fn read_u16(reader: &mut dyn BufRead) -> Result> { + fn read_u16( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { Ok(u16::from_le_bytes(util::read_bytes::<2>(reader)?)) } - fn read_i16(reader: &mut dyn BufRead) -> Result> { + fn read_i16( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { Ok(i16::from_le_bytes(util::read_bytes::<2>(reader)?)) } - fn read_u32(reader: &mut dyn BufRead) -> Result> { + fn read_u32( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { Ok(util::read_u32(reader)?) } - fn read_i32(reader: &mut dyn BufRead) -> Result> { + fn read_i32( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { Ok(util::read_i32(reader)?) } - fn read_f32(reader: &mut dyn BufRead) -> Result> { + fn read_f32( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { Ok(util::read_f32(reader)?) } - fn read_bool(reader: &mut dyn BufRead) -> Result> { - let v = Self::read_u8(reader)?; - if v == 0 { - Ok(false) - } else if v == 1 { - Ok(true) - } else { - panic!("TODO: error for invalid bools") - } + fn read_bool( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { + Ok(util::read_bool(reader)?) + } + + fn read_string( + ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { + Ok(util::read_string(reader, ctx.use_64_bit_length)?) + } + + fn read_array( + ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { + MetadataArrayValue::read_value(ctx, reader) } - fn read_string(reader: &mut dyn BufRead) -> Result> { - Ok(util::read_string(reader)?) + fn read_u64( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { + Ok(util::read_u64(reader)?) } - fn read_array(reader: &mut dyn BufRead) -> Result> { - MetadataArrayValue::read_value(reader) + fn read_i64( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { + Ok(util::read_i64(reader)?) + } + + fn read_f64( + _ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { + Ok(util::read_f64(reader)?) } pub fn as_uint32(&self) -> Option { @@ -220,51 +287,81 @@ pub enum MetadataArrayValue { Bool(Vec), String(Vec), Array(Vec), + UInt64(Vec), + Int64(Vec), + Float64(Vec), } impl MetadataArrayValue { - fn read_value(reader: &mut dyn BufRead) -> Result> { + fn read_value( + ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result> { let value_type = MetadataValueType::try_from(util::read_u32(reader)?) .expect("TODO: handle invalid value types"); - let length = util::read_u32(reader)? as usize; + let length = util::read_length(reader, ctx.use_64_bit_length)?; - fn read_array( - reader: &mut dyn BufRead, + struct ArrayReader<'a> { + ctx: &'a GgufContext, + reader: &'a mut dyn BufRead, length: usize, - value_reader: impl Fn(&mut dyn BufRead) -> Result>, - ) -> Result, LoadError> { - (0..length).map(|_| value_reader(reader)).collect() + } + impl ArrayReader<'_> { + fn read( + &mut self, + value_reader: impl Fn( + &GgufContext, + &mut dyn BufRead, + ) -> Result>, + value_constructor: impl Fn(Vec) -> MetadataArrayValue, + ) -> Result> { + (0..self.length) + .map(|_| value_reader(self.ctx, self.reader)) + .collect::, _>>() + .map(value_constructor) + } } + let mut reader = ArrayReader { + ctx, + reader, + length, + }; use MetadataValue as MV; use MetadataValueType as MVT; Ok(match value_type { - MVT::UInt8 => read_array(reader, length, MV::read_u8).map(Self::UInt8), - MVT::Int8 => read_array(reader, length, MV::read_i8).map(Self::Int8), - MVT::UInt16 => read_array(reader, length, MV::read_u16).map(Self::UInt16), - MVT::Int16 => read_array(reader, length, MV::read_i16).map(Self::Int16), - MVT::UInt32 => read_array(reader, length, MV::read_u32).map(Self::UInt32), - MVT::Int32 => read_array(reader, length, MV::read_i32).map(Self::Int32), - MVT::Float32 => read_array(reader, length, MV::read_f32).map(Self::Float32), - MVT::Bool => read_array(reader, length, MV::read_bool).map(Self::Bool), - MVT::String => read_array(reader, length, MV::read_string).map(Self::String), - MVT::Array => read_array(reader, length, MV::read_array).map(Self::Array), + MVT::UInt8 => reader.read(MV::read_u8, Self::UInt8), + MVT::Int8 => reader.read(MV::read_i8, Self::Int8), + MVT::UInt16 => reader.read(MV::read_u16, Self::UInt16), + MVT::Int16 => reader.read(MV::read_i16, Self::Int16), + MVT::UInt32 => reader.read(MV::read_u32, Self::UInt32), + MVT::Int32 => reader.read(MV::read_i32, Self::Int32), + MVT::Float32 => reader.read(MV::read_f32, Self::Float32), + MVT::Bool => reader.read(MV::read_bool, Self::Bool), + MVT::String => reader.read(MV::read_string, Self::String), + MVT::Array => reader.read(MV::read_array, Self::Array), + MVT::UInt64 => reader.read(MV::read_u64, Self::UInt64), + MVT::Int64 => reader.read(MV::read_i64, Self::Int64), + MVT::Float64 => reader.read(MV::read_f64, Self::Float64), }?) } } #[derive(Debug, Clone, PartialEq)] pub struct TensorInfo { - pub dimensions: Vec, + pub dimensions: Vec, pub element_type: ElementType, pub offset: u64, } impl TensorInfo { - fn read_name_value(reader: &mut dyn BufRead) -> Result<(String, Self), LoadError> { - let name = util::read_string(reader)?; + fn read_name_value( + ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result<(String, Self), LoadError> { + let name = util::read_string(reader, ctx.use_64_bit_length)?; let dimension_count = util::read_u32(reader)? as usize; let dimensions = (0..dimension_count) - .map(|_| util::read_u32(reader)) + .map(|_| util::read_length(reader, ctx.use_64_bit_length)) .collect::, _>>()?; let element_type = util::read_u32(reader)?; diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index 7812b71b..d2108a98 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -36,6 +36,11 @@ pub fn read_u32(reader: &mut dyn BufRead) -> Result { Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) } +/// Read a `i64` from a reader. +pub fn read_i64(reader: &mut dyn BufRead) -> Result { + Ok(i64::from_le_bytes(read_bytes::<8>(reader)?)) +} + /// Read a `u64` from a reader. pub fn read_u64(reader: &mut dyn BufRead) -> Result { Ok(u64::from_le_bytes(read_bytes::<8>(reader)?)) @@ -46,6 +51,25 @@ pub fn read_f32(reader: &mut dyn BufRead) -> Result { Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) } +/// Read a `f64` from a reader. +pub fn read_f64(reader: &mut dyn BufRead) -> Result { + Ok(f64::from_le_bytes(read_bytes::<8>(reader)?)) +} + +/// Read an integer (32-bit or 64-bit) from a reader, and convert it to a usize. +pub fn read_length( + reader: &mut dyn BufRead, + use_64_bit_length: bool, +) -> Result { + let len: usize = if use_64_bit_length { + read_u64(reader)?.try_into() + } else { + read_u32(reader)?.try_into() + } + .expect("TODO: invalid usize conversion"); + Ok(len) +} + /// Read a `bool` represented as an `i32` from a reader. pub fn read_bool(reader: &mut dyn BufRead) -> Result { let val = i32::from_le_bytes(read_bytes::<4>(reader)?); @@ -70,8 +94,11 @@ pub fn read_bytes_with_len( } /// Read a string from a reader. -pub fn read_string(reader: &mut dyn BufRead) -> Result { - let len = read_u32(reader)? as usize; +pub fn read_string( + reader: &mut dyn BufRead, + use_64_bit_length: bool, +) -> Result { + let len = read_length(reader, use_64_bit_length)?; let bytes = read_bytes_with_len(reader, len)?; Ok(String::from_utf8(bytes) .expect("string was not valid utf-8 (TODO: make this a library error)")) From 90c6797536324092c290d55fdbfec555bb1492b2 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 27 Aug 2023 19:49:52 +0200 Subject: [PATCH 05/33] chore(gguf): clippy fix --- crates/ggml/src/format/gguf/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index 19dce419..3d9810d8 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -328,7 +328,7 @@ impl MetadataArrayValue { }; use MetadataValue as MV; use MetadataValueType as MVT; - Ok(match value_type { + match value_type { MVT::UInt8 => reader.read(MV::read_u8, Self::UInt8), MVT::Int8 => reader.read(MV::read_i8, Self::Int8), MVT::UInt16 => reader.read(MV::read_u16, Self::UInt16), @@ -342,7 +342,7 @@ impl MetadataArrayValue { MVT::UInt64 => reader.read(MV::read_u64, Self::UInt64), MVT::Int64 => reader.read(MV::read_i64, Self::Int64), MVT::Float64 => reader.read(MV::read_f64, Self::Float64), - }?) + } } } From 38dd73060db787707b5bfe9157862884ae3f47b4 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 27 Aug 2023 23:37:13 +0200 Subject: [PATCH 06/33] fix(gguf): drop the null terminator --- crates/ggml/src/util.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index d2108a98..5344e754 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -99,7 +99,14 @@ pub fn read_string( use_64_bit_length: bool, ) -> Result { let len = read_length(reader, use_64_bit_length)?; - let bytes = read_bytes_with_len(reader, len)?; + let mut bytes = read_bytes_with_len(reader, len)?; + // The GGUF C writer prior to `llama.cpp@103cfafc774f6feb3172b5d4d39681c965b17eba` + // wrote a null terminator at the end of strings. As a work-around, we remove + // them here. + if bytes.last() == Some(&0) { + // Remove the null terminator. + bytes.pop(); + } Ok(String::from_utf8(bytes) .expect("string was not valid utf-8 (TODO: make this a library error)")) } From 41462ed172d8e22750b3021ea3bb712168c5a486 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 27 Aug 2023 23:57:57 +0200 Subject: [PATCH 07/33] refactor(ggml): begin decoupling the old formats --- .vscode/settings.json | 2 +- crates/ggml/Cargo.toml | 2 + crates/ggml/src/format/ggml/loader.rs | 9 +- crates/ggml/src/format/ggml/mod.rs | 131 +++++++------------------ crates/ggml/src/format/gguf/mod.rs | 124 ++++++++++++----------- crates/ggml/src/format/mod.rs | 136 +++++++++++++++++++------- crates/ggml/src/util.rs | 6 +- crates/llm-base/Cargo.toml | 2 +- crates/llm-base/src/lib.rs | 2 +- crates/llm-base/src/loader.rs | 9 +- crates/llm-base/src/model/mod.rs | 2 +- crates/llm/src/lib.rs | 2 +- 12 files changed, 220 insertions(+), 207 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index ed83f314..8fd23a09 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,3 @@ { - "rust-analyzer.cargo.features": [] + "rust-analyzer.cargo.features": ["pre-gguf-formats"] } diff --git a/crates/ggml/Cargo.toml b/crates/ggml/Cargo.toml index fe60f7a9..03aeb92a 100644 --- a/crates/ggml/Cargo.toml +++ b/crates/ggml/Cargo.toml @@ -17,6 +17,8 @@ rand = { workspace = true } anyhow = { workspace = true } [features] +# Whether or not the pre-GGUF loading/saving code is exposed. +pre-gguf-formats = [] cublas = ["ggml-sys/cublas"] clblast = ["ggml-sys/clblast"] metal = ["ggml-sys/metal"] diff --git a/crates/ggml/src/format/ggml/loader.rs b/crates/ggml/src/format/ggml/loader.rs index d1b3c3d5..5dbe86aa 100644 --- a/crates/ggml/src/format/ggml/loader.rs +++ b/crates/ggml/src/format/ggml/loader.rs @@ -10,12 +10,12 @@ use std::{ }; use crate::{ - format::{data_size, header_size, LoadError}, + format::{data_size, header_size, ContainerType, ContainerTypeReadError}, util::{has_data_left, read_bytes_with_len, read_f32, read_i32, read_u32}, ElementType, }; -use super::ContainerType; +use super::LoadError; #[derive(Debug, Clone)] /// Information about a [tensor](https://en.wikipedia.org/wiki/Tensor_(machine_learning)) that is being read. @@ -96,7 +96,10 @@ pub fn load( handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { // Verify magic - let container_type = ContainerType::read(reader)?; + let container_type = ContainerType::read(reader).map_err(|e| match e { + ContainerTypeReadError::InvalidMagic(magic) => LoadError::InvalidMagic(magic), + ContainerTypeReadError::Io(io) => LoadError::Io(io), + })?; match container_type { ContainerType::Ggml diff --git a/crates/ggml/src/format/ggml/mod.rs b/crates/ggml/src/format/ggml/mod.rs index 774ce69d..4a526756 100644 --- a/crates/ggml/src/format/ggml/mod.rs +++ b/crates/ggml/src/format/ggml/mod.rs @@ -3,104 +3,47 @@ mod loader; mod saver; +use std::error::Error; + +use super::ContainerType; +use crate::util; + pub use loader::*; pub use saver::*; #[cfg(test)] mod tests; -use crate::{format::LoadError, util}; - -/// Magic constant for `ggml` files (unversioned). -pub const FILE_MAGIC_GGML: [u8; 4] = *b"lmgg"; -/// Magic constant for `ggml` files (versioned, ggmf). -pub const FILE_MAGIC_GGMF: [u8; 4] = *b"fmgg"; -/// Magic constant for `ggml` files (versioned, ggjt). -pub const FILE_MAGIC_GGJT: [u8; 4] = *b"tjgg"; -/// Magic constant for `ggla` files (LoRA adapter). -pub const FILE_MAGIC_GGLA: [u8; 4] = *b"algg"; -/// Magic constant for `gguf` files. -pub const FILE_MAGIC_GGUF: [u8; 4] = *b"GGUF"; - -#[derive(Debug, PartialEq, Clone, Copy)] -/// The format of the file containing the model. -pub enum ContainerType { - /// Legacy format, oldest ggml tensor file format - Ggml, - /// Legacy format. Introduces versioning. Newer than GGML, older than GGJT. - Ggmf(u32), - /// [mmap](https://en.wikipedia.org/wiki/Mmap)-able format. - Ggjt(u32), - /// LoRA adapter format. - Ggla(u32), - /// GGUF format. Current version of the format. - Gguf(u32), -} -impl ContainerType { - /// Does this container type support mmap? - pub fn support_mmap(&self) -> bool { - match self { - ContainerType::Ggml => false, - ContainerType::Ggmf(_) => false, - ContainerType::Ggla(_) => false, - ContainerType::Ggjt(_) => true, - ContainerType::Gguf(_) => true, - } - } - - /// Read the container type from a reader. - pub fn read( - reader: &mut dyn std::io::BufRead, - ) -> Result> { - // Verify magic - let magic = util::read_bytes::<4>(reader)?; - let container_type: ContainerType = match magic { - FILE_MAGIC_GGML => ContainerType::Ggml, - FILE_MAGIC_GGMF => { - let version = util::read_u32(reader)?; - ContainerType::Ggmf(version) - } - FILE_MAGIC_GGJT => { - let version = util::read_u32(reader)?; - ContainerType::Ggjt(version) - } - FILE_MAGIC_GGLA => { - let version = util::read_u32(reader)?; - ContainerType::Ggla(version) - } - FILE_MAGIC_GGUF => { - let version = util::read_u32(reader)?; - ContainerType::Gguf(version) - } - magic => return Err(LoadError::InvalidMagic(util::FormatMagic(magic))), - }; - - Ok(container_type) - } - - /// Write the container type to a writer. - pub fn write(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> { - match self { - ContainerType::Ggml => { - writer.write_all(&FILE_MAGIC_GGML)?; - } - ContainerType::Ggmf(version) => { - writer.write_all(&FILE_MAGIC_GGMF)?; - util::write_u32(writer, *version)?; - } - ContainerType::Ggjt(version) => { - writer.write_all(&FILE_MAGIC_GGJT)?; - util::write_u32(writer, *version)?; - } - ContainerType::Ggla(version) => { - writer.write_all(&FILE_MAGIC_GGLA)?; - util::write_u32(writer, *version)?; - } - ContainerType::Gguf(version) => { - writer.write_all(&FILE_MAGIC_GGUF)?; - util::write_u32(writer, *version)?; - } - } - Ok(()) - } +#[derive(Debug, thiserror::Error)] +/// Errors that can occur while loading a model. +pub enum LoadError { + #[error("invalid file magic value: {0}")] + /// The file's magic value is invalid. + InvalidMagic(util::FileMagic), + #[error("invalid ggml format: format={0:?}")] + /// An unsupported format version was found. + InvalidFormatVersion(ContainerType), + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("implementation error")] + /// An error `E` was returned by the implementation of the loader. + ImplementationError(#[source] E), + #[error("unsupported tensor type {ftype} for tensor {tensor_name}")] + /// One of the tensors encountered had an unsupported data type. + UnsupportedElementType { + /// The name of the tensor. + tensor_name: String, + /// The format type that was encountered. + ftype: u32, + }, + #[error("invariant broken: {0}")] + /// An invariant was broken. + InvariantBroken(String), } diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index 3d9810d8..40ffeb68 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -2,16 +2,46 @@ use std::{ collections::HashMap, - convert::Infallible, io::{BufRead, Seek}, }; use crate::{util, ElementType}; -use super::{ggml::ContainerType, LoadError}; +use super::{ContainerType, ContainerTypeReadError}; pub const DEFAULT_ALIGNMENT: u32 = 32; +#[derive(Debug, thiserror::Error)] +/// Errors that can occur while loading a model. +pub enum GgufLoadError { + #[error("invalid GGUF file magic value: {0}")] + /// The file magic number is invalid. + InvalidMagic(util::FileMagic), + #[error("invalid ggml format: format={0:?}")] + /// An unsupported format version was found. + InvalidFormatVersion(ContainerType), + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("unsupported tensor type {ftype} for tensor {tensor_name}")] + /// One of the tensors encountered had an unsupported data type. + UnsupportedElementType { + /// The name of the tensor. + tensor_name: String, + /// The format type that was encountered. + ftype: u32, + }, + #[error("invariant broken: {0}")] + /// An invariant was broken. + InvariantBroken(String), +} + #[derive(Debug, Clone, PartialEq)] pub struct Gguf { pub metadata: HashMap, @@ -19,10 +49,13 @@ pub struct Gguf { pub tensor_data_position: u64, } impl Gguf { - pub fn load(reader: &mut R) -> Result> { - let container = ContainerType::read(reader)?; + pub fn load(reader: &mut R) -> Result { + let container = ContainerType::read(reader).map_err(|e| match e { + ContainerTypeReadError::InvalidMagic(magic) => GgufLoadError::InvalidMagic(magic), + ContainerTypeReadError::Io(io) => GgufLoadError::Io(io), + })?; if ![ContainerType::Gguf(1), ContainerType::Gguf(2)].contains(&container) { - return Err(LoadError::InvalidFormatVersion(container)); + return Err(GgufLoadError::InvalidFormatVersion(container)); } let ctx = GgufContext { @@ -151,7 +184,7 @@ impl MetadataValue { fn read_key_value( ctx: &GgufContext, reader: &mut dyn BufRead, - ) -> Result<(String, Self), LoadError> { + ) -> Result<(String, Self), GgufLoadError> { let key = util::read_string(reader, ctx.use_64_bit_length)?; let value_type = MetadataValueType::try_from(util::read_u32(reader)?) .expect("TODO: handle invalid value types"); @@ -164,7 +197,7 @@ impl MetadataValue { ctx: &GgufContext, reader: &mut dyn BufRead, value_type: MetadataValueType, - ) -> Result> { + ) -> Result { match value_type { MetadataValueType::UInt8 => Self::read_u8(ctx, reader).map(MetadataValue::UInt8), MetadataValueType::Int8 => Self::read_i8(ctx, reader).map(MetadataValue::Int8), @@ -182,88 +215,58 @@ impl MetadataValue { } } - fn read_u8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result> { + fn read_u8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_bytes::<1>(reader)?[0]) } - fn read_i8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result> { + fn read_i8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_bytes::<1>(reader)?[0] as i8) } - fn read_u16( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_u16(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(u16::from_le_bytes(util::read_bytes::<2>(reader)?)) } - fn read_i16( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_i16(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(i16::from_le_bytes(util::read_bytes::<2>(reader)?)) } - fn read_u32( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_u32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_u32(reader)?) } - fn read_i32( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_i32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_i32(reader)?) } - fn read_f32( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_f32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_f32(reader)?) } - fn read_bool( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_bool(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_bool(reader)?) } - fn read_string( - ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_string(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_string(reader, ctx.use_64_bit_length)?) } fn read_array( ctx: &GgufContext, reader: &mut dyn BufRead, - ) -> Result> { + ) -> Result { MetadataArrayValue::read_value(ctx, reader) } - fn read_u64( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_u64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_u64(reader)?) } - fn read_i64( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_i64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_i64(reader)?) } - fn read_f64( - _ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_f64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { Ok(util::read_f64(reader)?) } @@ -292,10 +295,7 @@ pub enum MetadataArrayValue { Float64(Vec), } impl MetadataArrayValue { - fn read_value( - ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result> { + fn read_value(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { let value_type = MetadataValueType::try_from(util::read_u32(reader)?) .expect("TODO: handle invalid value types"); let length = util::read_length(reader, ctx.use_64_bit_length)?; @@ -308,12 +308,9 @@ impl MetadataArrayValue { impl ArrayReader<'_> { fn read( &mut self, - value_reader: impl Fn( - &GgufContext, - &mut dyn BufRead, - ) -> Result>, + value_reader: impl Fn(&GgufContext, &mut dyn BufRead) -> Result, value_constructor: impl Fn(Vec) -> MetadataArrayValue, - ) -> Result> { + ) -> Result { (0..self.length) .map(|_| value_reader(self.ctx, self.reader)) .collect::, _>>() @@ -356,7 +353,7 @@ impl TensorInfo { fn read_name_value( ctx: &GgufContext, reader: &mut dyn BufRead, - ) -> Result<(String, Self), LoadError> { + ) -> Result<(String, Self), GgufLoadError> { let name = util::read_string(reader, ctx.use_64_bit_length)?; let dimension_count = util::read_u32(reader)? as usize; @@ -365,11 +362,12 @@ impl TensorInfo { .collect::, _>>()?; let element_type = util::read_u32(reader)?; - let element_type = - ElementType::try_from(element_type).map_err(|_| LoadError::UnsupportedElementType { + let element_type = ElementType::try_from(element_type).map_err(|_| { + GgufLoadError::UnsupportedElementType { tensor_name: name.clone(), ftype: element_type, - })?; + } + })?; let offset = util::read_u64(reader)?; diff --git a/crates/ggml/src/format/mod.rs b/crates/ggml/src/format/mod.rs index 9fe51d2b..7cc1af73 100644 --- a/crates/ggml/src/format/mod.rs +++ b/crates/ggml/src/format/mod.rs @@ -1,44 +1,114 @@ //! Loading and saving of GGML-related files. -use std::error::Error; +use thiserror::Error; -use crate::{util::FormatMagic, ElementType}; +use crate::{util, ElementType}; +#[cfg(feature = "pre-gguf-formats")] pub mod ggml; pub mod gguf; -#[derive(Debug, thiserror::Error)] -/// Errors that can occur while loading a model. -pub enum LoadError { - #[error("invalid file magic number: {0}")] - /// The file magic number is invalid. - InvalidMagic(FormatMagic), - #[error("invalid ggml format: format={0:?}")] - /// An unsupported format version was found. - InvalidFormatVersion(ggml::ContainerType), - #[error("non-specific I/O error")] - /// A non-specific IO error. +/// Magic constant for `ggml` files (unversioned). +pub const FILE_MAGIC_GGML: [u8; 4] = *b"lmgg"; +/// Magic constant for `ggml` files (versioned, ggmf). +pub const FILE_MAGIC_GGMF: [u8; 4] = *b"fmgg"; +/// Magic constant for `ggml` files (versioned, ggjt). +pub const FILE_MAGIC_GGJT: [u8; 4] = *b"tjgg"; +/// Magic constant for `ggla` files (LoRA adapter). +pub const FILE_MAGIC_GGLA: [u8; 4] = *b"algg"; +/// Magic constant for `gguf` files. +pub const FILE_MAGIC_GGUF: [u8; 4] = *b"GGUF"; + +/// Errors that can occur while reading the container type. +#[derive(Debug, Error)] +pub enum ContainerTypeReadError { + /// The magic value was invalid. + #[error("invalid magic value: {0}")] + InvalidMagic(util::FileMagic), + /// An I/O error occurred. + #[error("I/O error")] Io(#[from] std::io::Error), - #[error("could not convert bytes to a UTF-8 string")] - /// One of the strings encountered was not valid UTF-8. - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("implementation error")] - /// An error `E` was returned by the implementation of the loader. - ImplementationError(#[source] E), - #[error("unsupported tensor type {ftype} for tensor {tensor_name}")] - /// One of the tensors encountered had an unsupported data type. - UnsupportedElementType { - /// The name of the tensor. - tensor_name: String, - /// The format type that was encountered. - ftype: u32, - }, - #[error("invariant broken: {0}")] - /// An invariant was broken. - InvariantBroken(String), +} + +#[derive(Debug, PartialEq, Clone, Copy)] +/// The format of the file containing the model. +pub enum ContainerType { + /// Legacy format, oldest ggml tensor file format + Ggml, + /// Legacy format. Introduces versioning. Newer than GGML, older than GGJT. + Ggmf(u32), + /// [mmap](https://en.wikipedia.org/wiki/Mmap)-able format. + Ggjt(u32), + /// LoRA adapter format. + Ggla(u32), + /// GGUF format. Current version of the format. + Gguf(u32), +} +impl ContainerType { + /// Does this container type support mmap? + pub fn support_mmap(&self) -> bool { + match self { + ContainerType::Ggml => false, + ContainerType::Ggmf(_) => false, + ContainerType::Ggla(_) => false, + ContainerType::Ggjt(_) => true, + ContainerType::Gguf(_) => true, + } + } + + /// Read the container type from a reader. + pub fn read(reader: &mut dyn std::io::BufRead) -> Result { + // Verify magic + let magic = util::read_bytes::<4>(reader)?; + let container_type: ContainerType = match magic { + FILE_MAGIC_GGML => ContainerType::Ggml, + FILE_MAGIC_GGMF => { + let version = util::read_u32(reader)?; + ContainerType::Ggmf(version) + } + FILE_MAGIC_GGJT => { + let version = util::read_u32(reader)?; + ContainerType::Ggjt(version) + } + FILE_MAGIC_GGLA => { + let version = util::read_u32(reader)?; + ContainerType::Ggla(version) + } + FILE_MAGIC_GGUF => { + let version = util::read_u32(reader)?; + ContainerType::Gguf(version) + } + magic => return Err(ContainerTypeReadError::InvalidMagic(util::FileMagic(magic))), + }; + + Ok(container_type) + } + + /// Write the container type to a writer. + pub fn write(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> { + match self { + ContainerType::Ggml => { + writer.write_all(&FILE_MAGIC_GGML)?; + } + ContainerType::Ggmf(version) => { + writer.write_all(&FILE_MAGIC_GGMF)?; + util::write_u32(writer, *version)?; + } + ContainerType::Ggjt(version) => { + writer.write_all(&FILE_MAGIC_GGJT)?; + util::write_u32(writer, *version)?; + } + ContainerType::Ggla(version) => { + writer.write_all(&FILE_MAGIC_GGLA)?; + util::write_u32(writer, *version)?; + } + ContainerType::Gguf(version) => { + writer.write_all(&FILE_MAGIC_GGUF)?; + util::write_u32(writer, *version)?; + } + } + Ok(()) + } } /// Returns the size occupied by a tensor's data in bytes given the element type and number of elements. diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index 5344e754..7d65dfb2 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -7,13 +7,13 @@ use std::{ /// Helper struct that wraps the magic number of a file format, /// so that it can be printed in a human-readable format. -pub struct FormatMagic(pub [u8; 4]); -impl fmt::Display for FormatMagic { +pub struct FileMagic(pub [u8; 4]); +impl fmt::Display for FileMagic { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:x?} ({})", self.0, String::from_utf8_lossy(&self.0)) } } -impl fmt::Debug for FormatMagic { +impl fmt::Debug for FileMagic { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { fmt::Display::fmt(self, f) } diff --git a/crates/llm-base/Cargo.toml b/crates/llm-base/Cargo.toml index badcbdc6..47157429 100644 --- a/crates/llm-base/Cargo.toml +++ b/crates/llm-base/Cargo.toml @@ -11,7 +11,7 @@ readme = "../../README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ggml = { path = "../ggml", version = "0.2.0-dev" } +ggml = { path = "../ggml", version = "0.2.0-dev", features = ["pre-gguf-formats"] } bytemuck = { workspace = true } rand = { workspace = true } diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index e07c8852..b69cffa0 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -30,7 +30,7 @@ pub use inference_session::{ }; pub use llm_samplers::prelude::{Sampler, SamplerChain}; pub use loader::{ - load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic, + load, load_progress_callback_stdout, ContainerType, FileMagic, FileType, FileTypeFormat, LoadError, LoadProgress, Loader, TensorLoader, }; pub use lora::{LoraAdapter, LoraParameters}; diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 2c05704e..4b2d363e 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -12,14 +12,11 @@ use crate::{ util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelContext, ModelParameters, TokenId, Tokenizer, TokenizerLoadError, TokenizerSource, }; -pub use ggml::{format::ggml::ContainerType, util::FormatMagic}; use ggml::{ - format::{ - ggml::{PartialHyperparameters, TensorLoadInfo}, - LoadError as FormatLoadError, - }, + format::ggml::{LoadError as FormatLoadError, PartialHyperparameters, TensorLoadInfo}, Context, MAX_NAME_LENGTH, }; +pub use ggml::{format::ContainerType, util::FileMagic}; use memmap2::Mmap; use thiserror::Error; use tracing::log; @@ -262,7 +259,7 @@ pub enum LoadError { /// The path that failed. path: PathBuf, /// The magic number that was encountered. - magic: FormatMagic, + magic: FileMagic, }, #[error("invalid file format {container_type:?}")] /// The version of the format is not supported by this version of `llm`. diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 1ce0e25c..494d5aae 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -195,7 +195,7 @@ pub enum HyperparametersWriteError { /// Parameters for model-wide behaviour. #[derive(Debug, Clone)] pub struct ModelParameters { - /// For [GGML formats](ggml::format::ggml::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap) + /// For [GGML formats](ggml::format::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap) /// is the default. Although mmap typically improves performance, setting this value to `false` may /// be preferred in resource-constrained environments. pub prefer_mmap: bool, diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index febe2441..8d62f40a 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -81,7 +81,7 @@ pub use llm_base::{ ggml::accelerator::get_accelerator as ggml_get_accelerator, ggml::accelerator::Accelerator as GgmlAccelerator, ggml::format as ggml_format, ggml::RoPEOverrides, load, load_progress_callback_stdout, quantize, samplers, ElementType, - FileType, FileTypeFormat, FormatMagic, Hyperparameters, InferenceError, InferenceFeedback, + FileMagic, FileType, FileTypeFormat, Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, From 2de2df7f252935291c27c6cf21bc882e6b47b69c Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 28 Aug 2023 01:20:26 +0200 Subject: [PATCH 08/33] feat(bin): add gguf-explorer as debugging tool --- .vscode/settings.json | 2 +- Cargo.lock | 1735 +++++++++++++++++++++++++++- binaries/gguf-explorer/Cargo.toml | 18 + binaries/gguf-explorer/src/main.rs | 220 ++++ crates/ggml/examples/gguf.rs | 14 - crates/ggml/src/lib.rs | 2 +- 6 files changed, 1951 insertions(+), 40 deletions(-) create mode 100644 binaries/gguf-explorer/Cargo.toml create mode 100644 binaries/gguf-explorer/src/main.rs delete mode 100644 crates/ggml/examples/gguf.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index 8fd23a09..ed83f314 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,3 @@ { - "rust-analyzer.cargo.features": ["pre-gguf-formats"] + "rust-analyzer.cargo.features": [] } diff --git a/Cargo.lock b/Cargo.lock index 049d70df..ab1f6bcf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,91 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ab_glyph" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5110f1c78cf582855d895ecd0746b653db010cec6d9f5575293f27934d980a39" +dependencies = [ + "ab_glyph_rasterizer", + "owned_ttf_parser", +] + +[[package]] +name = "ab_glyph_rasterizer" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71b1793ee61086797f5c80b6efa2b8ffa6d5dd703f118545808a7f2e27f7046" + +[[package]] +name = "accesskit" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76eb1adf08c5bcaa8490b9851fd53cca27fa9880076f178ea9d29f05196728a8" + +[[package]] +name = "accesskit_consumer" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04bb4d9e4772fe0d47df57d0d5dbe5d85dd05e2f37ae1ddb6b105e76be58fb00" +dependencies = [ + "accesskit", +] + +[[package]] +name = "accesskit_macos" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134d0acf6acb667c89d3332999b1a5df4edbc8d6113910f392ebb73f2b03bb56" +dependencies = [ + "accesskit", + "accesskit_consumer", + "objc2", + "once_cell", +] + +[[package]] +name = "accesskit_unix" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e084cb5168790c0c112626175412dc5ad127083441a8248ae49ddf6725519e83" +dependencies = [ + "accesskit", + "accesskit_consumer", + "async-channel", + "atspi", + "futures-lite", + "serde", + "zbus", +] + +[[package]] +name = "accesskit_windows" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9eac0a7f2d7cd7a93b938af401d3d8e8b7094217989a7c25c55a953023436e31" +dependencies = [ + "accesskit", + "accesskit_consumer", + "arrayvec", + "once_cell", + "paste", + "windows", +] + +[[package]] +name = "accesskit_winit" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "825d23acee1bd6d25cbaa3ca6ed6e73faf24122a774ec33d52c5c86c6ab423c0" +dependencies = [ + "accesskit", + "accesskit_macos", + "accesskit_unix", + "accesskit_windows", + "winit", +] + [[package]] name = "addr2line" version = "0.20.0" @@ -28,6 +113,17 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "ahash" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.7.20" @@ -46,6 +142,30 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-activity" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64529721f27c2314ced0890ce45e469574a73e5e6fdd6e9da1860eb29285f5e0" +dependencies = [ + "android-properties", + "bitflags 1.3.2", + "cc", + "jni-sys", + "libc", + "log", + "ndk", + "ndk-context", + "ndk-sys", + "num_enum 0.6.1", +] + +[[package]] +name = "android-properties" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7eb209b1518d6bb87b283c20095f5228ecda460da70b44f0802523dea6da04" + [[package]] name = "anstream" version = "0.3.2" @@ -101,6 +221,197 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +[[package]] +name = "arboard" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6041616acea41d67c4a984709ddab1587fd0b10efe5cc563fee954d2f011854" +dependencies = [ + "clipboard-win", + "log", + "objc", + "objc-foundation", + "objc_id", + "once_cell", + "parking_lot", + "thiserror", + "winapi", + "x11rb", +] + +[[package]] +name = "arrayref" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" + +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + +[[package]] +name = "async-broadcast" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c48ccdbf6ca6b121e0f586cbc0e73ae440e56c67c30fa0873b4e110d9c26d2b" +dependencies = [ + "event-listener", + "futures-core", +] + +[[package]] +name = "async-channel" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" +dependencies = [ + "concurrent-queue", + "event-listener", + "futures-core", +] + +[[package]] +name = "async-executor" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fa3dc5f2a8564f07759c008b9109dc0d39de92a88d5588b8a5036d286383afb" +dependencies = [ + "async-lock", + "async-task", + "concurrent-queue", + "fastrand", + "futures-lite", + "slab", +] + +[[package]] +name = "async-fs" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "279cf904654eeebfa37ac9bb1598880884924aab82e290aa65c9e77a0e142e06" +dependencies = [ + "async-lock", + "autocfg", + "blocking", + "futures-lite", +] + +[[package]] +name = "async-io" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" +dependencies = [ + "async-lock", + "autocfg", + "cfg-if", + "concurrent-queue", + "futures-lite", + "log", + "parking", + "polling", + "rustix 0.37.21", + "slab", + "socket2", + "waker-fn", +] + +[[package]] +name = "async-lock" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "287272293e9d8c41773cec55e365490fe034813a2f172f502d6ddcf75b2f582b" +dependencies = [ + "event-listener", +] + +[[package]] +name = "async-process" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9d28b1d97e08915212e2e45310d47854eafa69600756fc735fb788f75199c9" +dependencies = [ + "async-io", + "async-lock", + "autocfg", + "blocking", + "cfg-if", + "event-listener", + "futures-lite", + "rustix 0.37.21", + "signal-hook", + "windows-sys 0.48.0", +] + +[[package]] +name = "async-recursion" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e97ce7de6cf12de5d7226c73f5ba9811622f4db3a5b91b55c53e987e5f91cba" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.29", +] + +[[package]] +name = "async-task" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc7ab41815b3c653ccd2978ec3255c81349336702dfdf62ee6f7069b12a3aae" + +[[package]] +name = "async-trait" +version = "0.1.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.29", +] + +[[package]] +name = "atomic-waker" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1181e1e0d1fce796a03db1ae795d67167da795f9cf4a39c37589e85ef57f26d3" + +[[package]] +name = "atomic_refcell" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112ef6b3f6cb3cb6fc5b6b494ef7a848492cff1ab0ef4de10b0f7d572861c905" + +[[package]] +name = "atspi" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "674e7a3376837b2e7d12d34d58ac47073c491dc3bf6f71a7adaf687d4d817faa" +dependencies = [ + "async-recursion", + "async-trait", + "atspi-macros", + "enumflags2", + "futures-lite", + "serde", + "tracing", + "zbus", + "zbus_names", +] + +[[package]] +name = "atspi-macros" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb4870a32c0eaa17e35bca0e6b16020635157121fb7d45593d242c295bc768" +dependencies = [ + "quote", + "syn 1.0.109", +] + [[package]] name = "atty" version = "0.2.14" @@ -179,7 +490,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.22", + "syn 2.0.29", "which", ] @@ -195,6 +506,12 @@ version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "block-buffer" version = "0.10.4" @@ -204,6 +521,40 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-sys" +version = "0.1.0-beta.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa55741ee90902547802152aaf3f8e5248aab7e21468089560d4c8840561146" +dependencies = [ + "objc-sys", +] + +[[package]] +name = "block2" +version = "0.2.0-alpha.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dd9e63c1744f755c2f60332b88de39d341e5e86239014ad839bd71c106dec42" +dependencies = [ + "block-sys", + "objc2-encode", +] + +[[package]] +name = "blocking" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77231a1c8f801696fc0123ec6150ce92cffb8e164a02afb9c8ddee0e9b65ad65" +dependencies = [ + "async-channel", + "async-lock", + "async-task", + "atomic-waker", + "fastrand", + "futures-lite", + "log", +] + [[package]] name = "bumpalo" version = "3.13.0" @@ -215,6 +566,20 @@ name = "bytemuck" version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdde5c9cd29ebd706ce1b35600920a33550e402fc998a2e53ad3b42c3c47a192" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.29", +] [[package]] name = "byteorder" @@ -277,6 +642,20 @@ dependencies = [ "zip", ] +[[package]] +name = "calloop" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e0d00eb1ea24371a97d2da6201c6747a633dc6dc1988ef503403b4c59504a8" +dependencies = [ + "bitflags 1.3.2", + "log", + "nix 0.25.1", + "slotmap", + "thiserror", + "vec_map", +] + [[package]] name = "cc" version = "1.0.79" @@ -286,6 +665,12 @@ dependencies = [ "jobserver", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cexpr" version = "0.6.0" @@ -301,6 +686,21 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" + +[[package]] +name = "cgl" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ced0551234e87afee12411d535648dd89d2e7f34c78b753395567aff3d447ff" +dependencies = [ + "libc", +] + [[package]] name = "ci_info" version = "0.10.2" @@ -363,7 +763,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", ] [[package]] @@ -383,6 +783,37 @@ dependencies = [ "winapi", ] +[[package]] +name = "cocoa" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f425db7937052c684daec3bd6375c8abe2d146dca4b8b143d6db777c39138f3a" +dependencies = [ + "bitflags 1.3.2", + "block", + "cocoa-foundation", + "core-foundation", + "core-graphics", + "foreign-types", + "libc", + "objc", +] + +[[package]] +name = "cocoa-foundation" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "931d3837c286f56e3c58423ce4eba12d08db2374461a785c86f672b08b5650d6" +dependencies = [ + "bitflags 1.3.2", + "block", + "core-foundation", + "core-graphics-types", + "foreign-types", + "libc", + "objc", +] + [[package]] name = "color-eyre" version = "0.6.2" @@ -396,6 +827,12 @@ dependencies = [ "owo-colors", ] +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + [[package]] name = "colorchoice" version = "1.0.0" @@ -403,14 +840,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] -name = "colored" -version = "2.0.0" +name = "colored" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3616f750b84d8f0de8a58bda93e08e2a81ad3f523089b05f1dffecab48c6cbd" +dependencies = [ + "atty", + "lazy_static", + "winapi", +] + +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "memchr", +] + +[[package]] +name = "concurrent-queue" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3616f750b84d8f0de8a58bda93e08e2a81ad3f523089b05f1dffecab48c6cbd" +checksum = "62ec6771ecfa0762d24683ee5a32ad78487a3d3afdc0fb8cae19d2c5deb50b7c" dependencies = [ - "atty", - "lazy_static", - "winapi", + "crossbeam-utils", ] [[package]] @@ -447,6 +903,30 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "core-graphics" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-graphics-types", + "foreign-types", + "libc", +] + +[[package]] +name = "core-graphics-types" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bb142d41022986c1d8ff29103a1411c8a3dfad3552f87a4f8dc50d61d4f4e33" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "cpufeatures" version = "0.2.8" @@ -495,7 +975,7 @@ dependencies = [ "autocfg", "cfg-if", "crossbeam-utils", - "memoffset", + "memoffset 0.9.0", "scopeguard", ] @@ -559,6 +1039,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_builder" version = "0.12.0" @@ -642,12 +1133,134 @@ dependencies = [ "winapi", ] +[[package]] +name = "dispatch" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b" + +[[package]] +name = "dlib" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "330c60081dcc4c72131f8eb70510f1ac07223e5d4163db481a04a0befcffa412" +dependencies = [ + "libloading", +] + +[[package]] +name = "downcast-rs" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" + +[[package]] +name = "ecolor" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e479a7fa3f23d4e794f8b2f8b3568dd4e47886ad1b12c9c095e141cb591eb63" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "eframe" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4596583a2c680c55b6feaa748f74890c4f9cb9c7cb69d6117110444cb65b2f" +dependencies = [ + "bytemuck", + "cocoa", + "egui", + "egui-winit", + "egui_glow", + "glow", + "glutin", + "glutin-winit", + "image", + "js-sys", + "log", + "objc", + "percent-encoding", + "raw-window-handle", + "thiserror", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winapi", + "winit", +] + +[[package]] +name = "egui" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3aef8ec3ae1b772f340170c65bf27d5b8c28f543a0116c844d2ac08d01123e7" +dependencies = [ + "accesskit", + "ahash", + "epaint", + "log", + "nohash-hasher", +] + +[[package]] +name = "egui-winit" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a49155fd4a0a4fb21224407a91de0030847972ef90fc64edb63621caea61cb2" +dependencies = [ + "accesskit_winit", + "arboard", + "egui", + "instant", + "log", + "raw-window-handle", + "smithay-clipboard", + "webbrowser", + "winit", +] + +[[package]] +name = "egui_extras" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9278f4337b526f0d57e5375e5a7340a311fa6ee8f9fcc75721ac50af13face02" +dependencies = [ + "egui", + "serde", +] + +[[package]] +name = "egui_glow" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f8c2752cdf1b0ef5fcda59a898cacabad974d4f5880e92a420b2c917022da64" +dependencies = [ + "bytemuck", + "egui", + "glow", + "log", + "memoffset 0.6.5", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "either" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +[[package]] +name = "emath" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3857d743a6e0741cdd60b622a74c7a36ea75f5f8f11b793b41d905d2c9721a4b" +dependencies = [ + "bytemuck", +] + [[package]] name = "encode_unicode" version = "0.3.6" @@ -669,6 +1282,27 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enumflags2" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c041f5090df68b32bcd905365fd51769c8b9d553fe87fde0b683534f10c01bd2" +dependencies = [ + "enumflags2_derive", + "serde", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e9a1f9f7d83e59740248a6e14ecf93929ade55027844dfcea78beafccc15745" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.29", +] + [[package]] name = "env_logger" version = "0.10.0" @@ -689,9 +1323,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2d328fc287c61314c4a61af7cfdcbd7e678e39778488c7cb13ec133ce0f4059" dependencies = [ "fsio", - "indexmap", + "indexmap 1.9.3", +] + +[[package]] +name = "epaint" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09333964d4d57f40a85338ba3ca5ed4716070ab184dcfed966b35491c5c64f3b" +dependencies = [ + "ab_glyph", + "ahash", + "atomic_refcell", + "bytemuck", + "ecolor", + "emath", + "log", + "nohash-hasher", + "parking_lot", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.1" @@ -729,6 +1386,12 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f748b253ceca9fed5f42f8b5ceb3851e93102199bc25b64b65369f76e5c0a35" +[[package]] +name = "event-listener" +version = "2.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" + [[package]] name = "eyre" version = "0.6.8" @@ -759,6 +1422,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "fdeflate" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d329bdeac514ee06249dabc27877490f17f5d371ec693360768b838e19f3ae10" +dependencies = [ + "simd-adler32", +] + [[package]] name = "filetime" version = "0.2.21" @@ -848,6 +1520,21 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +[[package]] +name = "futures-lite" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "memchr", + "parking", + "pin-project-lite", + "waker-fn", +] + [[package]] name = "futures-sink" version = "0.3.28" @@ -868,6 +1555,7 @@ checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-core", "futures-io", + "futures-sink", "futures-task", "memchr", "pin-project-lite", @@ -892,6 +1580,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "gethostname" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1ebd34e35c46e00bb73e81363248d627782724609fe1b6396f553f68fe3862e" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "getopts" version = "0.2.21" @@ -930,18 +1628,115 @@ dependencies = [ "cc", ] +[[package]] +name = "gguf-explorer" +version = "0.1.0" +dependencies = [ + "anyhow", + "eframe", + "egui_extras", + "ggml", +] + [[package]] name = "gimli" version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +[[package]] +name = "gl_generator" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a95dfc23a2b4a9a2f5ab41d194f8bfda3cabec42af4e39f08c339eb2a0c124d" +dependencies = [ + "khronos_api", + "log", + "xml-rs", +] + [[package]] name = "glob" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "glow" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca0fe580e4b60a8ab24a868bc08e2f03cbcb20d3d676601fa909386713333728" +dependencies = [ + "js-sys", + "slotmap", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "glutin" +version = "0.30.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc93b03242719b8ad39fb26ed2b01737144ce7bd4bfc7adadcef806596760fe" +dependencies = [ + "bitflags 1.3.2", + "cfg_aliases", + "cgl", + "core-foundation", + "dispatch", + "glutin_egl_sys", + "glutin_glx_sys", + "glutin_wgl_sys", + "libloading", + "objc2", + "once_cell", + "raw-window-handle", + "wayland-sys 0.30.1", + "windows-sys 0.45.0", + "x11-dl", +] + +[[package]] +name = "glutin-winit" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "629a873fc04062830bfe8f97c03773bcd7b371e23bcc465d0a61448cd1588fa4" +dependencies = [ + "cfg_aliases", + "glutin", + "raw-window-handle", + "winit", +] + +[[package]] +name = "glutin_egl_sys" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af784eb26c5a68ec85391268e074f0aa618c096eadb5d6330b0911cf34fe57c5" +dependencies = [ + "gl_generator", + "windows-sys 0.45.0", +] + +[[package]] +name = "glutin_glx_sys" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b53cb5fe568964aa066a3ba91eac5ecbac869fb0842cd0dc9e412434f1a1494" +dependencies = [ + "gl_generator", + "x11-dl", +] + +[[package]] +name = "glutin_wgl_sys" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef89398e90033fc6bc65e9bd42fd29bbbfd483bda5b56dc5562f455550618165" +dependencies = [ + "gl_generator", +] + [[package]] name = "h2" version = "0.3.20" @@ -954,7 +1749,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -976,6 +1771,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "heck" version = "0.4.1" @@ -997,6 +1798,12 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "hmac" version = "0.12.1" @@ -1006,6 +1813,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" +dependencies = [ + "windows-sys 0.48.0", +] + [[package]] name = "http" version = "0.2.9" @@ -1099,6 +1915,20 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "image" +version = "0.24.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f3dfdbdd72063086ff443e297b61695500514b1e41095b6fb9a5ab48a70a711" +dependencies = [ + "bytemuck", + "byteorder", + "color_quant", + "num-rational", + "num-traits", + "png", +] + [[package]] name = "indenter" version = "0.3.3" @@ -1112,7 +1942,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", ] [[package]] @@ -1143,6 +1983,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" dependencies = [ "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", ] [[package]] @@ -1197,6 +2040,28 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "jobserver" version = "0.1.26" @@ -1215,6 +2080,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "khronos_api" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" + [[package]] name = "lazy_static" version = "1.4.0" @@ -1435,6 +2306,15 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1459,6 +2339,24 @@ dependencies = [ "libc", ] +[[package]] +name = "memoffset" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +dependencies = [ + "autocfg", +] + +[[package]] +name = "memoffset" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +dependencies = [ + "autocfg", +] + [[package]] name = "memoffset" version = "0.9.0" @@ -1487,6 +2385,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" dependencies = [ "adler", + "simd-adler32", ] [[package]] @@ -1496,6 +2395,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", + "log", "wasi", "windows-sys 0.48.0", ] @@ -1518,7 +2418,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", ] [[package]] @@ -1539,6 +2439,35 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndk" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451422b7e4718271c8b5b3aadf5adedba43dc76312454b387e98fae0fc951aa0" +dependencies = [ + "bitflags 1.3.2", + "jni-sys", + "ndk-sys", + "num_enum 0.5.11", + "raw-window-handle", + "thiserror", +] + +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + +[[package]] +name = "ndk-sys" +version = "0.4.1+23.1.7779620" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cf2aae958bd232cac5069850591667ad422d263686d75b52a065f9badeee5a3" +dependencies = [ + "jni-sys", +] + [[package]] name = "nias" version = "0.5.0" @@ -1554,6 +2483,31 @@ dependencies = [ "smallvec", ] +[[package]] +name = "nix" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069" +dependencies = [ + "bitflags 1.3.2", + "cfg-if", + "libc", + "memoffset 0.6.5", +] + +[[package]] +name = "nix" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f346ff70e7dbfd675fe90590b92d59ef2de15a8779ae305ebcbfd3f0caf59be4" +dependencies = [ + "autocfg", + "bitflags 1.3.2", + "cfg-if", + "libc", + "memoffset 0.6.5", +] + [[package]] name = "nix" version = "0.26.2" @@ -1563,9 +2517,16 @@ dependencies = [ "bitflags 1.3.2", "cfg-if", "libc", + "memoffset 0.7.1", "static_assertions", ] +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "7.1.3" @@ -1586,6 +2547,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -1605,12 +2587,109 @@ dependencies = [ "libc", ] +[[package]] +name = "num_enum" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f646caf906c20226733ed5b1374287eb97e3c2a5c227ce668c1f2ce20ae57c9" +dependencies = [ + "num_enum_derive 0.5.11", +] + +[[package]] +name = "num_enum" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a015b430d3c108a207fd776d2e2196aaf8b1cf8cf93253e3a097ff3085076a1" +dependencies = [ + "num_enum_derive 0.6.1", +] + +[[package]] +name = "num_enum_derive" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbff9bc912032c62bf65ef1d5aea88983b420f4f839db1e9b0c281a25c9c799" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "num_enum_derive" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96667db765a921f7b295ffee8b60472b686a51d4f21c2ee4ffdb94c7013b65a6" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.29", +] + [[package]] name = "number_prefix" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + +[[package]] +name = "objc-foundation" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1add1b659e36c9607c7aab864a76c7a4c2760cd0cd2e120f3fb8b952c7e22bf9" +dependencies = [ + "block", + "objc", + "objc_id", +] + +[[package]] +name = "objc-sys" +version = "0.2.0-beta.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b9834c1e95694a05a828b59f55fa2afec6288359cda67146126b3f90a55d7" + +[[package]] +name = "objc2" +version = "0.3.0-beta.3.patch-leaks.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e01640f9f2cb1220bbe80325e179e532cb3379ebcd1bf2279d703c19fe3a468" +dependencies = [ + "block2", + "objc-sys", + "objc2-encode", +] + +[[package]] +name = "objc2-encode" +version = "2.0.0-pre.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abfcac41015b00a120608fdaa6938c44cb983fee294351cc4bac7638b4e50512" +dependencies = [ + "objc-sys", +] + +[[package]] +name = "objc_id" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c92d4ddb4bd7b50d730c215ff871754d0da6b2178849f8a2a2ab69712d0c073b" +dependencies = [ + "objc", +] + [[package]] name = "object" version = "0.31.1" @@ -1671,7 +2750,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", ] [[package]] @@ -1692,18 +2771,52 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "orbclient" +version = "0.3.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8378ac0dfbd4e7895f2d2c1f1345cab3836910baf3a300b000d04250f0c8428f" +dependencies = [ + "redox_syscall 0.3.5", +] + +[[package]] +name = "ordered-stream" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aa2b01e1d916879f73a53d01d1d6cee68adbb31d6d9177a8cfce093cced1d50" +dependencies = [ + "futures-core", + "pin-project-lite", +] + [[package]] name = "overload" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "owned_ttf_parser" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "706de7e2214113d63a8238d1910463cfce781129a6f263d13fdb09ff64355ba4" +dependencies = [ + "ttf-parser", +] + [[package]] name = "owo-colors" version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" +[[package]] +name = "parking" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14f2252c834a40ed9bb5422029649578e63aa341ac401f74e719dd1afda8394e" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1792,6 +2905,35 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "png" +version = "0.17.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd75bf2d8dd3702b9707cdbc56a5b9ef42cec752eb8b3bafc01234558442aa64" +dependencies = [ + "bitflags 1.3.2", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + +[[package]] +name = "polling" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" +dependencies = [ + "autocfg", + "bitflags 1.3.2", + "cfg-if", + "concurrent-queue", + "libc", + "log", + "pin-project-lite", + "windows-sys 0.48.0", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1809,7 +2951,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9825a04601d60621feed79c4e6b56d65db77cdca55cef43b46b0de1096d1c282" dependencies = [ "proc-macro2", - "syn 2.0.22", + "syn 2.0.29", +] + +[[package]] +name = "proc-macro-crate" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +dependencies = [ + "once_cell", + "toml_edit", ] [[package]] @@ -1870,6 +3022,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "raw-window-handle" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" + [[package]] name = "rayon" version = "1.7.0" @@ -2066,7 +3224,7 @@ dependencies = [ "libc", "log", "memchr", - "nix", + "nix 0.26.2", "radix_trie", "rustyline-derive", "scopeguard", @@ -2093,6 +3251,15 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.21" @@ -2102,12 +3269,31 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sctk-adwaita" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cda4e97be1fd174ccc2aae81c8b694e803fa99b34e8fd0f057a9d70698e3ed09" +dependencies = [ + "ab_glyph", + "log", + "memmap2", + "smithay-client-toolkit", + "tiny-skia", +] + [[package]] name = "security-framework" version = "2.9.1" @@ -2157,7 +3343,7 @@ checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", ] [[package]] @@ -2171,6 +3357,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_repr" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8725e1dfadb3a50f7e5ce0b1a540466f6ed3fe7a0fca2ac2b8b831d31316bd00" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.29", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2220,6 +3417,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -2229,6 +3436,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "slab" version = "0.4.8" @@ -2238,12 +3451,50 @@ dependencies = [ "autocfg", ] +[[package]] +name = "slotmap" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" +dependencies = [ + "version_check", +] + [[package]] name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +[[package]] +name = "smithay-client-toolkit" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f307c47d32d2715eb2e0ece5589057820e0e5e70d07c247d1063e844e107f454" +dependencies = [ + "bitflags 1.3.2", + "calloop", + "dlib", + "lazy_static", + "log", + "memmap2", + "nix 0.24.3", + "pkg-config", + "wayland-client", + "wayland-cursor", + "wayland-protocols", +] + +[[package]] +name = "smithay-clipboard" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a345c870a1fae0b1b779085e81b51e614767c239e93503588e54c5b17f4b0e8" +dependencies = [ + "smithay-client-toolkit", + "wayland-client", +] + [[package]] name = "socket2" version = "0.4.9" @@ -2289,6 +3540,12 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e08d8363704e6c71fc928674353e6b7c23dcea9d82d7012c8faf2a3a025f8d0" +[[package]] +name = "strict-num" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6637bab7722d379c8b41ba849228d680cc12d0a45ba1fa2b48f2a30577a06731" + [[package]] name = "strsim" version = "0.10.0" @@ -2314,9 +3571,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.22" +version = "2.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616" +checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" dependencies = [ "proc-macro2", "quote", @@ -2374,7 +3631,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", ] [[package]] @@ -2414,6 +3671,31 @@ dependencies = [ "time-core", ] +[[package]] +name = "tiny-skia" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df8493a203431061e901613751931f047d1971337153f96d0e5e363d6dbf6a67" +dependencies = [ + "arrayref", + "arrayvec", + "bytemuck", + "cfg-if", + "png", + "tiny-skia-path", +] + +[[package]] +name = "tiny-skia-path" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adbfb5d3f3dd57a0e11d12f4f13d4ebbbc1b5c15b7ab0a156d030b21da5f677c" +dependencies = [ + "arrayref", + "bytemuck", + "strict-num", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -2491,7 +3773,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", ] [[package]] @@ -2527,6 +3809,23 @@ dependencies = [ "serde", ] +[[package]] +name = "toml_datetime" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" + +[[package]] +name = "toml_edit" +version = "0.19.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" +dependencies = [ + "indexmap 2.0.0", + "toml_datetime", + "winnow", +] + [[package]] name = "tower-service" version = "0.3.2" @@ -2565,7 +3864,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", ] [[package]] @@ -2613,12 +3912,28 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "ttf-parser" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a464a4b34948a5f67fddd2b823c62d9d92e44be75058b99939eae6c5b6960b33" + [[package]] name = "typenum" version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +[[package]] +name = "uds_windows" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce65604324d3cce9b966701489fbd0cf318cb1f7bd9dd07ac9a4ee6fb791930d" +dependencies = [ + "tempfile", + "winapi", +] + [[package]] name = "unicode-bidi" version = "0.3.13" @@ -2696,12 +4011,34 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + [[package]] name = "version_check" version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "waker-fn" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" + +[[package]] +name = "walkdir" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2738,7 +4075,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", "wasm-bindgen-shared", ] @@ -2772,7 +4109,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.29", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2783,6 +4120,91 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wayland-client" +version = "0.29.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f3b068c05a039c9f755f881dc50f01732214f5685e379829759088967c46715" +dependencies = [ + "bitflags 1.3.2", + "downcast-rs", + "libc", + "nix 0.24.3", + "scoped-tls", + "wayland-commons", + "wayland-scanner", + "wayland-sys 0.29.5", +] + +[[package]] +name = "wayland-commons" +version = "0.29.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8691f134d584a33a6606d9d717b95c4fa20065605f798a3f350d78dced02a902" +dependencies = [ + "nix 0.24.3", + "once_cell", + "smallvec", + "wayland-sys 0.29.5", +] + +[[package]] +name = "wayland-cursor" +version = "0.29.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6865c6b66f13d6257bef1cd40cbfe8ef2f150fb8ebbdb1e8e873455931377661" +dependencies = [ + "nix 0.24.3", + "wayland-client", + "xcursor", +] + +[[package]] +name = "wayland-protocols" +version = "0.29.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b950621f9354b322ee817a23474e479b34be96c2e909c14f7bc0100e9a970bc6" +dependencies = [ + "bitflags 1.3.2", + "wayland-client", + "wayland-commons", + "wayland-scanner", +] + +[[package]] +name = "wayland-scanner" +version = "0.29.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f4303d8fa22ab852f789e75a967f0a2cdc430a607751c0499bada3e451cbd53" +dependencies = [ + "proc-macro2", + "quote", + "xml-rs", +] + +[[package]] +name = "wayland-sys" +version = "0.29.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be12ce1a3c39ec7dba25594b97b42cb3195d54953ddb9d3d95a7c3902bc6e9d4" +dependencies = [ + "dlib", + "lazy_static", + "pkg-config", +] + +[[package]] +name = "wayland-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b2a02ac608e07132978689a6f9bf4214949c85998c247abadd4f4129b1aa06" +dependencies = [ + "dlib", + "lazy_static", + "log", + "pkg-config", +] + [[package]] name = "web-sys" version = "0.3.64" @@ -2793,6 +4215,23 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webbrowser" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2c79b77f525a2d670cb40619d7d9c673d09e0666f72c591ebd7861f84a87e57" +dependencies = [ + "core-foundation", + "home", + "jni", + "log", + "ndk-context", + "objc", + "raw-window-handle", + "url", + "web-sys", +] + [[package]] name = "which" version = "4.4.0" @@ -2829,12 +4268,54 @@ dependencies = [ "winapi", ] +[[package]] +name = "winapi-wsapoll" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c17110f57155602a80dca10be03852116403c9ff3cd25b079d666f2aa3df6e" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-targets 0.48.1", +] + +[[package]] +name = "windows-implement" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e2ee588991b9e7e6c8338edf3333fbe4da35dc72092643958ebb43f0ab2c49c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "windows-interface" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6fb8df20c9bcaa8ad6ab513f7b40104840c8867d5751126e4df3b08388d0cc7" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "windows-sys" version = "0.42.0" @@ -2982,6 +4463,50 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "winit" +version = "0.28.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "866db3f712fffba75d31bf0cdecf357c8aeafd158c5b7ab51dba2a2b2d47f196" +dependencies = [ + "android-activity", + "bitflags 1.3.2", + "cfg_aliases", + "core-foundation", + "core-graphics", + "dispatch", + "instant", + "libc", + "log", + "mio", + "ndk", + "objc2", + "once_cell", + "orbclient", + "percent-encoding", + "raw-window-handle", + "redox_syscall 0.3.5", + "sctk-adwaita", + "smithay-client-toolkit", + "wasm-bindgen", + "wayland-client", + "wayland-commons", + "wayland-protocols", + "wayland-scanner", + "web-sys", + "windows-sys 0.45.0", + "x11-dl", +] + +[[package]] +name = "winnow" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c2e3184b9c4e92ad5167ca73039d0c42476302ab603e2fec4487511f38ccefc" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.10.1" @@ -2991,6 +4516,39 @@ dependencies = [ "winapi", ] +[[package]] +name = "x11-dl" +version = "2.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38735924fedd5314a6e548792904ed8c6de6636285cb9fec04d5b1db85c1516f" +dependencies = [ + "libc", + "once_cell", + "pkg-config", +] + +[[package]] +name = "x11rb" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "592b4883219f345e712b3209c62654ebda0bb50887f330cbd018d0f654bfd507" +dependencies = [ + "gethostname", + "nix 0.24.3", + "winapi", + "winapi-wsapoll", + "x11rb-protocol", +] + +[[package]] +name = "x11rb-protocol" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56b245751c0ac9db0e006dc812031482784e434630205a93c73cfefcaabeac67" +dependencies = [ + "nix 0.24.3", +] + [[package]] name = "xattr" version = "0.2.3" @@ -3000,6 +4558,97 @@ dependencies = [ "libc", ] +[[package]] +name = "xcursor" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "463705a63313cd4301184381c5e8042f0a7e9b4bb63653f216311d4ae74690b7" +dependencies = [ + "nom", +] + +[[package]] +name = "xdg-home" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2769203cd13a0c6015d515be729c526d041e9cf2c0cc478d57faee85f40c6dcd" +dependencies = [ + "nix 0.26.2", + "winapi", +] + +[[package]] +name = "xml-rs" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47430998a7b5d499ccee752b41567bc3afc57e1327dc855b1a2aa44ce29b5fa1" + +[[package]] +name = "zbus" +version = "3.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31de390a2d872e4cd04edd71b425e29853f786dc99317ed72d73d6fcf5ebb948" +dependencies = [ + "async-broadcast", + "async-executor", + "async-fs", + "async-io", + "async-lock", + "async-process", + "async-recursion", + "async-task", + "async-trait", + "blocking", + "byteorder", + "derivative", + "enumflags2", + "event-listener", + "futures-core", + "futures-sink", + "futures-util", + "hex", + "nix 0.26.2", + "once_cell", + "ordered-stream", + "rand", + "serde", + "serde_repr", + "sha1", + "static_assertions", + "tracing", + "uds_windows", + "winapi", + "xdg-home", + "zbus_macros", + "zbus_names", + "zvariant", +] + +[[package]] +name = "zbus_macros" +version = "3.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d1794a946878c0e807f55a397187c11fc7a038ba5d868e7db4f3bd7760bc9d" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "syn 1.0.109", + "zvariant_utils", +] + +[[package]] +name = "zbus_names" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb80bb776dbda6e23d705cf0123c3b95df99c4ebeaec6c2599d4a5419902b4a9" +dependencies = [ + "serde", + "static_assertions", + "zvariant", +] + [[package]] name = "zip" version = "0.6.6" @@ -3068,3 +4717,41 @@ dependencies = [ "libc", "pkg-config", ] + +[[package]] +name = "zvariant" +version = "3.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44b291bee0d960c53170780af148dca5fa260a63cdd24f1962fa82e03e53338c" +dependencies = [ + "byteorder", + "enumflags2", + "libc", + "serde", + "static_assertions", + "zvariant_derive", +] + +[[package]] +name = "zvariant_derive" +version = "3.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "934d7a7dfc310d6ee06c87ffe88ef4eca7d3e37bb251dece2ef93da8f17d8ecd" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", + "zvariant_utils", +] + +[[package]] +name = "zvariant_utils" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7234f0d811589db492d16893e3f21e8e2fd282e6d01b0cddee310322062cc200" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] diff --git a/binaries/gguf-explorer/Cargo.toml b/binaries/gguf-explorer/Cargo.toml new file mode 100644 index 00000000..85a7aa78 --- /dev/null +++ b/binaries/gguf-explorer/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "gguf-explorer" +version = "0.1.0" +edition = "2021" +repository = { workspace = true } +license = { workspace = true } +publish = false + +[package.metadata.release] +release = false + +[dependencies] +ggml = { path = "../../crates/ggml" } + +anyhow = { workspace = true } + +eframe = "0.22" +egui_extras = "0.22" diff --git a/binaries/gguf-explorer/src/main.rs b/binaries/gguf-explorer/src/main.rs new file mode 100644 index 00000000..11435057 --- /dev/null +++ b/binaries/gguf-explorer/src/main.rs @@ -0,0 +1,220 @@ +use std::{fmt::Display, fs::File, io::BufReader}; + +use egui_extras::{Column, TableBuilder}; +use ggml::format::gguf::{self, Gguf}; + +use eframe::egui::{self, Button, CentralPanel, CollapsingHeader, Label, RichText, TopBottomPanel}; + +fn main() -> eframe::Result<()> { + let file_path = match std::env::args().nth(1) { + Some(path) => path, + None => { + eprintln!("Usage: gguf-explorer "); + std::process::exit(1); + } + }; + + let mut file = File::open(file_path).expect("Failed to open file"); + let gguf = Gguf::load(&mut BufReader::new(&mut file)).expect("Failed to load gguf file"); + + let native_options = eframe::NativeOptions::default(); + eframe::run_native( + "GGUF Explorer", + native_options, + Box::new(move |_cc| { + Box::new(Explorer { + _file: file, + gguf, + + selected_tab: Tab::Metadata, + tensor_sort_order: TensorColumn::Offset, + }) + }), + ) +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum Tab { + Metadata, + Tensors, +} +impl Display for Tab { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Tab::Metadata => write!(f, "Metadata"), + Tab::Tensors => write!(f, "Tensors"), + } + } +} +impl Tab { + const ALL: [Tab; 2] = [Tab::Metadata, Tab::Tensors]; +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum TensorColumn { + Name, + Dimensions, + Type, + Offset, +} +impl Display for TensorColumn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TensorColumn::Name => write!(f, "Name"), + TensorColumn::Dimensions => write!(f, "Dimensions"), + TensorColumn::Type => write!(f, "Type"), + TensorColumn::Offset => write!(f, "Offset"), + } + } +} +impl TensorColumn { + const ALL: [Self; 4] = [Self::Name, Self::Dimensions, Self::Type, Self::Offset]; +} + +struct Explorer { + _file: File, + gguf: Gguf, + + selected_tab: Tab, + tensor_sort_order: TensorColumn, +} +impl eframe::App for Explorer { + fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { + TopBottomPanel::top("top_panel").show(ctx, |ui| { + ui.horizontal(|ui| { + for tab in Tab::ALL.iter().copied() { + let text = RichText::from(tab.to_string()); + let text = if tab == self.selected_tab { + text.underline() + } else { + text + }; + + if ui.add(Button::new(text)).clicked() { + self.selected_tab = tab; + } + } + }); + }); + + CentralPanel::default().show(ctx, |ui| match self.selected_tab { + Tab::Metadata => { + self.render_metadata(ui); + } + Tab::Tensors => { + self.render_tensors(ui); + } + }); + } +} +impl Explorer { + fn render_metadata(&mut self, ui: &mut egui::Ui) { + let metadata = &self.gguf.metadata; + let mut metadata_keys = metadata.keys().collect::>(); + metadata_keys.sort_by_key(|k| *k); + + TableBuilder::new(ui) + .striped(true) + .auto_shrink([false, true]) + .column(Column::auto().resizable(true)) + .column(Column::remainder().resizable(true)) + .header(20.0, |mut header| { + header.col(|ui| { + ui.label("Key"); + }); + header.col(|ui| { + ui.label("Value"); + }); + }) + .body(|mut body| { + for key in metadata_keys { + let value = &metadata[key]; + + body.row(30.0, |mut row| { + row.col(|ui| { + ui.add(Label::new(monospace(key)).wrap(false)); + }); + row.col(|ui| match value { + gguf::MetadataValue::Array(value) => { + CollapsingHeader::new(format!("array ({} elements)", value.len())) + .id_source(key) + .show(ui, |ui| { + ui.add( + Label::new(monospace(format!("{:?}", value))) + .wrap(false), + ); + }); + } + value => { + ui.add(Label::new(monospace(format!("{:?}", value))).wrap(false)); + } + }); + }); + } + }); + } + + fn render_tensors(&mut self, ui: &mut egui::Ui) { + let tensors = &self.gguf.tensor_infos; + let mut tensor_names = tensors.keys().collect::>(); + match self.tensor_sort_order { + TensorColumn::Name => tensor_names.sort_by_key(|k| *k), + TensorColumn::Dimensions => { + tensor_names.sort_by_key(|k| tensors[*k].dimensions.clone()) + } + TensorColumn::Type => tensor_names.sort_by_key(|k| tensors[*k].element_type), + TensorColumn::Offset => tensor_names.sort_by_key(|k| tensors[*k].offset), + } + + TableBuilder::new(ui) + .striped(true) + .auto_shrink([false, true]) + .column(Column::remainder().resizable(true)) + .columns(Column::auto().resizable(true), 3) + .header(20.0, |mut header| { + for column in TensorColumn::ALL.iter().copied() { + header.col(|ui| { + let text = RichText::from(column.to_string()); + let text = if self.tensor_sort_order == column { + text.underline() + } else { + text + }; + + if ui.add(Button::new(text).wrap(false)).clicked() { + self.tensor_sort_order = column; + } + }); + } + }) + .body(|mut body| { + for tensor_name in tensor_names { + let tensor = &tensors[tensor_name]; + + body.row(30.0, |mut row| { + row.col(|ui| { + ui.add(Label::new(monospace(tensor_name)).wrap(false)); + }); + row.col(|ui| { + ui.add( + Label::new(monospace(format!("{:?}", tensor.dimensions))) + .wrap(false), + ); + }); + row.col(|ui| { + ui.add( + Label::new(monospace(tensor.element_type.to_string())).wrap(false), + ); + }); + row.col(|ui| { + ui.add(Label::new(monospace(tensor.offset.to_string())).wrap(false)); + }); + }); + } + }); + } +} + +fn monospace(text: impl Into) -> RichText { + RichText::new(text).monospace() +} diff --git a/crates/ggml/examples/gguf.rs b/crates/ggml/examples/gguf.rs deleted file mode 100644 index f7c5328f..00000000 --- a/crates/ggml/examples/gguf.rs +++ /dev/null @@ -1,14 +0,0 @@ -use std::io::BufReader; - -use ggml::format::gguf; - -fn main() -> anyhow::Result<()> { - let mut file = BufReader::new(std::fs::File::open( - std::env::args().nth(1).expect("need a file to read"), - )?); - - let gguf = gguf::Gguf::load(&mut file)?; - dbg!(gguf); - - Ok(()) -} diff --git a/crates/ggml/src/lib.rs b/crates/ggml/src/lib.rs index 0cef3591..1f4ac50f 100644 --- a/crates/ggml/src/lib.rs +++ b/crates/ggml/src/lib.rs @@ -63,7 +63,7 @@ impl Default for RoPEOverrides { } } -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, PartialOrd, Ord)] /// The type of a value in `ggml`. pub enum Type { /// Quantized 4-bit (type 0). From 0da661fd5290989db4d63fa94c41e4f0f97b822c Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 28 Aug 2023 02:57:56 +0200 Subject: [PATCH 09/33] refactor(gguf): split metadata out + use macros --- crates/ggml/src/format/gguf/metadata.rs | 376 ++++++++++++++++++++++++ crates/ggml/src/format/gguf/mod.rs | 248 +--------------- 2 files changed, 379 insertions(+), 245 deletions(-) create mode 100644 crates/ggml/src/format/gguf/metadata.rs diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs new file mode 100644 index 00000000..39f3713b --- /dev/null +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -0,0 +1,376 @@ +use std::{collections::HashMap, io::BufRead}; + +use crate::util; + +use super::{GgufContext, GgufLoadError}; + +pub type Metadata = HashMap; + +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum MetadataValueType { + /// The value is a 8-bit unsigned integer. + UInt8 = 0, + /// The value is a 8-bit signed integer. + Int8 = 1, + /// The value is a 16-bit unsigned little-endian integer. + UInt16 = 2, + /// The value is a 16-bit signed little-endian integer. + Int16 = 3, + /// The value is a 32-bit unsigned little-endian integer. + UInt32 = 4, + /// The value is a 32-bit signed little-endian integer. + Int32 = 5, + /// The value is a 32-bit IEEE754 floating point number. + Float32 = 6, + /// The value is a boolean. + /// 1-byte value where 0 is false and 1 is true. + /// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + Bool = 7, + /// The value is a UTF-8 non-null-terminated string, with length prepended. + String = 8, + /// The value is an array of other values, with the length and type prepended. + /// + /// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + Array = 9, + /// The value is a 64-bit unsigned little-endian integer. + /// Implemented in GGUFv2. + UInt64 = 10, + /// The value is a 64-bit signed little-endian integer. + /// Implemented in GGUFv2. + Int64 = 11, + /// The value is a 64-bit IEEE754 floating point number. + /// Implemented in GGUFv2. + Float64 = 12, +} +pub trait MetadataValueTypeFromRustType { + fn value_type() -> MetadataValueType; +} +macro_rules! impl_value_boilerplate { + ($($value_type:ident($rust_type:ty)),*) => { + $( + impl MetadataValueTypeFromRustType for $rust_type { + fn value_type() -> MetadataValueType { + MetadataValueType::$value_type + } + } + )* + + + impl TryFrom for MetadataValueType { + type Error = (); + + fn try_from(value: u32) -> Result { + for test_value in [ + $(MetadataValueType::$value_type),* + ] { + if value == test_value as u32 { + return Ok(test_value); + } + } + Err(()) + } + } + + + #[derive(Debug, Clone, PartialEq)] + pub enum MetadataValue { + $( + $value_type($rust_type), + )* + } + + // Public + impl MetadataValue { + pub fn value_type(&self) -> MetadataValueType { + match self { + $(MetadataValue::$value_type(_) => MetadataValueType::$value_type),* + } + } + } + }; +} +impl_value_boilerplate! { + UInt8(u8), + Int8(i8), + UInt16(u16), + Int16(i16), + UInt32(u32), + Int32(i32), + Float32(f32), + Bool(bool), + String(String), + Array(MetadataArrayValue), + UInt64(u64), + Int64(i64), + Float64(f64) +} + +// Public +impl MetadataValue { + pub fn as_uint8(&self) -> Option { + match self { + Self::UInt8(v) => Some(*v), + _ => None, + } + } + + pub fn as_int8(&self) -> Option { + match self { + Self::Int8(v) => Some(*v), + _ => None, + } + } + + pub fn as_uint16(&self) -> Option { + match self { + Self::UInt16(v) => Some(*v), + _ => None, + } + } + + pub fn as_int16(&self) -> Option { + match self { + Self::Int16(v) => Some(*v), + _ => None, + } + } + + pub fn as_uint32(&self) -> Option { + match self { + Self::UInt32(v) => Some(*v), + _ => None, + } + } + + pub fn as_int32(&self) -> Option { + match self { + Self::Int32(v) => Some(*v), + _ => None, + } + } + + pub fn as_float32(&self) -> Option { + match self { + Self::Float32(v) => Some(*v), + _ => None, + } + } + + pub fn as_bool(&self) -> Option { + match self { + Self::Bool(v) => Some(*v), + _ => None, + } + } + + pub fn as_string(&self) -> Option<&str> { + match self { + Self::String(v) => Some(v), + _ => None, + } + } + + pub fn as_array(&self) -> Option<&MetadataArrayValue> { + match self { + Self::Array(v) => Some(v), + _ => None, + } + } + + pub fn as_uint64(&self) -> Option { + match self { + Self::UInt64(v) => Some(*v), + _ => None, + } + } + + pub fn as_int64(&self) -> Option { + match self { + Self::Int64(v) => Some(*v), + _ => None, + } + } + + pub fn as_float64(&self) -> Option { + match self { + Self::Float64(v) => Some(*v), + _ => None, + } + } +} +// Private +impl MetadataValue { + pub(super) fn read_key_value( + ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result<(String, Self), GgufLoadError> { + let key = util::read_string(reader, ctx.use_64_bit_length)?; + let value_type = MetadataValueType::try_from(util::read_u32(reader)?) + .expect("TODO: handle invalid value types"); + let value = Self::read_value(ctx, reader, value_type)?; + + Ok((key, value)) + } + + fn read_value( + ctx: &GgufContext, + reader: &mut dyn BufRead, + value_type: MetadataValueType, + ) -> Result { + match value_type { + MetadataValueType::UInt8 => Self::read_u8(ctx, reader).map(MetadataValue::UInt8), + MetadataValueType::Int8 => Self::read_i8(ctx, reader).map(MetadataValue::Int8), + MetadataValueType::UInt16 => Self::read_u16(ctx, reader).map(MetadataValue::UInt16), + MetadataValueType::Int16 => Self::read_i16(ctx, reader).map(MetadataValue::Int16), + MetadataValueType::UInt32 => Self::read_u32(ctx, reader).map(MetadataValue::UInt32), + MetadataValueType::Int32 => Self::read_i32(ctx, reader).map(MetadataValue::Int32), + MetadataValueType::Float32 => Self::read_f32(ctx, reader).map(MetadataValue::Float32), + MetadataValueType::Bool => Self::read_bool(ctx, reader).map(MetadataValue::Bool), + MetadataValueType::String => Self::read_string(ctx, reader).map(MetadataValue::String), + MetadataValueType::Array => Self::read_array(ctx, reader).map(MetadataValue::Array), + MetadataValueType::UInt64 => Self::read_u64(ctx, reader).map(MetadataValue::UInt64), + MetadataValueType::Int64 => Self::read_i64(ctx, reader).map(MetadataValue::Int64), + MetadataValueType::Float64 => Self::read_f64(ctx, reader).map(MetadataValue::Float64), + } + } + + fn read_u8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_bytes::<1>(reader)?[0]) + } + + fn read_i8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_bytes::<1>(reader)?[0] as i8) + } + + fn read_u16(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(u16::from_le_bytes(util::read_bytes::<2>(reader)?)) + } + + fn read_i16(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(i16::from_le_bytes(util::read_bytes::<2>(reader)?)) + } + + fn read_u32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_u32(reader)?) + } + + fn read_i32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_i32(reader)?) + } + + fn read_f32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_f32(reader)?) + } + + fn read_bool(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_bool(reader)?) + } + + fn read_string(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_string(reader, ctx.use_64_bit_length)?) + } + + fn read_array( + ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result { + MetadataArrayValue::read_value(ctx, reader) + } + + fn read_u64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_u64(reader)?) + } + + fn read_i64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_i64(reader)?) + } + + fn read_f64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + Ok(util::read_f64(reader)?) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum MetadataArrayValue { + UInt8(Vec), + Int8(Vec), + UInt16(Vec), + Int16(Vec), + UInt32(Vec), + Int32(Vec), + Float32(Vec), + Bool(Vec), + String(Vec), + Array(Vec), + UInt64(Vec), + Int64(Vec), + Float64(Vec), +} +impl MetadataArrayValue { + fn read_value(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + let value_type = MetadataValueType::try_from(util::read_u32(reader)?) + .expect("TODO: handle invalid value types"); + let length = util::read_length(reader, ctx.use_64_bit_length)?; + + struct ArrayReader<'a> { + ctx: &'a GgufContext, + reader: &'a mut dyn BufRead, + length: usize, + } + impl ArrayReader<'_> { + fn read( + &mut self, + value_reader: impl Fn(&GgufContext, &mut dyn BufRead) -> Result, + value_constructor: impl Fn(Vec) -> MetadataArrayValue, + ) -> Result { + (0..self.length) + .map(|_| value_reader(self.ctx, self.reader)) + .collect::, _>>() + .map(value_constructor) + } + } + + let mut reader = ArrayReader { + ctx, + reader, + length, + }; + use MetadataValue as MV; + use MetadataValueType as MVT; + match value_type { + MVT::UInt8 => reader.read(MV::read_u8, Self::UInt8), + MVT::Int8 => reader.read(MV::read_i8, Self::Int8), + MVT::UInt16 => reader.read(MV::read_u16, Self::UInt16), + MVT::Int16 => reader.read(MV::read_i16, Self::Int16), + MVT::UInt32 => reader.read(MV::read_u32, Self::UInt32), + MVT::Int32 => reader.read(MV::read_i32, Self::Int32), + MVT::Float32 => reader.read(MV::read_f32, Self::Float32), + MVT::Bool => reader.read(MV::read_bool, Self::Bool), + MVT::String => reader.read(MV::read_string, Self::String), + MVT::Array => reader.read(MV::read_array, Self::Array), + MVT::UInt64 => reader.read(MV::read_u64, Self::UInt64), + MVT::Int64 => reader.read(MV::read_i64, Self::Int64), + MVT::Float64 => reader.read(MV::read_f64, Self::Float64), + } + } + + /// Returns the length of the array. + pub fn len(&self) -> usize { + match self { + Self::UInt8(v) => v.len(), + Self::Int8(v) => v.len(), + Self::UInt16(v) => v.len(), + Self::Int16(v) => v.len(), + Self::UInt32(v) => v.len(), + Self::Int32(v) => v.len(), + Self::Float32(v) => v.len(), + Self::Bool(v) => v.len(), + Self::String(v) => v.len(), + Self::Array(v) => v.len(), + Self::UInt64(v) => v.len(), + Self::Int64(v) => v.len(), + Self::Float64(v) => v.len(), + } + } +} diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index 40ffeb68..bbc7cff9 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -7,7 +7,9 @@ use std::{ use crate::{util, ElementType}; -use super::{ContainerType, ContainerTypeReadError}; + +mod metadata; +pub use metadata::*; pub const DEFAULT_ALIGNMENT: u32 = 32; @@ -99,250 +101,6 @@ struct GgufContext { use_64_bit_length: bool, } -#[repr(u32)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum MetadataValueType { - /// The value is a 8-bit unsigned integer. - UInt8 = 0, - /// The value is a 8-bit signed integer. - Int8 = 1, - /// The value is a 16-bit unsigned little-endian integer. - UInt16 = 2, - /// The value is a 16-bit signed little-endian integer. - Int16 = 3, - /// The value is a 32-bit unsigned little-endian integer. - UInt32 = 4, - /// The value is a 32-bit signed little-endian integer. - Int32 = 5, - /// The value is a 32-bit IEEE754 floating point number. - Float32 = 6, - /// The value is a boolean. - /// 1-byte value where 0 is false and 1 is true. - /// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. - Bool = 7, - /// The value is a UTF-8 non-null-terminated string, with length prepended. - String = 8, - /// The value is an array of other values, with the length and type prepended. - /// - /// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. - Array = 9, - /// The value is a 64-bit unsigned little-endian integer. - /// Implemented in GGUFv2. - UInt64 = 10, - /// The value is a 64-bit signed little-endian integer. - /// Implemented in GGUFv2. - Int64 = 11, - /// The value is a 64-bit IEEE754 floating point number. - /// Implemented in GGUFv2. - Float64 = 12, -} -impl TryFrom for MetadataValueType { - type Error = (); - - fn try_from(value: u32) -> Result { - // TODO: consider a macro solution to this? - for test_value in [ - MetadataValueType::UInt8, - MetadataValueType::Int8, - MetadataValueType::UInt16, - MetadataValueType::Int16, - MetadataValueType::UInt32, - MetadataValueType::Int32, - MetadataValueType::Float32, - MetadataValueType::Bool, - MetadataValueType::String, - MetadataValueType::Array, - MetadataValueType::UInt64, - MetadataValueType::Int64, - MetadataValueType::Float64, - ] { - if value == test_value as u32 { - return Ok(test_value); - } - } - Err(()) - } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum MetadataValue { - UInt8(u8), - Int8(i8), - UInt16(u16), - Int16(i16), - UInt32(u32), - Int32(i32), - Float32(f32), - Bool(bool), - String(String), - Array(MetadataArrayValue), - UInt64(u64), - Int64(i64), - Float64(f64), -} -impl MetadataValue { - fn read_key_value( - ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result<(String, Self), GgufLoadError> { - let key = util::read_string(reader, ctx.use_64_bit_length)?; - let value_type = MetadataValueType::try_from(util::read_u32(reader)?) - .expect("TODO: handle invalid value types"); - let value = Self::read_value(ctx, reader, value_type)?; - - Ok((key, value)) - } - - fn read_value( - ctx: &GgufContext, - reader: &mut dyn BufRead, - value_type: MetadataValueType, - ) -> Result { - match value_type { - MetadataValueType::UInt8 => Self::read_u8(ctx, reader).map(MetadataValue::UInt8), - MetadataValueType::Int8 => Self::read_i8(ctx, reader).map(MetadataValue::Int8), - MetadataValueType::UInt16 => Self::read_u16(ctx, reader).map(MetadataValue::UInt16), - MetadataValueType::Int16 => Self::read_i16(ctx, reader).map(MetadataValue::Int16), - MetadataValueType::UInt32 => Self::read_u32(ctx, reader).map(MetadataValue::UInt32), - MetadataValueType::Int32 => Self::read_i32(ctx, reader).map(MetadataValue::Int32), - MetadataValueType::Float32 => Self::read_f32(ctx, reader).map(MetadataValue::Float32), - MetadataValueType::Bool => Self::read_bool(ctx, reader).map(MetadataValue::Bool), - MetadataValueType::String => Self::read_string(ctx, reader).map(MetadataValue::String), - MetadataValueType::Array => Self::read_array(ctx, reader).map(MetadataValue::Array), - MetadataValueType::UInt64 => Self::read_u64(ctx, reader).map(MetadataValue::UInt64), - MetadataValueType::Int64 => Self::read_i64(ctx, reader).map(MetadataValue::Int64), - MetadataValueType::Float64 => Self::read_f64(ctx, reader).map(MetadataValue::Float64), - } - } - - fn read_u8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_bytes::<1>(reader)?[0]) - } - - fn read_i8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_bytes::<1>(reader)?[0] as i8) - } - - fn read_u16(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(u16::from_le_bytes(util::read_bytes::<2>(reader)?)) - } - - fn read_i16(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(i16::from_le_bytes(util::read_bytes::<2>(reader)?)) - } - - fn read_u32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_u32(reader)?) - } - - fn read_i32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_i32(reader)?) - } - - fn read_f32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_f32(reader)?) - } - - fn read_bool(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_bool(reader)?) - } - - fn read_string(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_string(reader, ctx.use_64_bit_length)?) - } - - fn read_array( - ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result { - MetadataArrayValue::read_value(ctx, reader) - } - - fn read_u64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_u64(reader)?) - } - - fn read_i64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_i64(reader)?) - } - - fn read_f64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_f64(reader)?) - } - - pub fn as_uint32(&self) -> Option { - match self { - Self::UInt32(v) => Some(*v), - _ => None, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum MetadataArrayValue { - UInt8(Vec), - Int8(Vec), - UInt16(Vec), - Int16(Vec), - UInt32(Vec), - Int32(Vec), - Float32(Vec), - Bool(Vec), - String(Vec), - Array(Vec), - UInt64(Vec), - Int64(Vec), - Float64(Vec), -} -impl MetadataArrayValue { - fn read_value(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - let value_type = MetadataValueType::try_from(util::read_u32(reader)?) - .expect("TODO: handle invalid value types"); - let length = util::read_length(reader, ctx.use_64_bit_length)?; - - struct ArrayReader<'a> { - ctx: &'a GgufContext, - reader: &'a mut dyn BufRead, - length: usize, - } - impl ArrayReader<'_> { - fn read( - &mut self, - value_reader: impl Fn(&GgufContext, &mut dyn BufRead) -> Result, - value_constructor: impl Fn(Vec) -> MetadataArrayValue, - ) -> Result { - (0..self.length) - .map(|_| value_reader(self.ctx, self.reader)) - .collect::, _>>() - .map(value_constructor) - } - } - - let mut reader = ArrayReader { - ctx, - reader, - length, - }; - use MetadataValue as MV; - use MetadataValueType as MVT; - match value_type { - MVT::UInt8 => reader.read(MV::read_u8, Self::UInt8), - MVT::Int8 => reader.read(MV::read_i8, Self::Int8), - MVT::UInt16 => reader.read(MV::read_u16, Self::UInt16), - MVT::Int16 => reader.read(MV::read_i16, Self::Int16), - MVT::UInt32 => reader.read(MV::read_u32, Self::UInt32), - MVT::Int32 => reader.read(MV::read_i32, Self::Int32), - MVT::Float32 => reader.read(MV::read_f32, Self::Float32), - MVT::Bool => reader.read(MV::read_bool, Self::Bool), - MVT::String => reader.read(MV::read_string, Self::String), - MVT::Array => reader.read(MV::read_array, Self::Array), - MVT::UInt64 => reader.read(MV::read_u64, Self::UInt64), - MVT::Int64 => reader.read(MV::read_i64, Self::Int64), - MVT::Float64 => reader.read(MV::read_f64, Self::Float64), - } - } -} - #[derive(Debug, Clone, PartialEq)] pub struct TensorInfo { pub dimensions: Vec, From e182444547822068bb993e7a1f624d90dcf33dca Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 28 Aug 2023 02:58:18 +0200 Subject: [PATCH 10/33] wip: rewrite loader to use GGUF; almost wire up llama --- crates/ggml/src/format/gguf/mod.rs | 35 +- crates/llm-base/Cargo.toml | 2 +- crates/llm-base/src/lib.rs | 2 +- crates/llm-base/src/loader.rs | 514 +++++++++++------------------ crates/llm-base/src/lora.rs | 62 ++-- crates/llm-base/src/model/mod.rs | 21 +- crates/llm-base/src/quantize.rs | 469 +++++++++++++------------- crates/llm-base/src/util.rs | 98 ------ crates/models/llama/src/lib.rs | 208 +++++------- 9 files changed, 596 insertions(+), 815 deletions(-) diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index bbc7cff9..bf0afcdd 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -5,15 +5,17 @@ use std::{ io::{BufRead, Seek}, }; +use super::{data_size, header_size, ContainerType, ContainerTypeReadError}; use crate::{util, ElementType}; +use thiserror::Error; mod metadata; pub use metadata::*; pub const DEFAULT_ALIGNMENT: u32 = 32; -#[derive(Debug, thiserror::Error)] +#[derive(Debug, Error)] /// Errors that can occur while loading a model. pub enum GgufLoadError { #[error("invalid GGUF file magic value: {0}")] @@ -39,15 +41,20 @@ pub enum GgufLoadError { /// The format type that was encountered. ftype: u32, }, - #[error("invariant broken: {0}")] - /// An invariant was broken. - InvariantBroken(String), } +#[derive(Debug, Error)] +/// Errors that can occur while saving a model. +pub enum GgufSaveError { + // TODO! +} + +pub type TensorInfos = HashMap; + #[derive(Debug, Clone, PartialEq)] pub struct Gguf { - pub metadata: HashMap, - pub tensor_infos: HashMap, + pub metadata: Metadata, + pub tensor_infos: TensorInfos, pub tensor_data_position: u64, } impl Gguf { @@ -105,6 +112,8 @@ struct GgufContext { pub struct TensorInfo { pub dimensions: Vec, pub element_type: ElementType, + /// This offset is relative to `tensor_data`, not to the start + /// of the file, to make it easier for writers to write the file. pub offset: u64, } impl TensorInfo { @@ -138,4 +147,18 @@ impl TensorInfo { }, )) } + + /// Calculate the size of the tensor's values in bytes. + pub fn calc_size(&self) -> usize { + data_size(self.element_type, self.dimensions.iter().product()) + } + + /// Calculates the absolute size in bytes of the tensor's data, given the mmap flag. + pub fn calc_absolute_size(&self, mmap: bool) -> usize { + if mmap { + header_size() + } else { + header_size() + self.calc_size() + } + } } diff --git a/crates/llm-base/Cargo.toml b/crates/llm-base/Cargo.toml index 47157429..badcbdc6 100644 --- a/crates/llm-base/Cargo.toml +++ b/crates/llm-base/Cargo.toml @@ -11,7 +11,7 @@ readme = "../../README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -ggml = { path = "../ggml", version = "0.2.0-dev", features = ["pre-gguf-formats"] } +ggml = { path = "../ggml", version = "0.2.0-dev" } bytemuck = { workspace = true } rand = { workspace = true } diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index b69cffa0..0c54c954 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -31,7 +31,7 @@ pub use inference_session::{ pub use llm_samplers::prelude::{Sampler, SamplerChain}; pub use loader::{ load, load_progress_callback_stdout, ContainerType, FileMagic, FileType, FileTypeFormat, - LoadError, LoadProgress, Loader, TensorLoader, + LoadError, LoadProgress, MetadataExt, ModelTensorLoader, }; pub use lora::{LoraAdapter, LoraParameters}; pub use memmap2::Mmap; diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 4b2d363e..50cdc725 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -1,19 +1,21 @@ use std::{ - collections::HashMap, error::Error, - fmt::{Debug, Display, Formatter}, + fmt::{Display, Formatter}, fs::File, - io::{BufRead, BufReader, Read, Seek, SeekFrom}, + io::{BufReader, Read, Seek, SeekFrom}, path::{Path, PathBuf}, sync::Arc, }; use crate::{ - util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelContext, ModelParameters, - TokenId, Tokenizer, TokenizerLoadError, TokenizerSource, + Hyperparameters, KnownModel, LoraAdapter, ModelContext, ModelParameters, TokenizerLoadError, + TokenizerSource, }; use ggml::{ - format::ggml::{LoadError as FormatLoadError, PartialHyperparameters, TensorLoadInfo}, + format::gguf::{ + self, GgufLoadError, Metadata, MetadataValue, MetadataValueType, + MetadataValueTypeFromRustType, TensorInfo, + }, Context, MAX_NAME_LENGTH, }; pub use ggml::{format::ContainerType, util::FileMagic}; @@ -178,7 +180,7 @@ impl Display for FileTypeFormat { /// Each variant represents a step within the process of loading the model. /// These can be used to report progress to the user. #[derive(Clone, PartialEq, Eq, Debug)] -pub enum LoadProgress { +pub enum LoadProgress<'a> { /// The hyperparameters have been loaded from the model. HyperparametersLoaded, /// The context has been created. @@ -189,9 +191,9 @@ pub enum LoadProgress { /// A tensor was patched with a LoRA. LoraApplied { /// The name of the patched tensor. - name: String, + name: &'a str, /// LoRA file the patch was applied from. - source: PathBuf, + source: &'a Path, }, /// A tensor from the current part has been loaded. TensorLoaded { @@ -226,20 +228,6 @@ pub enum LoadError { /// The path that failed. path: PathBuf, }, - #[error("no parent path for {path:?}")] - /// There is no parent path for a given path. - NoParentPath { - /// The path without a parent. - path: PathBuf, - }, - #[error("unable to read exactly {bytes} bytes")] - /// Reading exactly `bytes` from a file failed. - ReadExactFailed { - /// The original error. - source: std::io::Error, - /// The number of bytes that were attempted to be read. - bytes: usize, - }, #[error("non-specific I/O error")] /// A non-specific IO error. Io(#[from] std::io::Error), @@ -249,16 +237,10 @@ pub enum LoadError { #[error("invalid integer conversion")] /// One of the integers encountered could not be converted to a more appropriate type. InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("unsupported ftype: {0}")] - /// The `ftype` hyperparameter had an invalid value. This usually means that the format used - /// by this file is unrecognized by this version of `llm`. - UnsupportedFileType(i32), - #[error("invalid magic number {magic} for {path:?}")] - /// An invalid magic number was encountered during the loading process. + #[error("invalid magic value {magic}")] + /// An invalid magic value was encountered during the loading process. InvalidMagic { - /// The path that failed. - path: PathBuf, - /// The magic number that was encountered. + /// The magic value that was encountered. magic: FileMagic, }, #[error("invalid file format {container_type:?}")] @@ -267,94 +249,69 @@ pub enum LoadError { /// The format that was encountered. container_type: ContainerType, }, - #[error("invalid value {ftype} for `f16` in hyperparameters")] - /// The `f16` hyperparameter had an invalid value. - HyperparametersF16Invalid { - /// The format type that was encountered. - ftype: i32, - }, #[error("unknown tensor `{tensor_name}` in {path:?}")] - /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during - /// the model prelude. + /// The tensor `tensor_name` is required for this model architecture, + /// but was not found in the model. UnknownTensor { /// The name of the tensor. tensor_name: String, /// The path that failed. path: PathBuf, }, - #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] - /// The tensor `tensor_name` did not match its expected size. - TensorWrongSize { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - /// The tensor `tensor_name` did not have the expected format type. - #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] + /// The tensor `tensor_name` had an unsupported element type. + #[error("invalid element type {element_type} for tensor `{tensor_name}` in {path:?}")] UnsupportedElementType { /// The name of the tensor. tensor_name: String, - /// The format type that was encountered. - ftype: u32, + /// The element type that was encountered. + element_type: u32, /// The path that failed. path: PathBuf, }, - /// An invariant was broken. - /// - /// This error is not relevant unless `loader2` is being used. - #[error("invariant broken: {invariant} in {path:?}")] - InvariantBroken { - /// The path that failed. - path: Option, - /// The invariant that was broken. - invariant: String, - }, - /// The model could not be created. - /// - /// This implies that there were no tensors in the model to be loaded. - /// - /// This error is not relevant unless `loader2` is being used. - #[error("could not create model from {path:?}")] - ModelNotCreated { - /// The path that failed. - path: PathBuf, - }, - /// Multiple parts of the model were found. - /// - /// Multi-part models are not supported. Please convert the model to a single part. - #[error("multipart models are not supported")] - MultipartNotSupported { - /// The paths that were found. - paths: Vec, - }, /// The tokenizer could not be loaded. #[error("could not load tokenizer {path:?}: {error}")] TokenizerLoadFail { - /// The invalid tokenizer path + /// The path of the tokenizer. path: PathBuf, /// The error that occurred. error: Box, }, - /// There is insufficient information to guess the model architecture from the provided file. - /// - /// A model architecture must be provided to load the model. + /// The quantization version was missing, despite this model containing quantized tensors. + #[error("quantization version was missing, despite model containing quantized tensors")] + MissingQuantizationVersion, + /// The quantization version is not supported by this version of `llm`. + #[error("quantization version {quantization_version:?} is not supported")] + UnsupportedQuantizationVersion { + /// The quantization version that was encountered. + quantization_version: MetadataValue, + }, + /// A tensor with an unsupported number of dimensions was encountered. #[error( - "could not guess model architecture from {path:?}. Please provide a model architecture." + "tensor {tensor_name} has {dimensions} dimensions, but only 1-3 dimensions are supported" )] - MissingModelArchitecture { - /// The path that failed. - path: PathBuf, + UnsupportedTensorDimensionCount { + /// The name of the tensor. + tensor_name: String, + /// The number of dimensions that were encountered. + dimensions: usize, + }, + /// The model expected a metadata key-value pair, but the key was missing. + #[error("missing metadata key {key:?}")] + MissingMetadataKey { + /// The key that was missing. + key: String, + }, + /// The metadata key-value pair was not of the expected type. + #[error("metadata key {key:?} was not of the expected type")] + InvalidMetadataType { + /// The key with the invalid type. + key: String, + /// The expected type. + expected_type: MetadataValueType, + /// The actual type. + actual_type: MetadataValueType, }, -} -impl From for LoadError { - fn from(value: util::FindAllModelFilesError) -> Self { - match value { - util::FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path }, - util::FindAllModelFilesError::IO(err) => LoadError::Io(err), - } - } } impl From for LoadError { fn from(value: TokenizerLoadError) -> Self { @@ -364,42 +321,72 @@ impl From for LoadError { } } } - impl LoadError { #[doc(hidden)] - pub fn from_format_error(value: FormatLoadError, path: PathBuf) -> Self { + pub fn from_gguf(value: GgufLoadError, path: PathBuf) -> Self { match value { - FormatLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { path, magic }, - FormatLoadError::InvalidFormatVersion(container_type) => { + GgufLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { magic }, + GgufLoadError::InvalidFormatVersion(container_type) => { LoadError::InvalidFormatVersion { container_type } } - FormatLoadError::Io(err) => LoadError::Io(err), - FormatLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err), - FormatLoadError::InvalidIntegerConversion(err) => { + GgufLoadError::Io(err) => LoadError::Io(err), + GgufLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err), + GgufLoadError::InvalidIntegerConversion(err) => { LoadError::InvalidIntegerConversion(err) } - FormatLoadError::ImplementationError(err) => err, - FormatLoadError::UnsupportedElementType { tensor_name, ftype } => { + GgufLoadError::UnsupportedElementType { tensor_name, ftype } => { LoadError::UnsupportedElementType { path, tensor_name, - ftype, + element_type: ftype, } } - FormatLoadError::InvariantBroken(invariant) => LoadError::InvariantBroken { - path: Some(path), - invariant, - }, } } } -/// Used by models to fetch tensors from a loader. -pub trait TensorLoader { - /// Gets a tensor from the loader. - fn load(&mut self, name: &str) -> Result; - /// Finish loading the model, returning the context. - fn finish(self) -> ModelContext; +#[doc(hidden)] +pub trait MetadataExt { + fn fallible_get(&self, key: &str) -> Result<&MetadataValue, LoadError>; + fn fallible_typed_get<'a, T: MetadataValueTypeFromRustType>( + &'a self, + key: &'a str, + getter: impl Fn(&MetadataValue) -> Option<&T>, + ) -> Result<&'a T, LoadError>; + fn fallible_get_countable(&self, key: &str) -> Result; +} +impl MetadataExt for Metadata { + fn fallible_get(&self, key: &str) -> Result<&MetadataValue, LoadError> { + self.get(key).ok_or_else(|| LoadError::MissingMetadataKey { + key: key.to_owned(), + }) + } + + fn fallible_typed_get<'a, T: MetadataValueTypeFromRustType>( + &'a self, + key: &'a str, + getter: impl Fn(&MetadataValue) -> Option<&T>, + ) -> Result<&'a T, LoadError> { + let metadata_value = self.fallible_get(key)?; + getter(metadata_value).ok_or_else(|| LoadError::InvalidMetadataType { + key: key.to_string(), + expected_type: T::value_type(), + actual_type: metadata_value.value_type(), + }) + } + + fn fallible_get_countable(&self, key: &str) -> Result { + let metadata_value = self.fallible_get(key)?; + match metadata_value { + MetadataValue::UInt32(v) => Ok(usize::try_from(*v)?), + MetadataValue::UInt64(v) => Ok(usize::try_from(*v)?), + _ => Err(LoadError::InvalidMetadataType { + key: key.to_string(), + expected_type: MetadataValueType::UInt64, + actual_type: metadata_value.value_type(), + }), + } + } } /// Load a GGML model from the `path` and configure it per the `params`. The status @@ -420,7 +407,7 @@ pub fn load( path: &Path, tokenizer_source: TokenizerSource, params: ModelParameters, - load_progress_callback: impl FnMut(LoadProgress), + mut load_progress_callback: impl FnMut(LoadProgress), ) -> Result { if !path.exists() { return Err(LoadError::FileDoesNotExist { @@ -428,12 +415,7 @@ pub fn load( }); } - let paths = util::find_all_model_files(path)?; - if paths.len() != 1 { - return Err(LoadError::MultipartNotSupported { paths }); - } - - let file = File::open(path).map_err(|e| LoadError::OpenFileFailed { + let mut file = File::open(path).map_err(|e| LoadError::OpenFileFailed { source: e, path: path.to_owned(), })?; @@ -441,52 +423,40 @@ pub fn load( log::trace!("Read model file from {:?}", path); let tokenizer = tokenizer_source.retrieve(path)?; - let mut loader = Loader::new(tokenizer, load_progress_callback); - ggml::format::ggml::load(&mut reader, &mut loader) - .map_err(|err| LoadError::from_format_error(err, path.to_owned()))?; + let gguf = + gguf::Gguf::load(&mut reader).map_err(|e| LoadError::from_gguf(e, path.to_owned()))?; log::trace!("Loaded GGML model from reader"); - let Loader { - hyperparameters, - tokenizer, - tensors, - mut load_progress_callback, - container_type, - .. - } = loader; - - let quantization_version = (&hyperparameters as &M::Hyperparameters) - .file_type() - .map(|ft| ft.quantization_version) - .unwrap_or_default(); - let quantization_version = if quantization_version == 0 { - // HACK: I think llama.cpp does not actually write the quantization version correctly, - // so we need to guess it from the container type. - if container_type == ContainerType::Ggjt(2) { - 1 - } else if container_type == ContainerType::Ggjt(3) { - 2 - } else { - quantization_version - } - } else { - quantization_version - }; + let quantization_version = gguf.metadata.get("general.quantization_version"); log::trace!( "Determined quantization version of model as {:?}", quantization_version ); // TODO: this is temporary while we figure out how to handle this - if tensors.values().any(|t| t.element_type.is_quantized()) { - assert_eq!(quantization_version, 2, "quantization version must be 2"); + let any_quantized = gguf + .tensor_infos + .values() + .any(|t| t.element_type.is_quantized()); + if any_quantized { + match quantization_version { + Some(MetadataValue::UInt32(2)) => { + // Currently supported version + } + Some(quantization_version) => { + return Err(LoadError::UnsupportedQuantizationVersion { + quantization_version: quantization_version.clone(), + }) + } + None => return Err(LoadError::MissingQuantizationVersion), + } } - let use_mmap = - params.prefer_mmap && container_type.support_mmap() && params.lora_adapters.is_none(); + let use_mmap = params.prefer_mmap && params.lora_adapters.is_none(); - let ctx_size = tensors + let ctx_size = gguf + .tensor_infos .values() .map(|ti| ti.calc_absolute_size(use_mmap)) .sum::(); @@ -503,28 +473,25 @@ pub fn load( path: lora_path.to_owned(), })?; let mut lora_reader = BufReader::new(&lora_file); - // TODO: Consider updating the progress callback to report the progress of the LoRA file. - // Most LoRAs are small enough that this is not necessary, but it would be nice to have. - let mut lora_loader: Loader = - Loader::new(Tokenizer::empty_embedded(), |_| {}); - ggml::format::ggml::load(&mut lora_reader, &mut lora_loader) - .map_err(|err| LoadError::from_format_error(err, lora_path.to_owned()))?; + let gguf = gguf::Gguf::load(&mut lora_reader).map_err(|e| LoadError::from_gguf(e, lora_path.to_owned()))?; // Collect the names of the tensors that should be patched - let tensors_to_patch = lora_loader - .tensors + let tensors_to_patch = gguf + .tensor_infos .keys() .filter_map(|k| Some(k.rsplit_once('.')?.0.to_owned())) .collect(); log::trace!("Loaded LoRA weights"); // Return the LoRA patches + #[allow(unreachable_code)] Ok::<_, LoadError>(LoraAdapter { - scaling: lora_loader.hyperparameters.calculate_scaling(), - tensors: lora_loader.tensors, + tensors: gguf.tensor_infos.clone(), tensors_to_patch, file: lora_file, path: lora_path.to_owned(), + gguf, + scaling: todo!("Calculate scaling from LoRA file metadata (GGUF does not have standardised metadata yet)"), }) }) .collect(); @@ -543,19 +510,20 @@ pub fn load( (Context::new_with_allocate(ctx_size), file.metadata()?.len()) }; - let tensors_len = tensors.len(); - let tl = MmapCompatibleLoader { - path: path.to_owned(), - file, - tensors, - context, + let hyperparameters = ::read_gguf(&gguf.metadata)?; + let tl = ModelTensorLoader { + tensor_loader: TensorLoader { + file: &mut file, + gguf: &gguf, + context, + }, lora_adapters, load_progress_callback: &mut load_progress_callback, - loaded_tensors: Default::default(), + loaded_tensor_count: 0, }; - let model = KnownModel::new(hyperparameters, params, tokenizer, tl)?; + let tensors_len = gguf.tensor_infos.len(); (load_progress_callback)(LoadProgress::Loaded { file_size, tensor_count: tensors_len, @@ -566,180 +534,88 @@ pub fn load( Ok(model) } -/// A GGML format loader for LLMs. -pub struct Loader { - // Input - load_progress_callback: F, - - // Input/Output - /// The tokenizer of the model. - pub tokenizer: Tokenizer, - - // Output - /// The container type of the model. - pub container_type: ContainerType, - /// The hyperparameters of the model. - pub hyperparameters: Hp, - /// The tensors of the model. - pub tensors: HashMap, -} -impl Loader { - /// Creates a new loader. - pub fn new(tokenizer: Tokenizer, load_progress_callback: F) -> Self { - Self { - load_progress_callback, - - container_type: ContainerType::Ggml, - hyperparameters: Hp::default(), - tokenizer, - tensors: HashMap::default(), - } - } -} -impl ggml::format::ggml::LoadHandler - for Loader -{ - fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> { - self.container_type = container_type; - Ok(()) - } - - fn vocabulary_token(&mut self, i: usize, token: Vec, score: f32) -> Result<(), LoadError> { - if let Tokenizer::Embedded(mv) = &mut self.tokenizer { - let id = match TokenId::try_from(i) { - Ok(id) => id, - Err(err) => return Err(LoadError::InvalidIntegerConversion(err)), - }; - - mv.push_token(id, token, score); - } - - Ok(()) - } - - fn read_hyperparameters( - &mut self, - reader: &mut dyn BufRead, - ) -> Result { - // NOTE: Field order matters! Data is laid out in the file exactly in this order. - let hyperparameters = Hp::read_ggml(reader)?; - let partial = PartialHyperparameters { - n_vocab: hyperparameters.n_vocabulary(), - }; - self.hyperparameters = hyperparameters; - (self.load_progress_callback)(LoadProgress::HyperparametersLoaded); - - Ok(partial) - } - - fn tensor_buffer(&mut self, info: TensorLoadInfo) -> Result<(), LoadError> { - self.tensors.insert(info.name.clone(), info); - Ok(()) - } -} - -struct MmapCompatibleLoader<'a> { - path: PathBuf, - file: File, - tensors: HashMap, - context: Context, - lora_adapters: Option>, - load_progress_callback: &'a mut dyn FnMut(LoadProgress), - loaded_tensors: HashMap, +/// A helper struct for loading tensors from a model. +pub struct ModelTensorLoader<'a> { + pub(crate) tensor_loader: TensorLoader<'a>, + pub(crate) lora_adapters: Option>, + pub(crate) load_progress_callback: &'a mut dyn FnMut(LoadProgress), + pub(crate) loaded_tensor_count: usize, } -impl TensorLoader for MmapCompatibleLoader<'_> { - fn load(&mut self, name: &str) -> Result { - let info = self.tensors.get(name).ok_or(LoadError::UnknownTensor { - tensor_name: String::from(name), - path: Default::default(), - })?; - - let mut main_context = FileContext::new(&self.context, &mut self.file, &self.path); - - let mut tensor = main_context.get_tensor(info)?; +impl ModelTensorLoader<'_> { + /// Load a tensor from the model. + pub fn load(&mut self, name: &str) -> Result { + let (mut tensor, info) = self.tensor_loader.load(name)?; if let Some(lora_adapters) = &mut self.lora_adapters { for lora_adapter in lora_adapters { - lora_adapter.patch(info, &mut tensor)?; + lora_adapter.patch(name, info, &mut tensor)?; (self.load_progress_callback)(LoadProgress::LoraApplied { - name: name.to_owned(), - source: lora_adapter.path.to_owned(), + name, + source: &lora_adapter.path, }); } } + self.loaded_tensor_count += 1; (self.load_progress_callback)(LoadProgress::TensorLoaded { - current_tensor: self.loaded_tensors.len(), - tensor_count: self.tensors.len(), + current_tensor: self.loaded_tensor_count, + tensor_count: self.tensor_loader.gguf.tensor_infos.len(), }); - self.loaded_tensors.insert(name.to_owned(), tensor.share()); Ok(tensor) } - fn finish(self) -> ModelContext { + /// Finish loading tensors from the model, and get the model context. + pub fn finish(self) -> ModelContext { // We can ignore this warning as it's OK to share this particular // context around, being that it is immutable. #[allow(clippy::arc_with_non_send_sync)] - ModelContext(Arc::new(self.context)) + ModelContext(Arc::new(self.tensor_loader.finish())) } } -pub(crate) struct FileContext<'a> { - context: &'a Context, - file: &'a mut File, - path: &'a Path, +pub(crate) struct TensorLoader<'a> { + pub file: &'a mut File, + pub context: Context, + pub gguf: &'a gguf::Gguf, } -impl<'a> FileContext<'a> { - pub(crate) fn new(context: &'a Context, file: &'a mut File, path: &'a Path) -> Self { - Self { - context, - file, - path, - } - } - - pub(crate) fn get_tensor(&mut self, info: &TensorLoadInfo) -> Result { - let name = &info.name; - let ne = info.dims(); - let dims = ne.len(); - - if dims != info.n_dims { - return Err(LoadError::InvariantBroken { - path: Some(self.path.to_owned()), - invariant: format!( - "the tensor {name} should have {} dimensions, not {}", - info.n_dims, dims - ), - }); - } - - let mut tensor = match dims { - 1 => self.context.new_tensor_1d(info.element_type, ne[0]), - 2 => self.context.new_tensor_2d(info.element_type, ne[0], ne[1]), - 3 => self - .context - .new_tensor_3d(info.element_type, ne[0], ne[1], ne[2]), - _ => { - return Err(LoadError::InvariantBroken { - path: Some(self.path.to_owned()), - invariant: format!( - "the tensor {name} should have between 1 and 3 dimensions, not {dims}" - ), - }) +impl TensorLoader<'_> { + pub fn load(&mut self, name: &str) -> Result<(ggml::Tensor, &TensorInfo), LoadError> { + let info = self + .gguf + .tensor_infos + .get(name) + .ok_or(LoadError::UnknownTensor { + tensor_name: String::from(name), + path: Default::default(), + })?; + + let ty = info.element_type; + let dims = &info.dimensions; + + let mut tensor = match dims.len() { + 1 => self.context.new_tensor_1d(ty, dims[0]), + 2 => self.context.new_tensor_2d(ty, dims[0], dims[1]), + 3 => self.context.new_tensor_3d(ty, dims[0], dims[1], dims[2]), + other => { + return Err(LoadError::UnsupportedTensorDimensionCount { + tensor_name: name.to_string(), + dimensions: other, + }); } }; + let offset = self.gguf.tensor_data_position + info.offset; match self.context.storage().as_mmap() { Some(mmap) => unsafe { - let ptr = mmap.as_ptr().offset(info.start_offset as isize); + let ptr = mmap.as_ptr().offset(offset as isize); tensor.set_data(ptr as *mut std::ffi::c_void); }, None => { let buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; - self.file.seek(SeekFrom::Start(info.start_offset))?; + self.file.seek(SeekFrom::Start(offset))?; self.file.read_exact(buf)?; } } @@ -751,7 +627,11 @@ impl<'a> FileContext<'a> { name }; - Ok(tensor.set_name(tensor_name)) + Ok((tensor.set_name(tensor_name), info)) + } + + pub fn finish(self) -> Context { + self.context } } diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index f403d875..4b5a6c9c 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -1,9 +1,8 @@ use crate::{ - loader::FileContext, model::HyperparametersWriteError, util, FileType, Hyperparameters, - LoadError, + loader::TensorLoader, model::HyperparametersWriteError, FileType, Hyperparameters, LoadError, }; -use ggml::{format::ggml::TensorLoadInfo, GraphExecutionPlan}; +use ggml::{format::gguf::TensorInfo, GraphExecutionPlan}; use std::{ collections::{HashMap, HashSet}, fs::File, @@ -25,22 +24,15 @@ impl LoraParameters { } } impl Hyperparameters for LoraParameters { - fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { - Ok(LoraParameters { - r: util::read_i32(reader)?, - alpha: util::read_i32(reader)?, - }) + fn read_gguf(metadata: &ggml::format::gguf::Metadata) -> Result { + todo!() } - fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { - util::write_i32(writer, self.r)?; - util::write_i32(writer, self.alpha)?; - Ok(()) - } - - fn n_vocabulary(&self) -> usize { - // LoRA adapters do not have a vocabulary. - 0 + fn write_gguf( + &self, + metadata: &mut ggml::format::gguf::Metadata, + ) -> Result<(), HyperparametersWriteError> { + todo!() } fn file_type(&self) -> Option { @@ -57,37 +49,43 @@ pub struct LoraAdapter { /// Scaling to apply to the LoRA weights. pub scaling: f32, /// The tensors of the LoRA. - pub tensors: HashMap, + pub tensors: HashMap, /// Names of the tensors that should be patched. pub tensors_to_patch: HashSet, /// File containing the LoRA weights. pub file: File, /// Path to the LoRA file. pub path: PathBuf, + /// The loaded GGUF for the LoRA. + pub gguf: ggml::format::gguf::Gguf, } impl LoraAdapter { /// Patch a tensor via LoRA pub fn patch( &mut self, - info: &TensorLoadInfo, + name: &str, + info: &TensorInfo, tensor: &mut ggml::Tensor, ) -> Result<(), LoadError> { // Check if we need to patch this tensor - let name = &info.name; if !self.tensors_to_patch.contains(name) { return Ok(()); } - let a_info = self.get_info(&format!("{}.loraA", name))?; - let b_info = self.get_info(&format!("{}.loraB", name))?; + let a_name = format!("{}.loraA", name); + let a_info = self.get_info(&a_name)?; + + let b_name = format!("{}.loraB", name); + let b_info = self.get_info(&b_name)?; let must_scale = self.scaling != 1.0; // Calculate the size of the patch context via the following steps: // 1. Calculate the size of the two `a` and `b` tensors // 2. Calculate the size of the original tensor // 3. Calculate the size of the `ba` and tensors. It has the same dimensions as the original tensor, but is of the element type of the `a` or `b` tensor e.g. fp16 - let ba_size = ggml::format::tensor_size(a_info.element_type, info.dims().iter().product()); + let ba_size = + ggml::format::tensor_size(a_info.element_type, info.dimensions.iter().product()); let mut patch_context_size = a_info.calc_absolute_size(false) + b_info.calc_absolute_size(false) + info.calc_absolute_size(false) @@ -96,7 +94,7 @@ impl LoraAdapter { // 3b. (Optional) If we need to scale the `ba` tensor, we need to allocate for a second `ba` and the `scaled` tensors which will be crated as an `f32` tensor. if must_scale { let scaled_size = - ggml::format::tensor_size(ggml::ElementType::F32, info.dims().iter().product()); + ggml::format::tensor_size(ggml::ElementType::F32, info.dimensions.iter().product()); patch_context_size += scaled_size + ba_size; } @@ -106,14 +104,18 @@ impl LoraAdapter { // Create a temporary context for the patching operations // TODO: test if GPU can be enabled (make it configurable) let patch_context = ggml::Context::new_with_allocate(patch_context_size); - let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path); + let mut loader = TensorLoader { + file: &mut self.file, + context: patch_context, + gguf: &self.gguf, + }; // Load the A and B tensors - let a = patch_file.get_tensor(&a_info)?; - let b = patch_file.get_tensor(&b_info)?; - - //Build a ggml context and apply the patch + let (a, _) = loader.load(&a_name)?; + let (b, _) = loader.load(&b_name)?; + // Build a ggml context and apply the patch + let patch_context = loader.finish(); let mut gf = patch_context.create_compute_graph(); // LoRA formula: w = w + ba*s @@ -141,7 +143,7 @@ impl LoraAdapter { Ok(()) } - fn get_info(&self, name: &str) -> Result { + fn get_info(&self, name: &str) -> Result { self.tensors .get(name) .cloned() diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 494d5aae..e383db1d 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -8,13 +8,13 @@ use std::{ sync::Arc, }; -use ggml::accelerator::Backend; +use ggml::{accelerator::Backend, format::gguf::Metadata}; use regex::Regex; use thiserror::Error; use crate::{ loader::TensorLoader, tokenizer::TokenId, FileType, InferenceSession, InferenceSessionConfig, - LoadError, LoadProgress, Tokenizer, TokenizerSource, + LoadError, LoadProgress, ModelTensorLoader, Tokenizer, TokenizerSource, }; /// Common functions for model evaluation @@ -43,12 +43,12 @@ pub trait KnownModel: Send + Sync { /// Creates a new model from the provided [ModelParameters] hyperparameters. /// This function is called by the [load](crate::loader::load) function. - fn new( + fn new( hyperparameters: Self::Hyperparameters, params: ModelParameters, tokenizer: Tokenizer, - tensor_loader: impl TensorLoader, - ) -> Result + tensor_loader: ModelTensorLoader, + ) -> Result where Self: Sized; @@ -166,14 +166,11 @@ impl> Model for M { /// Implemented by model hyperparameters for interacting with hyperparameters /// without knowing what they are, as well as writing/reading them as required. pub trait Hyperparameters: Sized + Default + Debug + PartialEq + Eq { - /// Read the parameters in GGML format from a reader. - fn read_ggml(reader: &mut dyn BufRead) -> Result; - - /// Write the parameters in GGML format to a writer. - fn write_ggml(&self, writer: &mut dyn Write) -> Result<(), HyperparametersWriteError>; + /// Read the parameters from GGUF metadata. + fn read_gguf(metadata: &Metadata) -> Result; - /// Get the number of tokens in the embedded vocabulary, if any. - fn n_vocabulary(&self) -> usize; + /// Write the parameters to GGUF metadata. + fn write_gguf(&self, metadata: &mut Metadata) -> Result<(), HyperparametersWriteError>; /// Get the filetype of the model. fn file_type(&self) -> Option; diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index efb30044..a870b51e 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -1,12 +1,13 @@ +// TODO: Reimeplement entirely for GGUF! +#![allow(unused)] + //! Implements quantization of weights. use crate::{ loader::FileTypeFormat, model::HyperparametersWriteError, Hyperparameters, KnownModel, - LoadError, LoadProgress, Loader, Tokenizer, -}; -use ggml::format::ggml::{ - SaveContainerType, SaveError, SaveHandler, TensorLoadInfo, TensorSaveInfo, + LoadError, LoadProgress, Tokenizer, }; +use ggml::format::gguf::GgufSaveError; use half::f16; use regex::Regex; use std::{ @@ -122,18 +123,19 @@ pub enum QuantizeError { VocabularyScoringNotSupported, } impl QuantizeError { - pub(crate) fn from_format_error(value: SaveError, path: PathBuf) -> Self { - match value { - SaveError::Io(io) => QuantizeError::Io(io), - SaveError::InvalidIntegerConversion(e) => QuantizeError::InvalidIntegerConversion(e), - SaveError::ImplementationError(e) => e, - SaveError::InvariantBroken(invariant) => { - QuantizeError::InvariantBroken { path, invariant } - } - SaveError::VocabularyScoringNotSupported => { - QuantizeError::VocabularyScoringNotSupported - } - } + pub(crate) fn from_format_error(value: GgufSaveError, path: PathBuf) -> Self { + todo!() + // match value { + // SaveError::Io(io) => QuantizeError::Io(io), + // SaveError::InvalidIntegerConversion(e) => QuantizeError::InvalidIntegerConversion(e), + // SaveError::ImplementationError(e) => e, + // SaveError::InvariantBroken(invariant) => { + // QuantizeError::InvariantBroken { path, invariant } + // } + // SaveError::VocabularyScoringNotSupported => { + // QuantizeError::VocabularyScoringNotSupported + // } + // } } } @@ -142,84 +144,85 @@ pub fn quantize( reader: &mut R, writer: &mut W, tokenizer: Tokenizer, - save_container_type: SaveContainerType, quantization_type: ggml::Type, progress_callback: impl Fn(QuantizeProgress), ) -> Result<(), QuantizeError> { - // Sanity check - let quantization_target = QuantizationTarget::try_from(quantization_type).map_err(|_| { - QuantizeError::InvalidQuantizationTarget { - element_type: quantization_type, - } - })?; + // // Sanity check + // let quantization_target = QuantizationTarget::try_from(quantization_type).map_err(|_| { + // QuantizeError::InvalidQuantizationTarget { + // element_type: quantization_type, + // } + // })?; - // Load the model - let progress_callback = Arc::new(progress_callback); + // // Load the model + // let progress_callback = Arc::new(progress_callback); - let mut loader = Loader::::new(tokenizer, { - let progress_callback = progress_callback.clone(); - move |p| { - if let LoadProgress::HyperparametersLoaded = p { - progress_callback(QuantizeProgress::HyperparametersLoaded) - } - } - }); - ggml::format::ggml::load(reader, &mut loader) - .map_err(|err| LoadError::from_format_error(err, PathBuf::default()))?; - - // Save the quantized model, quantizing as we go - let Loader { - mut hyperparameters, - tokenizer, - tensors, - .. - } = loader; - - if let Some(ft) = hyperparameters.file_type_mut() { - ft.quantization_version = ggml::QNT_VERSION; - ft.format = quantization_target - .try_into() - .expect("format has no corresponding ftype"); - } + // let mut loader = Loader::::new(tokenizer, { + // let progress_callback = progress_callback.clone(); + // move |p| { + // if let LoadProgress::HyperparametersLoaded = p { + // progress_callback(QuantizeProgress::HyperparametersLoaded) + // } + // } + // }); + // ggml::format::ggml::load(reader, &mut loader) + // .map_err(|err| LoadError::from_format_error(err, PathBuf::default()))?; + + // // Save the quantized model, quantizing as we go + // let Loader { + // mut hyperparameters, + // tokenizer, + // tensors, + // .. + // } = loader; + + // if let Some(ft) = hyperparameters.file_type_mut() { + // ft.quantization_version = ggml::QNT_VERSION; + // ft.format = quantization_target + // .try_into() + // .expect("format has no corresponding ftype"); + // } - let tokenizer = match tokenizer { - Tokenizer::Embedded(v) => v.iter().collect::>(), - Tokenizer::HuggingFace(_) => vec![], - }; - - let to_quantize = M::quantize_tensors(); - let to_skip = M::skip_quantize_tensors(); - let mut saver = QuantizeSaver::new( - quantization_target, - &hyperparameters, - &tensors, - &to_quantize, - &to_skip, - reader, - |p| progress_callback(p), - ); - ggml::format::ggml::save( - writer, - &mut saver, - save_container_type, - &tokenizer, - &tensors.keys().cloned().collect::>(), - ) - .map_err(|err| QuantizeError::from_format_error(err, PathBuf::default()))?; - - // Final report - let sum_all: i64 = saver.history_all.iter().sum(); - progress_callback(QuantizeProgress::Finished { - original_size: saver.total_size_original, - reduced_size: saver.total_size_new, - history: saver - .history_all - .iter() - .map(|hist| *hist as f32 / sum_all as f32) - .collect(), - }); - - Ok(()) + // let tokenizer = match tokenizer { + // Tokenizer::Embedded(v) => v.iter().collect::>(), + // Tokenizer::HuggingFace(_) => vec![], + // }; + + // let to_quantize = M::quantize_tensors(); + // let to_skip = M::skip_quantize_tensors(); + // let mut saver = QuantizeSaver::new( + // quantization_target, + // &hyperparameters, + // &tensors, + // &to_quantize, + // &to_skip, + // reader, + // |p| progress_callback(p), + // ); + // ggml::format::ggml::save( + // writer, + // &mut saver, + // save_container_type, + // &tokenizer, + // &tensors.keys().cloned().collect::>(), + // ) + // .map_err(|err| QuantizeError::from_format_error(err, PathBuf::default()))?; + + // // Final report + // let sum_all: i64 = saver.history_all.iter().sum(); + // progress_callback(QuantizeProgress::Finished { + // original_size: saver.total_size_original, + // reduced_size: saver.total_size_new, + // history: saver + // .history_all + // .iter() + // .map(|hist| *hist as f32 / sum_all as f32) + // .collect(), + // }); + + // Ok(()) + + todo!("reimeplement for GGUF") } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -267,150 +270,150 @@ impl From for FileTypeFormat { } } -struct QuantizeSaver<'a, F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek> { - // Input - quantization_target: QuantizationTarget, - hyperparameters: &'a H, - tensors: &'a HashMap, - to_quantize: &'a [Regex], - to_skip: &'a [Regex], - source_reader: &'a mut R, - progress_callback: F, - - // Output - total_size_original: usize, - total_size_new: usize, - history_all: Vec, -} -impl<'a, F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek> - QuantizeSaver<'a, F, H, R> -{ - fn new( - quantization_target: QuantizationTarget, - hyperparameters: &'a H, - tensors: &'a HashMap, - to_quantize: &'a [Regex], - to_skip: &'a [Regex], - source_reader: &'a mut R, - progress_callback: F, - ) -> Self { - Self { - quantization_target, - hyperparameters, - tensors, - to_quantize, - to_skip, - source_reader, - progress_callback, - - total_size_original: 0, - total_size_new: 0, - history_all: vec![0; 16], - } - } -} -impl SaveHandler - for QuantizeSaver<'_, F, H, R> -{ - fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), QuantizeError> { - self.hyperparameters - .write_ggml(writer) - .map_err(QuantizeError::HyperparametersWriteError)?; - Ok(()) - } +// struct QuantizeSaver<'a, F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek> { +// // Input +// quantization_target: QuantizationTarget, +// hyperparameters: &'a H, +// tensors: &'a HashMap, +// to_quantize: &'a [Regex], +// to_skip: &'a [Regex], +// source_reader: &'a mut R, +// progress_callback: F, - fn tensor_data(&mut self, tensor_name: &str) -> Result { - let tensor = self.tensors.get(tensor_name).expect( - "tensor not found; should be impossible due to handler being populated from loader", - ); - - (self.progress_callback)(QuantizeProgress::TensorLoading { - name: tensor_name, - dims: tensor.dims, - n_elements: tensor.n_elements, - element_type: tensor.element_type, - }); - - // Quantize only 2D tensors - let quantize = tensor.n_dims == 2 - && self.to_quantize.iter().any(|re| re.is_match(tensor_name)) - && !self.to_skip.iter().any(|re| re.is_match(tensor_name)); - let raw_data = tensor.read_data(self.source_reader)?; - - if quantize && !matches!(tensor.element_type, ggml::Type::F32 | ggml::Type::F16) { - return Err(QuantizeError::UnsupportedElementType { - element_type: tensor.element_type, - }); - } +// // Output +// total_size_original: usize, +// total_size_new: usize, +// history_all: Vec, +// } +// impl<'a, F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek> +// QuantizeSaver<'a, F, H, R> +// { +// fn new( +// quantization_target: QuantizationTarget, +// hyperparameters: &'a H, +// tensors: &'a HashMap, +// to_quantize: &'a [Regex], +// to_skip: &'a [Regex], +// source_reader: &'a mut R, +// progress_callback: F, +// ) -> Self { +// Self { +// quantization_target, +// hyperparameters, +// tensors, +// to_quantize, +// to_skip, +// source_reader, +// progress_callback, - self.total_size_original += raw_data.len(); - - let (element_type, data) = if quantize { - (self.progress_callback)(QuantizeProgress::TensorQuantizing { name: tensor_name }); - - let data_f32: Vec = match tensor.element_type { - ggml::Type::F32 => raw_data - .chunks_exact(4) - .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap())) - .collect(), - ggml::Type::F16 => raw_data - .chunks_exact(2) - .map(|chunk| { - f16::from_bits(u16::from_le_bytes(chunk.try_into().unwrap())).to_f32() - }) - .collect(), - _ => unreachable!(), - }; - - let result = match self.quantization_target { - QuantizationTarget::Q4_0 => { - ggml::quantize_q4_0(&data_f32, tensor.n_elements, tensor.dims[0]) - } - QuantizationTarget::Q4_1 => { - ggml::quantize_q4_1(&data_f32, tensor.n_elements, tensor.dims[0]) - } - QuantizationTarget::Q5_0 => { - ggml::quantize_q5_0(&data_f32, tensor.n_elements, tensor.dims[0]) - } - QuantizationTarget::Q5_1 => { - ggml::quantize_q5_1(&data_f32, tensor.n_elements, tensor.dims[0]) - } - QuantizationTarget::Q8_0 => { - ggml::quantize_q8_0(&data_f32, tensor.n_elements, tensor.dims[0]) - } - }; - let new_data = result.output; - - let mut history_new = vec![]; - for (i, val) in result.history.iter().enumerate() { - self.history_all[i] += val; - history_new.push(*val as f32 / tensor.n_elements as f32); - } - - (self.progress_callback)(QuantizeProgress::TensorQuantized { - name: tensor_name, - original_size: raw_data.len(), - reduced_size: new_data.len(), - history: history_new, - }); - - self.total_size_new += new_data.len(); - - (self.quantization_target.into(), new_data) - } else { - (self.progress_callback)(QuantizeProgress::TensorSkipped { - name: tensor_name, - size: raw_data.len(), - }); - self.total_size_new += raw_data.len(); - (tensor.element_type, raw_data) - }; - - Ok(TensorSaveInfo { - n_dims: tensor.n_dims, - dims: tensor.dims, - element_type, - data, - }) - } -} +// total_size_original: 0, +// total_size_new: 0, +// history_all: vec![0; 16], +// } +// } +// } +// impl SaveHandler +// for QuantizeSaver<'_, F, H, R> +// { +// fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), QuantizeError> { +// self.hyperparameters +// .write_ggml(writer) +// .map_err(QuantizeError::HyperparametersWriteError)?; +// Ok(()) +// } + +// fn tensor_data(&mut self, tensor_name: &str) -> Result { +// let tensor = self.tensors.get(tensor_name).expect( +// "tensor not found; should be impossible due to handler being populated from loader", +// ); + +// (self.progress_callback)(QuantizeProgress::TensorLoading { +// name: tensor_name, +// dims: tensor.dims, +// n_elements: tensor.n_elements, +// element_type: tensor.element_type, +// }); + +// // Quantize only 2D tensors +// let quantize = tensor.n_dims == 2 +// && self.to_quantize.iter().any(|re| re.is_match(tensor_name)) +// && !self.to_skip.iter().any(|re| re.is_match(tensor_name)); +// let raw_data = tensor.read_data(self.source_reader)?; + +// if quantize && !matches!(tensor.element_type, ggml::Type::F32 | ggml::Type::F16) { +// return Err(QuantizeError::UnsupportedElementType { +// element_type: tensor.element_type, +// }); +// } + +// self.total_size_original += raw_data.len(); + +// let (element_type, data) = if quantize { +// (self.progress_callback)(QuantizeProgress::TensorQuantizing { name: tensor_name }); + +// let data_f32: Vec = match tensor.element_type { +// ggml::Type::F32 => raw_data +// .chunks_exact(4) +// .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap())) +// .collect(), +// ggml::Type::F16 => raw_data +// .chunks_exact(2) +// .map(|chunk| { +// f16::from_bits(u16::from_le_bytes(chunk.try_into().unwrap())).to_f32() +// }) +// .collect(), +// _ => unreachable!(), +// }; + +// let result = match self.quantization_target { +// QuantizationTarget::Q4_0 => { +// ggml::quantize_q4_0(&data_f32, tensor.n_elements, tensor.dims[0]) +// } +// QuantizationTarget::Q4_1 => { +// ggml::quantize_q4_1(&data_f32, tensor.n_elements, tensor.dims[0]) +// } +// QuantizationTarget::Q5_0 => { +// ggml::quantize_q5_0(&data_f32, tensor.n_elements, tensor.dims[0]) +// } +// QuantizationTarget::Q5_1 => { +// ggml::quantize_q5_1(&data_f32, tensor.n_elements, tensor.dims[0]) +// } +// QuantizationTarget::Q8_0 => { +// ggml::quantize_q8_0(&data_f32, tensor.n_elements, tensor.dims[0]) +// } +// }; +// let new_data = result.output; + +// let mut history_new = vec![]; +// for (i, val) in result.history.iter().enumerate() { +// self.history_all[i] += val; +// history_new.push(*val as f32 / tensor.n_elements as f32); +// } + +// (self.progress_callback)(QuantizeProgress::TensorQuantized { +// name: tensor_name, +// original_size: raw_data.len(), +// reduced_size: new_data.len(), +// history: history_new, +// }); + +// self.total_size_new += new_data.len(); + +// (self.quantization_target.into(), new_data) +// } else { +// (self.progress_callback)(QuantizeProgress::TensorSkipped { +// name: tensor_name, +// size: raw_data.len(), +// }); +// self.total_size_new += raw_data.len(); +// (tensor.element_type, raw_data) +// }; + +// Ok(TensorSaveInfo { +// n_dims: tensor.n_dims, +// dims: tensor.dims, +// element_type, +// data, +// }) +// } +// } diff --git a/crates/llm-base/src/util.rs b/crates/llm-base/src/util.rs index 55cda41c..586c5b6f 100644 --- a/crates/llm-base/src/util.rs +++ b/crates/llm-base/src/util.rs @@ -1,11 +1,6 @@ //! Utilities for interacting with LLMs and loading them. pub use ggml::util::*; -use std::{ - io::BufRead, - path::{Path, PathBuf}, -}; - /// NOTE: The original code relies in promotion rules and automatic cast between /// int to float. What we do instead is use this macro to convert every term of /// the multiplication to f64, which should have enough precision bits to hold @@ -22,15 +17,6 @@ macro_rules! mulf { } use memmap2::{Mmap, MmapAsRawDesc, MmapOptions}; -use thiserror::Error; - -use crate::{FileType, LoadError}; - -/// Read the filetype from a reader. -pub fn read_filetype(reader: &mut dyn BufRead) -> Result { - let ftype = read_i32(reader)?; - FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype)) -} /// Used to buffer incoming tokens until they produce a valid string of UTF-8 text. /// @@ -73,69 +59,6 @@ impl TokenUtf8Buffer { } } -#[derive(Error, Debug)] -/// Errors encountered during the loading process. -pub enum FindAllModelFilesError { - #[error("no parent path for {path:?}")] - /// There is no parent path for a given path. - NoParentPath { - /// The path without a parent. - path: PathBuf, - }, - #[error("non-specific I/O error")] - /// A non-specific IO error. - IO(#[from] std::io::Error), -} - -/// Find all the files related to a model. -pub fn find_all_model_files(main_path: &Path) -> Result, FindAllModelFilesError> { - let mut main_path_parent = - main_path - .parent() - .ok_or_else(|| FindAllModelFilesError::NoParentPath { - path: main_path.to_owned(), - })?; - if main_path_parent.to_str() == Some("") { - main_path_parent = Path::new("."); - } - Ok(collect_related_paths( - main_path, - std::fs::read_dir(main_path_parent)? - .filter_map(Result::ok) - .map(|de| de.path()), - )) -} - -fn collect_related_paths( - main_path: &Path, - directory_paths: impl Iterator, -) -> Vec { - let main_filename = main_path.file_name().and_then(|p| p.to_str()); - - let mut paths: Vec = directory_paths - .filter(|p| { - p.file_name() - .and_then(|p| p.to_str()) - .zip(main_filename) - .map(|(part_filename, main_filename)| { - match part_filename.strip_prefix(main_filename) { - Some(suffix) => { - suffix.is_empty() - || (suffix - .strip_prefix('.') - .map(|s| s.parse::().is_ok()) - .unwrap_or(false)) - } - None => false, - } - }) - .unwrap_or(false) - }) - .collect(); - paths.sort(); - paths -} - /// mmap with MAP_POPULATE pub fn mmap_populate(file: T) -> Result { unsafe { MmapOptions::new().populate().map(file) } @@ -156,27 +79,6 @@ pub fn softmax(logits: &[f32]) -> Vec { mod tests { use super::*; - #[test] - fn test_collect_related_paths() { - let main_path = PathBuf::from("/models/llama.bin"); - let directory_paths = [ - "/models/llama.bin", - "/models/llama.bin.1", - "/models/llama.bin.2", - "/models/llama.bin.tmp", - ] - .map(PathBuf::from); - let expected_paths = [ - "/models/llama.bin", - "/models/llama.bin.1", - "/models/llama.bin.2", - ] - .map(PathBuf::from); - - let output_paths = collect_related_paths(&main_path, directory_paths.into_iter()); - assert_eq!(expected_paths.as_slice(), output_paths); - } - #[test] fn test_valid_utf8() { let mut buffer = TokenUtf8Buffer::new(); diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index a70f315f..f85df602 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -1,13 +1,15 @@ //! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem. #![deny(missing_docs)] -use std::error::Error; - use llm_base::{ - ggml::{self}, + ggml::{ + self, + format::gguf::{Metadata, MetadataValue, MetadataValueTypeFromRustType}, + }, model::{common, HyperparametersWriteError}, - util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, + FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, + MetadataExt, ModelContext, ModelParameters, ModelTensorLoader, OutputRequest, Regex, TokenId, + Tokenizer, }; /// The LLaMA model. Ref: [Introducing LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) @@ -18,7 +20,6 @@ pub struct Llama { params: ModelParameters, hyperparameters: Hyperparameters, tokenizer: Tokenizer, - _version: LlamaModelType, // model-global weights // weighted token embeddings wte: ggml::Tensor, @@ -40,12 +41,12 @@ unsafe impl Sync for Llama {} impl KnownModel for Llama { type Hyperparameters = Hyperparameters; - fn new( + fn new( mut hyperparameters: Self::Hyperparameters, params: ModelParameters, tokenizer: Tokenizer, - tensor_loader: impl TensorLoader, - ) -> Result { + tensor_loader: ModelTensorLoader, + ) -> Result { let mut tl = tensor_loader; // model-global weights @@ -58,7 +59,7 @@ impl KnownModel for Llama { let mut layers = Vec::new(); - for i in 0..hyperparameters.n_layer { + for i in 0..hyperparameters.block_count { let backend = params.backend(i); let layer = Layer { @@ -94,32 +95,21 @@ impl KnownModel for Llama { } let context = tl.finish(); - // TODO: read from file - let mut version = match hyperparameters.n_layer { - 26 => LlamaModelType::Model3b, - 32 => LlamaModelType::Model7b, - 40 => LlamaModelType::Model13b, - 60 => LlamaModelType::Model30b, - 80 => LlamaModelType::Model65b, - _ => LlamaModelType::Model7b, // anything < 32 - }; // TODO: temporary fix for 70B models if let Some(n_gqa) = params.n_gqa { - if hyperparameters.n_layer >= 80 { + if hyperparameters.block_count >= 80 { assert_eq!( - hyperparameters.n_head % n_gqa, + hyperparameters.head_count % n_gqa, 0, "assuming 70B Llama2 model based on GQA == 8" ); - hyperparameters.n_head_kv = hyperparameters.n_head / n_gqa; - version = LlamaModelType::Model70b; + hyperparameters.head_count_kv = hyperparameters.head_count / n_gqa; } } Ok(Self { hyperparameters, params, - _version: version, tokenizer, wte, norm, @@ -134,9 +124,9 @@ impl KnownModel for Llama { InferenceSession::new( config, &self.params, - self.hyperparameters.n_layer, - self.hyperparameters.n_embd, - self.hyperparameters.n_vocab, + self.hyperparameters.block_count, + self.hyperparameters.embedding_length, + self.hyperparameters.vocabulary_count, ) } @@ -152,16 +142,15 @@ impl KnownModel for Llama { let ctx_size = self.params.context_size; let Hyperparameters { - n_vocab, - n_embd, - n_mult: _, - n_head, - n_head_kv, - n_layer, + vocabulary_count, + embedding_length, + head_count, + head_count_kv, + block_count, n_rot, file_type: _, } = self.hyperparameters; - let n_embd_gqa = n_embd / (n_head / n_head_kv); + let embedding_length_gqa = embedding_length / (head_count / head_count_kv); let outputs = session.compute(self.context.clone(), input_tokens, |builder| { let mut ctx0 = builder.ctx0.borrow_mut(); @@ -171,7 +160,7 @@ impl KnownModel for Llama { let mut gf = ctx0.create_compute_graph(); - for il in 0..n_layer { + for il in 0..block_count { ctx0.set_offloading(self.params.should_offload(il)); let input_self_attention = input_layer.share(); @@ -192,8 +181,8 @@ impl KnownModel for Llama { .op_rope_inplace( &ctx0.op_reshape_3d( &ctx0.op_mul_mat(&self.layers[il].wq, ¤t), - n_embd / n_head, - n_head, + embedding_length / head_count, + head_count, input_len, ), session_len, @@ -206,8 +195,8 @@ impl KnownModel for Llama { .op_rope_inplace( &ctx0.op_reshape_3d( &ctx0.op_mul_mat(&self.layers[il].wk, ¤t), - n_embd / n_head, - n_head_kv, + embedding_length / head_count, + head_count_kv, input_len, ), session_len, @@ -218,24 +207,25 @@ impl KnownModel for Llama { .set_name("Kcur"); // store key and value to memory - // compute the transposed [N, n_embd] V matrix + // compute the transposed [N, embedding_length] V matrix let v_current = ctx0.op_transpose(&ctx0.op_reshape_2d( &ctx0.op_mul_mat(&self.layers[il].wv, ¤t), - n_embd_gqa, + embedding_length_gqa, input_len, )); let k = ctx0.op_view_1d( builder.memory_k, - input_len * n_embd_gqa, - (builder.memory_k.element_size() * n_embd_gqa) * (il * ctx_size + session_len), + input_len * embedding_length_gqa, + (builder.memory_k.element_size() * embedding_length_gqa) + * (il * ctx_size + session_len), ); let v = ctx0.op_view_2d( builder.memory_v, - (input_len, n_embd_gqa), + (input_len, embedding_length_gqa), ctx_size * builder.memory_v.element_size(), - (il * ctx_size) * builder.memory_v.element_size() * n_embd_gqa + (il * ctx_size) * builder.memory_v.element_size() * embedding_length_gqa + session_len * builder.memory_v.element_size(), ); @@ -250,11 +240,13 @@ impl KnownModel for Llama { &ctx0.op_reshape_3d( &ctx0.op_view_1d( builder.memory_k, - (session_len + input_len) * n_embd_gqa, - il * ctx_size * builder.memory_k.element_size() * n_embd_gqa, + (session_len + input_len) * embedding_length_gqa, + il * ctx_size + * builder.memory_k.element_size() + * embedding_length_gqa, ), - n_embd / n_head, - n_head_kv, + embedding_length / head_count, + head_count_kv, session_len + input_len, ), (0, 2, 1, 3), @@ -264,10 +256,10 @@ impl KnownModel for Llama { // K * Q let k_q = ctx0.op_mul_mat(&k, &q).set_name("KQ"); - // KQ_scaled = KQ / sqrt(n_embd/n_head) + // KQ_scaled = KQ / sqrt(embedding_length/head_count) let kq_scale = ctx0 - .new_f32(1.0 / ((n_embd as f32 / n_head as f32).sqrt())) - .set_name("1/sqrt(n_embd/n_head)"); + .new_f32(1.0 / ((embedding_length as f32 / head_count as f32).sqrt())) + .set_name("1/sqrt(embedding_length/head_count)"); let k_q_scaled = ctx0.op_scale_inplace(&k_q, &kq_scale).set_name("KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) @@ -280,16 +272,21 @@ impl KnownModel for Llama { .op_soft_max_inplace(&k_q_masked) .set_name("KQ_soft_max"); - // split cached V into n_head heads + // split cached V into head_count heads let v = ctx0 .op_view_3d( builder.memory_v, - (session_len + input_len, n_embd / n_head, n_head_kv), + ( + session_len + input_len, + embedding_length / head_count, + head_count_kv, + ), ( ctx_size * builder.memory_v.element_size(), - ctx_size * builder.memory_v.element_size() * n_embd / n_head, + ctx_size * builder.memory_v.element_size() * embedding_length + / head_count, ), - il * ctx_size * builder.memory_v.element_size() * n_embd_gqa, + il * ctx_size * builder.memory_v.element_size() * embedding_length_gqa, ) .set_name("V"); @@ -298,11 +295,11 @@ impl KnownModel for Llama { // KQV_merged = KQV.permute(0, 2, 1, 3) let k_q_v_merged = ctx0.op_permute(&k_q_v, (0, 2, 1, 3)).set_name("KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) + // cur = KQV_merged.contiguous().view(embedding_length, N) current = ctx0 .op_cpy( &k_q_v_merged, - &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), + &ctx0.new_tensor_2d(ggml::Type::F32, embedding_length, input_len), ) .set_name("KQV_merged_contiguous"); @@ -362,9 +359,14 @@ impl KnownModel for Llama { }); // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); + common::read_last_token(session, &outputs.result, vocabulary_count, input_len); + common::extract_logits(output_request, &outputs.result, vocabulary_count, input_len); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + embedding_length, + input_len, + ); } fn hyperparameters(&self) -> &Self::Hyperparameters { @@ -404,69 +406,51 @@ impl KnownModel for Llama { #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct Hyperparameters { /// Size of the model's vocabulary - pub n_vocab: usize, + pub vocabulary_count: usize, /// Size of the model's embedding layer - pub n_embd: usize, - /// n_mult - pub n_mult: usize, - /// n_head - pub n_head: usize, - /// grouped-query attention - pub n_head_kv: usize, + pub embedding_length: usize, + /// The number of attention heads + pub head_count: usize, + /// The number of grouped-query attention heads + pub head_count_kv: usize, /// Number of layers in the model - pub n_layer: usize, + pub block_count: usize, /// n_rot pub n_rot: usize, /// file_type - pub file_type: FileType, + pub file_type: Option, } impl llm_base::Hyperparameters for Hyperparameters { - fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { - let n_vocab = util::read_i32(reader)?.try_into()?; - let n_embd = util::read_i32(reader)?.try_into()?; - let n_mult = util::read_i32(reader)?.try_into()?; - let n_head = util::read_i32(reader)?.try_into()?; - let n_layer = util::read_i32(reader)?.try_into()?; - let n_rot = util::read_i32(reader)?.try_into()?; - let file_type = util::read_filetype(reader)?; - - // Defaults to multi-head attention where n_head_kv == n_heads - let n_head_kv = n_head; - - Ok(Hyperparameters { - n_head, - n_head_kv, - n_vocab, - n_embd, - n_mult, - n_layer, - n_rot, - file_type, + fn read_gguf(metadata: &Metadata) -> Result { + Ok(Self { + // TODO: handle models without an embedded vocabulary + vocabulary_count: metadata + .fallible_typed_get("tokenizer.ggml.tokens", |v| v.as_array())? + .len(), + embedding_length: metadata.fallible_get_countable("llama.embedding_length")?, + head_count: metadata.fallible_get_countable("llama.attention.head_count")?, + head_count_kv: metadata.fallible_get_countable("llama.attention.head_count_kv")?, + block_count: metadata.fallible_get_countable("llama.block_count")?, + file_type: metadata + .get("general.file_type") + .and_then(|v| v.as_uint32()) + .map(|v| FileType::try_from(v as i32)) + .transpose()?, + n_rot: todo!(), }) } - fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { - util::write_i32(writer, self.n_vocab.try_into()?)?; - util::write_i32(writer, self.n_embd.try_into()?)?; - util::write_i32(writer, self.n_mult.try_into()?)?; - util::write_i32(writer, self.n_head.try_into()?)?; - util::write_i32(writer, self.n_layer.try_into()?)?; - util::write_i32(writer, self.n_rot.try_into()?)?; - util::write_i32(writer, self.file_type.into())?; - Ok(()) - } - - fn n_vocabulary(&self) -> usize { - self.n_vocab + fn write_gguf(&self, metadata: &mut Metadata) -> Result<(), HyperparametersWriteError> { + todo!() } fn file_type(&self) -> Option { - Some(self.file_type) + self.file_type } fn file_type_mut(&mut self) -> Option<&mut FileType> { - Some(&mut self.file_type) + self.file_type.as_mut() } } @@ -486,13 +470,3 @@ struct Layer { w2: ggml::Tensor, w3: ggml::Tensor, } - -/// Available Llama models -enum LlamaModelType { - Model3b, - Model7b, - Model13b, - Model30b, - Model65b, - Model70b, -} From 178a0fba9a098bb7cb664f7a8aced4793a5f74e9 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 29 Aug 2023 01:05:36 +0200 Subject: [PATCH 11/33] fix(cli): use info log level --- binaries/llm-cli/src/main.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index a623d721..2d957822 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -17,7 +17,11 @@ mod util; fn main() -> eyre::Result<()> { tracing_subscriber::fmt() .with_writer(std::io::stderr) - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_env_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing_subscriber::filter::LevelFilter::INFO.into()) + .from_env_lossy(), + ) .with_ansi(std::io::stderr().is_terminal()) .init(); From 2a9417a3eb4ce3b6afa05f94aa51e5554d4ec6c8 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 29 Aug 2023 01:07:34 +0200 Subject: [PATCH 12/33] wip: successfully load a llama2 gguf* * with some heavy caveats, see the PR --- binaries/llm-cli/src/cli_args.rs | 1 - binaries/llm-cli/src/main.rs | 70 ++++++-------- crates/ggml/src/format/gguf/metadata.rs | 93 ++++++++++++++++++ crates/ggml/src/format/gguf/mod.rs | 17 ++++ crates/llm-base/src/loader.rs | 54 ++++++----- crates/llm-base/src/lora.rs | 1 + crates/llm-base/src/model/mod.rs | 3 - crates/llm-base/src/tokenizer/mod.rs | 32 ++++--- crates/llm/src/lib.rs | 6 +- crates/models/llama/src/lib.rs | 121 +++++++++++------------- 10 files changed, 251 insertions(+), 147 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 71440d47..68839c67 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -519,7 +519,6 @@ impl ModelLoad { use_gpu, gpu_layers: self.gpu_layers, rope_overrides: self.rope_scaling.to_rope_arguments(), - n_gqa: None, }; let mut sp = Some(spinoff::Spinner::new( diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 2d957822..20dd2d91 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -8,6 +8,7 @@ use clap::Parser; use cli_args::Args; use color_eyre::eyre::{self, Context, ContextCompat}; use is_terminal::IsTerminal; +use llm::ggml_format::gguf; mod cli_args; mod interactive; @@ -132,56 +133,45 @@ fn perplexity(args: &cli_args::Perplexity) -> eyre::Result<()> { } fn info(args: &cli_args::Info) -> eyre::Result<()> { - struct InfoVisitor<'a>(&'a cli_args::Info); - impl llm::ModelArchitectureVisitor> for InfoVisitor<'_> { - fn visit(&mut self) -> eyre::Result<()> { - let args = self.0; + let model_path = &args.model_and_tokenizer.model_path; - let model_path = &args.model_and_tokenizer.model_path; - let tokenizer = args.model_and_tokenizer.to_source()?.retrieve(model_path)?; + let file = File::open(model_path)?; + let mut reader = BufReader::new(&file); + let gguf = gguf::Gguf::load(&mut reader)?; - let file = File::open(model_path)?; - let mut reader = BufReader::new(&file); - let mut loader: llm::Loader = - llm::Loader::new(tokenizer, |_| { - // We purposely do not print progress here, as we are only interested in the metadata - }); + log::info!("Non-array parameters:"); + for (metadata_key, metadata_value) in &gguf.metadata { + if metadata_value.as_array().is_some() { + continue; + } - llm::ggml_format::ggml::load(&mut reader, &mut loader)?; + log::info!("- {}: {:?}", metadata_key, metadata_value); + } - log::info!("Container type: {:?}", loader.container_type); - log::info!("Hyperparameters: {:?}", loader.hyperparameters); - log::info!("Tokenizer vocabulary size: {}", loader.tokenizer.len()); + if let Some((tokens, _scores)) = gguf.tokenizer_embedded() { + log::info!("Embedded tokenizer vocabulary size: {}", tokens.len()); - if args.tokenizer { - log::info!("Tokens:"); - for i in 0..loader.tokenizer.len() { - log::info!("- {}: {}", i, utf8_or_array(&loader.tokenizer.token(i))); - } - } - - if args.tensors { - log::info!("Tensors:"); - for (name, tensor) in &loader.tensors { - log::info!("- {} ({:?} {:?})", name, tensor.element_type, tensor.dims()); - } - } - - fn utf8_or_array(token: &[u8]) -> String { - std::str::from_utf8(token) - .map(|s| s.to_owned()) - .unwrap_or(format!("{:?}", token)) + if args.tokenizer { + log::info!("Embedded tokenizer vocabulary:"); + for (i, token) in tokens.iter().enumerate() { + log::info!("- {}: {}", i, token); } + } + } - Ok(()) + if args.tensors { + log::info!("Tensors:"); + for (name, tensor) in &gguf.tensor_infos { + log::info!( + "- {} ({:?} {:?})", + name, + tensor.element_type, + tensor.dimensions + ); } } - args.model_and_tokenizer - .architecture - .model_architecture - .wrap_err("a model architecture is required at present")? - .visit(&mut InfoVisitor(args)) + Ok(()) } fn prompt_tokens(args: &cli_args::PromptTokens) -> eyre::Result<()> { diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index 39f3713b..a39301f4 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -307,6 +307,99 @@ pub enum MetadataArrayValue { Int64(Vec), Float64(Vec), } +// Public +impl MetadataArrayValue { + pub fn as_uint8_array(&self) -> Option<&[u8]> { + match self { + Self::UInt8(v) => Some(v), + _ => None, + } + } + + pub fn as_int8_array(&self) -> Option<&[i8]> { + match self { + Self::Int8(v) => Some(v), + _ => None, + } + } + + pub fn as_uint16_array(&self) -> Option<&[u16]> { + match self { + Self::UInt16(v) => Some(v), + _ => None, + } + } + + pub fn as_int16_array(&self) -> Option<&[i16]> { + match self { + Self::Int16(v) => Some(v), + _ => None, + } + } + + pub fn as_uint32_array(&self) -> Option<&[u32]> { + match self { + Self::UInt32(v) => Some(v), + _ => None, + } + } + + pub fn as_int32_array(&self) -> Option<&[i32]> { + match self { + Self::Int32(v) => Some(v), + _ => None, + } + } + + pub fn as_float32_array(&self) -> Option<&[f32]> { + match self { + Self::Float32(v) => Some(v), + _ => None, + } + } + + pub fn as_bool_array(&self) -> Option<&[bool]> { + match self { + Self::Bool(v) => Some(v), + _ => None, + } + } + + pub fn as_string_array(&self) -> Option<&[String]> { + match self { + Self::String(v) => Some(v), + _ => None, + } + } + + pub fn as_array_array(&self) -> Option<&[MetadataArrayValue]> { + match self { + Self::Array(v) => Some(v), + _ => None, + } + } + + pub fn as_uint64_array(&self) -> Option<&[u64]> { + match self { + Self::UInt64(v) => Some(v), + _ => None, + } + } + + pub fn as_int64_array(&self) -> Option<&[i64]> { + match self { + Self::Int64(v) => Some(v), + _ => None, + } + } + + pub fn as_float64_array(&self) -> Option<&[f64]> { + match self { + Self::Float64(v) => Some(v), + _ => None, + } + } +} impl MetadataArrayValue { fn read_value(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { let value_type = MetadataValueType::try_from(util::read_u32(reader)?) diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index bf0afcdd..5c2152a7 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -102,6 +102,23 @@ impl Gguf { tensor_data_position, }) } + + // TODO: consider moving this to a `ModelGguf` abstraction that wraps this + // and provides a model-specific interface + pub fn tokenizer_embedded(&self) -> Option<(&[String], &[f32])> { + let tokens = self + .metadata + .get("tokenizer.ggml.tokens")? + .as_array()? + .as_string_array()?; + let scores = self + .metadata + .get("tokenizer.ggml.scores")? + .as_array()? + .as_float32_array()?; + + Some((tokens, scores)) + } } struct GgufContext { diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 50cdc725..e5330410 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -8,8 +8,8 @@ use std::{ }; use crate::{ - Hyperparameters, KnownModel, LoraAdapter, ModelContext, ModelParameters, TokenizerLoadError, - TokenizerSource, + Hyperparameters, KnownModel, LoraAdapter, ModelContext, ModelParameters, Tokenizer, + TokenizerLoadError, TokenizerSource, }; use ggml::{ format::gguf::{ @@ -38,7 +38,7 @@ impl From for i32 { } } impl TryFrom for FileType { - type Error = (); + type Error = LoadError; fn try_from(value: i32) -> Result { let format = FileTypeFormat::try_from( @@ -99,7 +99,7 @@ pub enum FileTypeFormat { MostlyQ6_K, } impl TryFrom for FileTypeFormat { - type Error = (); + type Error = LoadError; fn try_from(value: ggml::sys::llama::llama_ftype) -> Result { use ggml::sys::llama::*; @@ -121,7 +121,10 @@ impl TryFrom for FileTypeFormat { LLAMA_FTYPE_MOSTLY_Q5_K_S => Ok(FileTypeFormat::MostlyQ5_K_S), LLAMA_FTYPE_MOSTLY_Q5_K_M => Ok(FileTypeFormat::MostlyQ5_K_M), LLAMA_FTYPE_MOSTLY_Q6_K => Ok(FileTypeFormat::MostlyQ6_K), - _ => Err(()), + #[allow(clippy::unnecessary_cast)] + _ => Err(LoadError::UnsupportedFileType { + file_type_format: value as u32, + }), } } } @@ -269,14 +272,8 @@ pub enum LoadError { path: PathBuf, }, /// The tokenizer could not be loaded. - #[error("could not load tokenizer {path:?}: {error}")] - TokenizerLoadFail { - /// The path of the tokenizer. - path: PathBuf, - - /// The error that occurred. - error: Box, - }, + #[error("could not load tokenizer: {0}")] + TokenizerLoadFail(#[from] TokenizerLoadError), /// The quantization version was missing, despite this model containing quantized tensors. #[error("quantization version was missing, despite model containing quantized tensors")] MissingQuantizationVersion, @@ -312,14 +309,12 @@ pub enum LoadError { /// The actual type. actual_type: MetadataValueType, }, -} -impl From for LoadError { - fn from(value: TokenizerLoadError) -> Self { - LoadError::TokenizerLoadFail { - path: value.path, - error: value.error, - } - } + /// The file type within the model was not supported by this version of `llm`. + #[error("file type {file_type_format} is not supported")] + UnsupportedFileType { + /// The file type format (ignoring the quantization version) that was encountered. + file_type_format: u32, + }, } impl LoadError { #[doc(hidden)] @@ -422,7 +417,7 @@ pub fn load( let mut reader = BufReader::new(&file); log::trace!("Read model file from {:?}", path); - let tokenizer = tokenizer_source.retrieve(path)?; + let mut tokenizer = tokenizer_source.retrieve(path)?; let gguf = gguf::Gguf::load(&mut reader).map_err(|e| LoadError::from_gguf(e, path.to_owned()))?; @@ -453,6 +448,17 @@ pub fn load( } } + // Populate the embedded tokenizer if required + if let Tokenizer::Embedded(tokenizer) = &mut tokenizer { + if let Some((tokens, scores)) = gguf.tokenizer_embedded() { + for (i, (token, score)) in tokens.iter().zip(scores.iter()).enumerate() { + tokenizer.push_token(i as u32, token.as_bytes().to_vec(), *score); + } + } else { + return Err(TokenizerLoadError::NoTokenizerFound.into()); + } + } + let use_mmap = params.prefer_mmap && params.lora_adapters.is_none(); let ctx_size = gguf @@ -513,6 +519,7 @@ pub fn load( let hyperparameters = ::read_gguf(&gguf.metadata)?; let tl = ModelTensorLoader { tensor_loader: TensorLoader { + path, file: &mut file, gguf: &gguf, context, @@ -575,6 +582,7 @@ impl ModelTensorLoader<'_> { } pub(crate) struct TensorLoader<'a> { + pub path: &'a Path, pub file: &'a mut File, pub context: Context, pub gguf: &'a gguf::Gguf, @@ -587,7 +595,7 @@ impl TensorLoader<'_> { .get(name) .ok_or(LoadError::UnknownTensor { tensor_name: String::from(name), - path: Default::default(), + path: self.path.to_path_buf(), })?; let ty = info.element_type; diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index 4b5a6c9c..d44da997 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -105,6 +105,7 @@ impl LoraAdapter { // TODO: test if GPU can be enabled (make it configurable) let patch_context = ggml::Context::new_with_allocate(patch_context_size); let mut loader = TensorLoader { + path: &self.path, file: &mut self.file, context: patch_context, gguf: &self.gguf, diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index e383db1d..768cdf9d 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -207,8 +207,6 @@ pub struct ModelParameters { pub gpu_layers: Option, /// The arguments/overrides to pass to the [custom RoPE](https://arxiv.org/pdf/2306.15595.pdf) function, if it is used by the model. pub rope_overrides: Option, - /// Enables gouped-query attention for Llama-2 70B model - pub n_gqa: Option, } impl Default for ModelParameters { @@ -220,7 +218,6 @@ impl Default for ModelParameters { use_gpu: false, gpu_layers: None, rope_overrides: None, - n_gqa: None, } } } diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 48be8926..0469098e 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -20,7 +20,7 @@ pub(crate) type TokenScore = f32; #[derive(Error, Debug)] /// Errors related to tokenization. pub enum TokenizationError { - #[error("an invalid token was encountered during tokenization")] + #[error("an invalid token was encountered during tokenization: {error}")] /// During tokenization, one of the produced tokens was invalid / zero. TokenizationFailed { #[source] @@ -35,16 +35,26 @@ pub enum TokenizationError { #[derive(Error, Debug)] /// Errors related to loading the tokenizer. #[error("error loading tokenizer from {path}: {error}")] -pub struct TokenizerLoadError { - /// The path to the tokenizer. - pub path: PathBuf, - /// The error that occurred during loading. - pub error: Box, +pub enum TokenizerLoadError { + #[error("error loading Hugging Face tokenizer from {path}: {error}")] + /// An error occurred while loading a Hugging Face tokenizer. + HuggingFaceTokenizerError { + /// The path to the tokenizer. + path: PathBuf, + /// The error that occurred during loading. + error: Box, + }, + #[error("no tokenizer was found, including in the model file")] + /// No tokenizer was found, including in the model file. + NoTokenizerFound, } impl TokenizerLoadError { - fn new(path: impl Into, error: impl Into>) -> Self { - Self { + fn huggingface_error( + path: impl Into, + error: impl Into>, + ) -> Self { + Self::HuggingFaceTokenizerError { path: path.into(), error: error.into(), } @@ -84,19 +94,19 @@ impl TokenizerSource { #[cfg(feature = "tokenizers-remote")] Self::HuggingFaceRemote(identifier) => HuggingFaceTokenizer::new( tokenizers::Tokenizer::from_pretrained(&identifier, None) - .map_err(|error| TokenizerLoadError::new(model_path, error))?, + .map_err(|error| TokenizerLoadError::huggingface_error(model_path, error))?, ) .into(), Self::HuggingFaceTokenizerFile(path) => HuggingFaceTokenizer::new( tokenizers::Tokenizer::from_file(&path) - .map_err(|error| TokenizerLoadError::new(path, error))?, + .map_err(|error| TokenizerLoadError::huggingface_error(path, error))?, ) .into(), Self::HuggingFaceTokenizerString(s) => HuggingFaceTokenizer::new( tokenizers::Tokenizer::from_str(&s) - .map_err(|error| TokenizerLoadError::new(model_path, error))?, + .map_err(|error| TokenizerLoadError::huggingface_error(model_path, error))?, ) .into(), diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 8d62f40a..b15bbdf3 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -84,7 +84,7 @@ pub use llm_base::{ FileMagic, FileType, FileTypeFormat, Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, - InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, + InvalidTokenBias, KnownModel, LoadError, LoadProgress, Model, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, RewindError, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource, @@ -232,9 +232,7 @@ pub fn load_dynamic( )?)) } - let architecture = architecture.ok_or_else(|| LoadError::MissingModelArchitecture { - path: path.to_owned(), - })?; + let architecture = architecture.expect("TODO: This option will be removed soon"); struct LoadVisitor<'a, F: FnMut(LoadProgress)> { path: &'a Path, diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index f85df602..09955f4e 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -29,7 +29,7 @@ pub struct Llama { output: ggml::Tensor, // weights for the model - layers: Vec, + blocks: Vec, // must be kept alive for the model context: ModelContext, @@ -50,63 +50,51 @@ impl KnownModel for Llama { let mut tl = tensor_loader; // model-global weights - let wte = tl.load("tok_embeddings.weight")?; + let wte = tl.load("token_embd.weight")?; let backend = params.backend(0); - let norm = tl.load("norm.weight")?.transfer_to(backend); + let norm = tl.load("output_norm.weight")?.transfer_to(backend); let output = tl.load("output.weight")?.transfer_to(backend); - let mut layers = Vec::new(); + let mut blocks = Vec::new(); for i in 0..hyperparameters.block_count { let backend = params.backend(i); - let layer = Layer { - attention_norm: tl - .load(&format!("layers.{i}.attention_norm.weight"))? + let block = Block { + attn_n: tl + .load(&format!("blk.{i}.attn_norm.weight"))? .transfer_to(backend), - wq: tl - .load(&format!("layers.{i}.attention.wq.weight"))? + attn_q: tl + .load(&format!("blk.{i}.attn_q.weight"))? .transfer_to(backend), - wk: tl - .load(&format!("layers.{i}.attention.wk.weight"))? + attn_k: tl + .load(&format!("blk.{i}.attn_k.weight"))? .transfer_to(backend), - wv: tl - .load(&format!("layers.{i}.attention.wv.weight"))? + attn_v: tl + .load(&format!("blk.{i}.attn_v.weight"))? .transfer_to(backend), - wo: tl - .load(&format!("layers.{i}.attention.wo.weight"))? + attn_output: tl + .load(&format!("blk.{i}.attn_output.weight"))? .transfer_to(backend), ffn_norm: tl - .load(&format!("layers.{i}.ffn_norm.weight"))? + .load(&format!("blk.{i}.ffn_norm.weight"))? .transfer_to(backend), - w1: tl - .load(&format!("layers.{i}.feed_forward.w1.weight"))? + ffn_gate: tl + .load(&format!("blk.{i}.ffn_gate.weight"))? .transfer_to(backend), - w2: tl - .load(&format!("layers.{i}.feed_forward.w2.weight"))? + ffn_down: tl + .load(&format!("blk.{i}.ffn_down.weight"))? .transfer_to(backend), - w3: tl - .load(&format!("layers.{i}.feed_forward.w3.weight"))? + ffn_up: tl + .load(&format!("blk.{i}.ffn_up.weight"))? .transfer_to(backend), }; - layers.push(layer); + blocks.push(block); } let context = tl.finish(); - // TODO: temporary fix for 70B models - if let Some(n_gqa) = params.n_gqa { - if hyperparameters.block_count >= 80 { - assert_eq!( - hyperparameters.head_count % n_gqa, - 0, - "assuming 70B Llama2 model based on GQA == 8" - ); - hyperparameters.head_count_kv = hyperparameters.head_count / n_gqa; - } - } - Ok(Self { hyperparameters, params, @@ -114,7 +102,7 @@ impl KnownModel for Llama { wte, norm, output, - layers, + blocks, context, }) } @@ -147,10 +135,10 @@ impl KnownModel for Llama { head_count, head_count_kv, block_count, - n_rot, file_type: _, } = self.hyperparameters; - let embedding_length_gqa = embedding_length / (head_count / head_count_kv); + let embedding_length_gqa = + embedding_length / self.hyperparameters.grouped_query_attention(); let outputs = session.compute(self.context.clone(), input_tokens, |builder| { let mut ctx0 = builder.ctx0.borrow_mut(); @@ -172,21 +160,22 @@ impl KnownModel for Llama { current = ctx0.op_rms_norm(&input_layer); // cur = attention_norm * cur - current = ctx0.op_mul(¤t, &self.layers[il].attention_norm); + current = ctx0.op_mul(¤t, &self.blocks[il].attn_n); // self-attention // compute Q and K and RoPE them let overrides = self.params.rope_overrides.as_ref(); + let n_embd_head = embedding_length / head_count; let q_current = ctx0 .op_rope_inplace( &ctx0.op_reshape_3d( - &ctx0.op_mul_mat(&self.layers[il].wq, ¤t), - embedding_length / head_count, + &ctx0.op_mul_mat(&self.blocks[il].attn_q, ¤t), + n_embd_head, head_count, input_len, ), session_len, - n_rot, + n_embd_head, 0, overrides, ) @@ -194,13 +183,13 @@ impl KnownModel for Llama { let k_current = ctx0 .op_rope_inplace( &ctx0.op_reshape_3d( - &ctx0.op_mul_mat(&self.layers[il].wk, ¤t), - embedding_length / head_count, + &ctx0.op_mul_mat(&self.blocks[il].attn_k, ¤t), + n_embd_head, head_count_kv, input_len, ), session_len, - n_rot, + n_embd_head, 0, overrides, ) @@ -209,7 +198,7 @@ impl KnownModel for Llama { // store key and value to memory // compute the transposed [N, embedding_length] V matrix let v_current = ctx0.op_transpose(&ctx0.op_reshape_2d( - &ctx0.op_mul_mat(&self.layers[il].wv, ¤t), + &ctx0.op_mul_mat(&self.blocks[il].attn_v, ¤t), embedding_length_gqa, input_len, )); @@ -245,7 +234,7 @@ impl KnownModel for Llama { * builder.memory_k.element_size() * embedding_length_gqa, ), - embedding_length / head_count, + n_embd_head, head_count_kv, session_len + input_len, ), @@ -304,7 +293,7 @@ impl KnownModel for Llama { .set_name("KQV_merged_contiguous"); // projection (no bias) - current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); + current = ctx0.op_mul_mat(&self.blocks[il].attn_output, ¤t); ctx0.use_scratch(builder.get_scratch(1)); @@ -315,18 +304,18 @@ impl KnownModel for Llama { current = ctx0.op_rms_norm(&input_feed_forward); // cur = cur*ffn_norm(broadcasted) - current = ctx0.op_mul(¤t, &self.layers[il].ffn_norm); + current = ctx0.op_mul(¤t, &self.blocks[il].ffn_norm); - let tmp = ctx0.op_mul_mat(&self.layers[il].w3, ¤t); + let tmp = ctx0.op_mul_mat(&self.blocks[il].ffn_up, ¤t); - current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t); + current = ctx0.op_mul_mat(&self.blocks[il].ffn_gate, ¤t); // SILU activation current = ctx0.op_silu(¤t); current = ctx0.op_mul(¤t, &tmp); - current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t); + current = ctx0.op_mul_mat(&self.blocks[il].ffn_down, ¤t); current = ctx0.op_add(¤t, &input_feed_forward); @@ -415,12 +404,9 @@ pub struct Hyperparameters { pub head_count_kv: usize, /// Number of layers in the model pub block_count: usize, - /// n_rot - pub n_rot: usize, /// file_type pub file_type: Option, } - impl llm_base::Hyperparameters for Hyperparameters { fn read_gguf(metadata: &Metadata) -> Result { Ok(Self { @@ -437,7 +423,6 @@ impl llm_base::Hyperparameters for Hyperparameters { .and_then(|v| v.as_uint32()) .map(|v| FileType::try_from(v as i32)) .transpose()?, - n_rot: todo!(), }) } @@ -453,20 +438,26 @@ impl llm_base::Hyperparameters for Hyperparameters { self.file_type.as_mut() } } +impl Hyperparameters { + /// Returns the number of grouped-query attention heads. + pub fn grouped_query_attention(&self) -> usize { + self.head_count / self.head_count_kv + } +} -struct Layer { - attention_norm: ggml::Tensor, +struct Block { + attn_n: ggml::Tensor, - wq: ggml::Tensor, - wk: ggml::Tensor, - wv: ggml::Tensor, - wo: ggml::Tensor, + attn_q: ggml::Tensor, + attn_k: ggml::Tensor, + attn_v: ggml::Tensor, + attn_output: ggml::Tensor, // normalization ffn_norm: ggml::Tensor, // ff - w1: ggml::Tensor, - w2: ggml::Tensor, - w3: ggml::Tensor, + ffn_gate: ggml::Tensor, + ffn_down: ggml::Tensor, + ffn_up: ggml::Tensor, } From 823828d9f3dbdcc0267269724d2f32d37b433dae Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 29 Aug 2023 01:08:19 +0200 Subject: [PATCH 13/33] wip: disable everything that's broken --- binaries/llm-cli/src/cli_args.rs | 106 +-- binaries/llm-cli/src/main.rs | 120 ++-- binaries/llm-test/src/common.rs | 28 +- binaries/llm-test/src/main.rs | 2 +- crates/llm/Cargo.toml | 2 +- crates/models/bloom/src/lib.rs | 888 +++++++++++++------------- crates/models/falcon/src/lib.rs | 942 +++++++++++++-------------- crates/models/gpt2/src/lib.rs | 928 +++++++++++++-------------- crates/models/gptj/src/lib.rs | 868 ++++++++++++------------- crates/models/gptneox/src/lib.rs | 1030 +++++++++++++++--------------- crates/models/mpt/src/lib.rs | 742 ++++++++++----------- 11 files changed, 2828 insertions(+), 2828 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 68839c67..00c0b0ed 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -45,9 +45,9 @@ pub enum Args { /// and do not support a long enough context window to be able to /// have an extended conversation. Chat(Box), - - /// Quantize a GGML model to 4-bit. - Quantize(Box), + // + // /// Quantize a GGML model to 4-bit. + // Quantize(Box), } #[derive(Parser, Debug)] @@ -629,56 +629,56 @@ pub fn read_prompt_file(path: &Path) -> eyre::Result { .wrap_err_with(|| format!("Could not read prompt file at {path:?}")) } -#[derive(Parser, Debug)] -pub struct Quantize { - #[command(flatten)] - pub architecture: ModelArchitecture, - - /// The path to the model to quantize - #[arg()] - pub source: PathBuf, - - /// The path to save the quantized model to - #[arg()] - pub destination: PathBuf, - - #[command(flatten)] - pub tokenizer: ModelTokenizer, - - /// The GGML container type to target. - /// - /// Note that using GGML requires the original model to have - /// an unscored vocabulary, which is not the case for newer models. - #[arg(short, long, default_value_t = SaveContainerType::GgjtV3)] - pub container_type: SaveContainerType, - - /// The format to convert to - pub target: QuantizationTarget, -} - -#[derive(Parser, Debug, ValueEnum, Clone, Copy)] -pub enum SaveContainerType { - /// GGML container. - Ggml, - /// GGJT v3 container. - GgjtV3, -} -impl fmt::Display for SaveContainerType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - SaveContainerType::Ggml => write!(f, "ggml"), - SaveContainerType::GgjtV3 => write!(f, "ggjt-v3"), - } - } -} -impl From for ggml_format::ggml::SaveContainerType { - fn from(value: SaveContainerType) -> Self { - match value { - SaveContainerType::Ggml => ggml_format::ggml::SaveContainerType::Ggml, - SaveContainerType::GgjtV3 => ggml_format::ggml::SaveContainerType::GgjtV3, - } - } -} +// #[derive(Parser, Debug)] +// pub struct Quantize { +// #[command(flatten)] +// pub architecture: ModelArchitecture, + +// /// The path to the model to quantize +// #[arg()] +// pub source: PathBuf, + +// /// The path to save the quantized model to +// #[arg()] +// pub destination: PathBuf, + +// #[command(flatten)] +// pub tokenizer: ModelTokenizer, + +// /// The GGML container type to target. +// /// +// /// Note that using GGML requires the original model to have +// /// an unscored vocabulary, which is not the case for newer models. +// #[arg(short, long, default_value_t = SaveContainerType::GgjtV3)] +// pub container_type: SaveContainerType, + +// /// The format to convert to +// pub target: QuantizationTarget, +// } + +// #[derive(Parser, Debug, ValueEnum, Clone, Copy)] +// pub enum SaveContainerType { +// /// GGML container. +// Ggml, +// /// GGJT v3 container. +// GgjtV3, +// } +// impl fmt::Display for SaveContainerType { +// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +// match self { +// SaveContainerType::Ggml => write!(f, "ggml"), +// SaveContainerType::GgjtV3 => write!(f, "ggjt-v3"), +// } +// } +// } +// impl From for ggml_format::ggml::SaveContainerType { +// fn from(value: SaveContainerType) -> Self { +// match value { +// SaveContainerType::Ggml => ggml_format::ggml::SaveContainerType::Ggml, +// SaveContainerType::GgjtV3 => ggml_format::ggml::SaveContainerType::GgjtV3, +// } +// } +// } #[derive(Parser, Debug, ValueEnum, Clone, Copy)] #[clap(rename_all = "snake_case")] diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 20dd2d91..1f8b330c 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -36,7 +36,7 @@ fn main() -> eyre::Result<()> { Args::PromptTokens(args) => prompt_tokens(&args), Args::Repl(args) => interactive::repl(&args), Args::Chat(args) => interactive::chat(&args), - Args::Quantize(args) => quantize(&args), + // Args::Quantize(args) => quantize(&args), } } @@ -203,65 +203,65 @@ fn prompt_tokens(args: &cli_args::PromptTokens) -> eyre::Result<()> { Ok(()) } -fn quantize(args: &cli_args::Quantize) -> eyre::Result<()> { - use llm::QuantizeProgress; - - struct QuantizeVisitor<'a>(&'a cli_args::Quantize); - impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { - fn visit(&mut self) -> eyre::Result<()> { - let args = self.0; - - let mut source: BufReader = BufReader::new(std::fs::File::open(&args.source)?); - let mut destination: BufWriter = - BufWriter::new(std::fs::File::create(&args.destination)?); - let tokenizer: llm::Tokenizer = args.tokenizer.to_source()?.retrieve(&args.source)?; - - llm::quantize::( - &mut source, - &mut destination, - tokenizer, - args.container_type.into(), - args.target.into(), - |progress| match progress { - QuantizeProgress::HyperparametersLoaded => log::info!("Loaded hyperparameters"), - QuantizeProgress::TensorLoading { - name, - dims, - element_type, - n_elements, - } => log::info!( - "Loading tensor `{name}` ({n_elements} ({dims:?}) {element_type} elements)" - ), - QuantizeProgress::TensorQuantizing { name } => log::info!("Quantizing tensor `{name}`"), - QuantizeProgress::TensorQuantized { - name, - original_size, - reduced_size, - history, - } => log::info!( - "Quantized tensor `{name}` from {original_size} to {reduced_size} bytes ({history:?})" - ), - QuantizeProgress::TensorSkipped { name, size } => { - log::info!("Skipped tensor `{name}` ({size} bytes)") - } - QuantizeProgress::Finished { - original_size, - reduced_size, - history, - } => log::info!( - "Finished quantization from {original_size} to {reduced_size} bytes ({history:?})" - ), - }, - ) - .wrap_err("failed to quantize model") - } - } - - args.architecture - .model_architecture - .wrap_err("the architecture must be known for quantization")? - .visit(&mut QuantizeVisitor(args)) -} +// fn quantize(args: &cli_args::Quantize) -> eyre::Result<()> { +// use llm::QuantizeProgress; + +// struct QuantizeVisitor<'a>(&'a cli_args::Quantize); +// impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { +// fn visit(&mut self) -> eyre::Result<()> { +// let args = self.0; + +// let mut source: BufReader = BufReader::new(std::fs::File::open(&args.source)?); +// let mut destination: BufWriter = +// BufWriter::new(std::fs::File::create(&args.destination)?); +// let tokenizer: llm::Tokenizer = args.tokenizer.to_source()?.retrieve(&args.source)?; + +// llm::quantize::( +// &mut source, +// &mut destination, +// tokenizer, +// args.container_type.into(), +// args.target.into(), +// |progress| match progress { +// QuantizeProgress::HyperparametersLoaded => log::info!("Loaded hyperparameters"), +// QuantizeProgress::TensorLoading { +// name, +// dims, +// element_type, +// n_elements, +// } => log::info!( +// "Loading tensor `{name}` ({n_elements} ({dims:?}) {element_type} elements)" +// ), +// QuantizeProgress::TensorQuantizing { name } => log::info!("Quantizing tensor `{name}`"), +// QuantizeProgress::TensorQuantized { +// name, +// original_size, +// reduced_size, +// history, +// } => log::info!( +// "Quantized tensor `{name}` from {original_size} to {reduced_size} bytes ({history:?})" +// ), +// QuantizeProgress::TensorSkipped { name, size } => { +// log::info!("Skipped tensor `{name}` ({size} bytes)") +// } +// QuantizeProgress::Finished { +// original_size, +// reduced_size, +// history, +// } => log::info!( +// "Finished quantization from {original_size} to {reduced_size} bytes ({history:?})" +// ), +// }, +// ) +// .wrap_err("failed to quantize model") +// } +// } + +// args.architecture +// .model_architecture +// .wrap_err("the architecture must be known for quantization")? +// .visit(&mut QuantizeVisitor(args)) +// } fn load_prompt_file_with_prompt( prompt_file: &cli_args::PromptFile, diff --git a/binaries/llm-test/src/common.rs b/binaries/llm-test/src/common.rs index 4c858820..46ab2a50 100644 --- a/binaries/llm-test/src/common.rs +++ b/binaries/llm-test/src/common.rs @@ -10,21 +10,21 @@ pub(super) fn can_send(model: M) -> anyhow::Result model } -pub(super) fn can_roundtrip_hyperparameters( - model: &M, -) -> anyhow::Result<()> { - fn test_hyperparameters(hyperparameters: &M) -> anyhow::Result<()> { - let mut data = vec![]; - hyperparameters.write_ggml(&mut data)?; - let new_hyperparameters = - ::read_ggml(&mut std::io::Cursor::new(data))?; +// pub(super) fn can_roundtrip_hyperparameters( +// model: &M, +// ) -> anyhow::Result<()> { +// fn test_hyperparameters(hyperparameters: &M) -> anyhow::Result<()> { +// let mut data = vec![]; +// hyperparameters.write_ggml(&mut data)?; +// let new_hyperparameters = +// ::read_ggml(&mut std::io::Cursor::new(data))?; - assert_eq!(hyperparameters, &new_hyperparameters); +// assert_eq!(hyperparameters, &new_hyperparameters); - log::info!("`can_roundtrip_hyperparameters` test passed!"); +// log::info!("`can_roundtrip_hyperparameters` test passed!"); - Ok(()) - } +// Ok(()) +// } - test_hyperparameters(model.hyperparameters()) -} +// test_hyperparameters(model.hyperparameters()) +// } diff --git a/binaries/llm-test/src/main.rs b/binaries/llm-test/src/main.rs index b1bc9b07..4b982505 100644 --- a/binaries/llm-test/src/main.rs +++ b/binaries/llm-test/src/main.rs @@ -255,7 +255,7 @@ async fn test_model( let model = common::can_send(model)?; // Confirm that the hyperparameters can be roundtripped - common::can_roundtrip_hyperparameters(&model)?; + // common::can_roundtrip_hyperparameters(&model)?; // diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 0f395f5a..f3c93e65 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -34,7 +34,7 @@ default = ["models", "tokenizers-remote"] tokenizers-remote = ["llm-base/tokenizers-remote"] -models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"] +models = ["llama"] #, "gpt2", "gptj", "bloom", "gptneox", "mpt"] llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] gptj = ["dep:llm-gptj"] diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index efa1f338..3880f8c4 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -1,444 +1,444 @@ -//! An implementation of [BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom) -//! for the `llm` ecosystem. -#![deny(missing_docs)] - -use llm_base::{ - ggml, - model::{common, HyperparametersWriteError}, - util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, - ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, -}; - -/// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom) -/// -/// # Safety -/// This implements [Send] and [Sync] as it is immutable after construction. -pub struct Bloom { - params: ModelParameters, - - hyperparameters: Hyperparameters, - tokenizer: Tokenizer, - - // model-global weights - // weighted token embeddings - wte: ggml::Tensor, - // normalization weight & bias - norm: ggml::Tensor, - norm_bias: ggml::Tensor, - // output normalization weight & bias - output_norm: ggml::Tensor, - output_norm_bias: ggml::Tensor, - // output weight - output: ggml::Tensor, - - // weights for the model - layers: Vec, - - // must be kept alive for the model - context: ModelContext, -} - -unsafe impl Send for Bloom {} -unsafe impl Sync for Bloom {} - -impl KnownModel for Bloom { - type Hyperparameters = Hyperparameters; - - fn new( - hyperparameters: Self::Hyperparameters, - params: ModelParameters, - tokenizer: Tokenizer, - tensor_loader: impl llm_base::TensorLoader, - ) -> Result { - let mut tl = tensor_loader; - - // model-global weights - let wte = tl.load("tok_embeddings.weight")?; - let norm = tl.load("norm.weight")?; - let norm_bias = tl.load("norm.bias")?; - let output_norm = tl.load("output_norm.weight")?; - let output_norm_bias = tl.load("output_norm.bias")?; - let output = tl.load("output.weight")?; - - let mut layers = Vec::new(); - for i in 0..hyperparameters.n_layer { - let layer = Layer { - attention_norm: tl.load(&format!("layers.{i}.attention_norm.weight"))?, - attention_norm_b: tl.load(&format!("layers.{i}.attention_norm.bias"))?, - - query_key_value: tl - .load(&format!("layers.{i}.attention.query_key_value.weight"))?, - query_key_value_b: tl - .load(&format!("layers.{i}.attention.query_key_value.bias"))?, - - wo: tl.load(&format!("layers.{i}.attention.wo.weight"))?, - wo_b: tl.load(&format!("layers.{i}.attention.wo.bias"))?, - - ffn_norm: tl.load(&format!("layers.{i}.ffn_norm.weight"))?, - ffn_norm_b: tl.load(&format!("layers.{i}.ffn_norm.bias"))?, - - w1: tl.load(&format!("layers.{i}.feed_forward.w1.weight"))?, - w1_b: tl.load(&format!("layers.{i}.feed_forward.w1.bias"))?, - w2: tl.load(&format!("layers.{i}.feed_forward.w2.weight"))?, - w2_b: tl.load(&format!("layers.{i}.feed_forward.w2.bias"))?, - }; - - layers.push(layer); - } - - let context = tl.finish(); - - Ok(Bloom { - hyperparameters, - params, - tokenizer, - wte, - norm, - norm_bias, - output_norm, - output_norm_bias, - output, - layers, - context, - }) - } - - fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { - InferenceSession::new( - config, - &self.params, - self.hyperparameters.n_layer, - self.hyperparameters.n_embd, - self.hyperparameters.n_vocab, - ) - } - - fn evaluate( - &self, - session: &mut InferenceSession, - input_tokens: &[TokenId], - output_request: &mut OutputRequest, - ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; - let ctx_size = self.params.context_size; - - let Hyperparameters { - n_vocab, - n_embd, - n_mult: _, - n_head, - n_layer, - file_type: _, - } = self.hyperparameters; - - let outputs = session.compute(self.context.clone(), input_tokens, |builder| { - let ctx0 = builder.ctx0.borrow(); - let (memory_k_size, memory_v_size) = ( - builder.memory_k.element_size(), - builder.memory_v.element_size(), - ); - let embd = &builder.embd; - let mut input_layer = ctx0.op_get_rows(&self.wte, embd); - - // normalize embeddings - input_layer = ctx0.op_norm(&input_layer); - input_layer = ctx0.op_mul(&input_layer, &self.norm); - input_layer = ctx0.op_add(&input_layer, &self.norm_bias); - - let mut gf = ctx0.create_compute_graph(); - for il in 0..n_layer { - let input_self_attention = input_layer.share(); - let mut current: ggml::Tensor; - - // norm - current = ctx0.op_norm(&input_layer); - - // cur = attention_norm * cur - current = ctx0.op_mul(¤t, &self.layers[il].attention_norm); - current = ctx0.op_add(¤t, &self.layers[il].attention_norm_b); - - //attention - current = ctx0.op_mul_mat(&self.layers[il].query_key_value, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].query_key_value_b); - - // self-attention - let nb = current.get_nb()[1]; - let q_current = ctx0.op_view_2d( - ¤t, - (n_embd, input_len), - nb, - //0 * std::mem::size_of::() * n_embd as usize, - 0, - ); - let k_current = ctx0.op_view_2d( - ¤t, - (n_embd, input_len), - nb, - std::mem::size_of::() * n_embd, - ); - let v_current = ctx0.op_view_2d( - ¤t, - (n_embd, input_len), - nb, - 2 * std::mem::size_of::() * n_embd, - ); - - // store key and value to memory - if input_len >= 1 { - let k = ctx0.op_view_1d( - builder.memory_k, - input_len * n_embd, - (memory_k_size * n_embd) * (il * ctx_size + session_len), - ); - - let v = ctx0.op_view_1d( - builder.memory_v, - input_len * n_embd, - (memory_v_size * n_embd) * (il * ctx_size + session_len), - ); - - gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v)); - } - - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - let big_q = ctx0.op_permute( - &ctx0.op_cpy( - &q_current, - &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, input_len), - ), - (0, 2, 1, 3), - ); - - // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - let big_k = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - builder.memory_k, - (session_len + input_len) * n_embd, - il * ctx_size * memory_k_size * n_embd, - ), - n_embd / n_head, - n_head, - session_len + input_len, - ), - (0, 2, 1, 3), - ); - - // K * Q - let k_q = ctx0.op_mul_mat(&big_k, &big_q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - let k_q_scaled = ctx0.op_scale( - &k_q, - &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - - //alibi - // KQ_scaled_alibi = KQ_scaled + alibi_bias - let k_q_scaled_alibi = ctx0.op_alibi(&k_q_scaled, session_len, n_head, 8f32); - - // KQ_masked = mask_past(KQ_scaled) - let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled_alibi, session_len); - - // KQ = soft_max(KQ_masked) - let k_q_soft_max = ctx0.op_soft_max(&k_q_masked); - - let memv_elsize = memory_v_size; - - let v_trans = ctx0.op_cpy( - &ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - builder.memory_v, - (session_len + input_len) * n_embd, - il * ctx_size * memv_elsize * n_embd, - ), - n_embd / n_head, - n_head, - session_len + input_len, - ), - (1, 2, 0, 3), - ), - &ctx0.new_tensor_3d( - builder.memory_v.get_type(), - session_len + input_len, - n_embd / n_head, - n_head, - ), - ); - - let k_q_v = ctx0.op_mul_mat(&v_trans, &k_q_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - let k_q_v_merged = ctx0.op_permute(&k_q_v, (0, 2, 1, 3)); - - // cur = KQV_merged.contiguous().view(n_embd, N) - current = ctx0.op_cpy( - &k_q_v_merged, - &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), - ); - - // projection - current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].wo_b); - - let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); - - // feed-forward network - // norm - current = ctx0.op_norm(&input_feed_forward); - - // cur = ffn_norm*cur + ffn_norm_b - current = ctx0.op_mul(¤t, &self.layers[il].ffn_norm); - - current = ctx0.op_add(¤t, &self.layers[il].ffn_norm_b); - - current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t); - - current = ctx0.op_add(¤t, &self.layers[il].w1_b); - - // SILU activation - - current = ctx0.op_gelu(¤t); - - current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t); - - current = ctx0.op_add(¤t, &self.layers[il].w2_b); - - current = ctx0.op_add(¤t, &input_feed_forward); - - // input for next layer - input_layer = current; - } - - // norm - input_layer = ctx0.op_norm(&input_layer); - - // inpL = norm*inpL - input_layer = ctx0.op_mul(&input_layer, &self.output_norm); - - input_layer = ctx0.op_add(&input_layer, &self.output_norm_bias); - - let embeddings_tensor: ggml::Tensor = input_layer.share(); - - // lm_head - input_layer = ctx0.op_mul_mat(&self.output, &input_layer); - - ( - gf, - GraphOutputs { - result: input_layer, - embedding_result: embeddings_tensor, - }, - ) - }); - - // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); - } - - fn hyperparameters(&self) -> &Self::Hyperparameters { - &self.hyperparameters - } - - fn tokenizer(&self) -> &Tokenizer { - &self.tokenizer - } - - fn context_size(&self) -> usize { - self.params.context_size - } - - fn bot_token_id(&self) -> Option { - self.tokenizer.id("".as_bytes()) - } - - fn eot_token_id(&self) -> TokenId { - self.tokenizer.id("".as_bytes()).unwrap() - } - - fn quantize_tensors() -> Vec { - vec![Regex::new(".*weight").unwrap()] - } - - fn skip_quantize_tensors() -> Vec { - vec![] - } - - fn supports_rewind(&self) -> bool { - true - } -} - -/// BLOOM [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct Hyperparameters { - /// Size of the model's vocabulary - pub n_vocab: usize, - /// Size of the model's embedding layer - pub n_embd: usize, - /// n_mult - pub n_mult: usize, - /// n_head - pub n_head: usize, - /// Number of layers in the model - pub n_layer: usize, - /// file_type - pub file_type: FileType, -} - -impl llm_base::Hyperparameters for Hyperparameters { - fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { - Ok(Hyperparameters { - n_vocab: util::read_i32(reader)?.try_into()?, - n_embd: util::read_i32(reader)?.try_into()?, - n_mult: util::read_i32(reader)?.try_into()?, - n_head: util::read_i32(reader)?.try_into()?, - n_layer: util::read_i32(reader)?.try_into()?, - file_type: util::read_filetype(reader)?, - }) - } - - fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { - util::write_i32(writer, self.n_vocab.try_into()?)?; - util::write_i32(writer, self.n_embd.try_into()?)?; - util::write_i32(writer, self.n_mult.try_into()?)?; - util::write_i32(writer, self.n_head.try_into()?)?; - util::write_i32(writer, self.n_layer.try_into()?)?; - util::write_i32(writer, self.file_type.into())?; - Ok(()) - } - - fn n_vocabulary(&self) -> usize { - self.n_vocab - } - - fn file_type(&self) -> Option { - Some(self.file_type) - } - - fn file_type_mut(&mut self) -> Option<&mut FileType> { - Some(&mut self.file_type) - } -} - -struct Layer { - pub attention_norm: ggml::Tensor, - pub attention_norm_b: ggml::Tensor, - pub wo: ggml::Tensor, - pub wo_b: ggml::Tensor, - pub query_key_value: ggml::Tensor, - pub query_key_value_b: ggml::Tensor, - // normalization - pub ffn_norm: ggml::Tensor, - pub ffn_norm_b: ggml::Tensor, - // ff - pub w1: ggml::Tensor, - pub w1_b: ggml::Tensor, - pub w2: ggml::Tensor, - pub w2_b: ggml::Tensor, -} +// //! An implementation of [BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom) +// //! for the `llm` ecosystem. +// #![deny(missing_docs)] + +// use llm_base::{ +// ggml, +// model::{common, HyperparametersWriteError}, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, +// ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, +// }; + +// /// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom) +// /// +// /// # Safety +// /// This implements [Send] and [Sync] as it is immutable after construction. +// pub struct Bloom { +// params: ModelParameters, + +// hyperparameters: Hyperparameters, +// tokenizer: Tokenizer, + +// // model-global weights +// // weighted token embeddings +// wte: ggml::Tensor, +// // normalization weight & bias +// norm: ggml::Tensor, +// norm_bias: ggml::Tensor, +// // output normalization weight & bias +// output_norm: ggml::Tensor, +// output_norm_bias: ggml::Tensor, +// // output weight +// output: ggml::Tensor, + +// // weights for the model +// layers: Vec, + +// // must be kept alive for the model +// context: ModelContext, +// } + +// unsafe impl Send for Bloom {} +// unsafe impl Sync for Bloom {} + +// impl KnownModel for Bloom { +// type Hyperparameters = Hyperparameters; + +// fn new( +// hyperparameters: Self::Hyperparameters, +// params: ModelParameters, +// tokenizer: Tokenizer, +// tensor_loader: impl llm_base::TensorLoader, +// ) -> Result { +// let mut tl = tensor_loader; + +// // model-global weights +// let wte = tl.load("tok_embeddings.weight")?; +// let norm = tl.load("norm.weight")?; +// let norm_bias = tl.load("norm.bias")?; +// let output_norm = tl.load("output_norm.weight")?; +// let output_norm_bias = tl.load("output_norm.bias")?; +// let output = tl.load("output.weight")?; + +// let mut layers = Vec::new(); +// for i in 0..hyperparameters.n_layer { +// let layer = Layer { +// attention_norm: tl.load(&format!("layers.{i}.attention_norm.weight"))?, +// attention_norm_b: tl.load(&format!("layers.{i}.attention_norm.bias"))?, + +// query_key_value: tl +// .load(&format!("layers.{i}.attention.query_key_value.weight"))?, +// query_key_value_b: tl +// .load(&format!("layers.{i}.attention.query_key_value.bias"))?, + +// wo: tl.load(&format!("layers.{i}.attention.wo.weight"))?, +// wo_b: tl.load(&format!("layers.{i}.attention.wo.bias"))?, + +// ffn_norm: tl.load(&format!("layers.{i}.ffn_norm.weight"))?, +// ffn_norm_b: tl.load(&format!("layers.{i}.ffn_norm.bias"))?, + +// w1: tl.load(&format!("layers.{i}.feed_forward.w1.weight"))?, +// w1_b: tl.load(&format!("layers.{i}.feed_forward.w1.bias"))?, +// w2: tl.load(&format!("layers.{i}.feed_forward.w2.weight"))?, +// w2_b: tl.load(&format!("layers.{i}.feed_forward.w2.bias"))?, +// }; + +// layers.push(layer); +// } + +// let context = tl.finish(); + +// Ok(Bloom { +// hyperparameters, +// params, +// tokenizer, +// wte, +// norm, +// norm_bias, +// output_norm, +// output_norm_bias, +// output, +// layers, +// context, +// }) +// } + +// fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { +// InferenceSession::new( +// config, +// &self.params, +// self.hyperparameters.n_layer, +// self.hyperparameters.n_embd, +// self.hyperparameters.n_vocab, +// ) +// } + +// fn evaluate( +// &self, +// session: &mut InferenceSession, +// input_tokens: &[TokenId], +// output_request: &mut OutputRequest, +// ) { +// let input_len = input_tokens.len(); +// let session_len = session.n_past; +// let ctx_size = self.params.context_size; + +// let Hyperparameters { +// n_vocab, +// n_embd, +// n_mult: _, +// n_head, +// n_layer, +// file_type: _, +// } = self.hyperparameters; + +// let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let ctx0 = builder.ctx0.borrow(); +// let (memory_k_size, memory_v_size) = ( +// builder.memory_k.element_size(), +// builder.memory_v.element_size(), +// ); +// let embd = &builder.embd; +// let mut input_layer = ctx0.op_get_rows(&self.wte, embd); + +// // normalize embeddings +// input_layer = ctx0.op_norm(&input_layer); +// input_layer = ctx0.op_mul(&input_layer, &self.norm); +// input_layer = ctx0.op_add(&input_layer, &self.norm_bias); + +// let mut gf = ctx0.create_compute_graph(); +// for il in 0..n_layer { +// let input_self_attention = input_layer.share(); +// let mut current: ggml::Tensor; + +// // norm +// current = ctx0.op_norm(&input_layer); + +// // cur = attention_norm * cur +// current = ctx0.op_mul(¤t, &self.layers[il].attention_norm); +// current = ctx0.op_add(¤t, &self.layers[il].attention_norm_b); + +// //attention +// current = ctx0.op_mul_mat(&self.layers[il].query_key_value, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].query_key_value_b); + +// // self-attention +// let nb = current.get_nb()[1]; +// let q_current = ctx0.op_view_2d( +// ¤t, +// (n_embd, input_len), +// nb, +// //0 * std::mem::size_of::() * n_embd as usize, +// 0, +// ); +// let k_current = ctx0.op_view_2d( +// ¤t, +// (n_embd, input_len), +// nb, +// std::mem::size_of::() * n_embd, +// ); +// let v_current = ctx0.op_view_2d( +// ¤t, +// (n_embd, input_len), +// nb, +// 2 * std::mem::size_of::() * n_embd, +// ); + +// // store key and value to memory +// if input_len >= 1 { +// let k = ctx0.op_view_1d( +// builder.memory_k, +// input_len * n_embd, +// (memory_k_size * n_embd) * (il * ctx_size + session_len), +// ); + +// let v = ctx0.op_view_1d( +// builder.memory_v, +// input_len * n_embd, +// (memory_v_size * n_embd) * (il * ctx_size + session_len), +// ); + +// gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k)); +// gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v)); +// } + +// // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) +// let big_q = ctx0.op_permute( +// &ctx0.op_cpy( +// &q_current, +// &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, input_len), +// ), +// (0, 2, 1, 3), +// ); + +// // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) +// let big_k = ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// builder.memory_k, +// (session_len + input_len) * n_embd, +// il * ctx_size * memory_k_size * n_embd, +// ), +// n_embd / n_head, +// n_head, +// session_len + input_len, +// ), +// (0, 2, 1, 3), +// ); + +// // K * Q +// let k_q = ctx0.op_mul_mat(&big_k, &big_q); + +// // KQ_scaled = KQ / sqrt(n_embd/n_head) +// let k_q_scaled = ctx0.op_scale( +// &k_q, +// &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), +// ); + +// //alibi +// // KQ_scaled_alibi = KQ_scaled + alibi_bias +// let k_q_scaled_alibi = ctx0.op_alibi(&k_q_scaled, session_len, n_head, 8f32); + +// // KQ_masked = mask_past(KQ_scaled) +// let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled_alibi, session_len); + +// // KQ = soft_max(KQ_masked) +// let k_q_soft_max = ctx0.op_soft_max(&k_q_masked); + +// let memv_elsize = memory_v_size; + +// let v_trans = ctx0.op_cpy( +// &ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// builder.memory_v, +// (session_len + input_len) * n_embd, +// il * ctx_size * memv_elsize * n_embd, +// ), +// n_embd / n_head, +// n_head, +// session_len + input_len, +// ), +// (1, 2, 0, 3), +// ), +// &ctx0.new_tensor_3d( +// builder.memory_v.get_type(), +// session_len + input_len, +// n_embd / n_head, +// n_head, +// ), +// ); + +// let k_q_v = ctx0.op_mul_mat(&v_trans, &k_q_soft_max); + +// // KQV_merged = KQV.permute(0, 2, 1, 3) +// let k_q_v_merged = ctx0.op_permute(&k_q_v, (0, 2, 1, 3)); + +// // cur = KQV_merged.contiguous().view(n_embd, N) +// current = ctx0.op_cpy( +// &k_q_v_merged, +// &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), +// ); + +// // projection +// current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].wo_b); + +// let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); + +// // feed-forward network +// // norm +// current = ctx0.op_norm(&input_feed_forward); + +// // cur = ffn_norm*cur + ffn_norm_b +// current = ctx0.op_mul(¤t, &self.layers[il].ffn_norm); + +// current = ctx0.op_add(¤t, &self.layers[il].ffn_norm_b); + +// current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t); + +// current = ctx0.op_add(¤t, &self.layers[il].w1_b); + +// // SILU activation + +// current = ctx0.op_gelu(¤t); + +// current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t); + +// current = ctx0.op_add(¤t, &self.layers[il].w2_b); + +// current = ctx0.op_add(¤t, &input_feed_forward); + +// // input for next layer +// input_layer = current; +// } + +// // norm +// input_layer = ctx0.op_norm(&input_layer); + +// // inpL = norm*inpL +// input_layer = ctx0.op_mul(&input_layer, &self.output_norm); + +// input_layer = ctx0.op_add(&input_layer, &self.output_norm_bias); + +// let embeddings_tensor: ggml::Tensor = input_layer.share(); + +// // lm_head +// input_layer = ctx0.op_mul_mat(&self.output, &input_layer); + +// ( +// gf, +// GraphOutputs { +// result: input_layer, +// embedding_result: embeddings_tensor, +// }, +// ) +// }); + +// // finish evaluation +// common::read_last_token(session, &outputs.result, n_vocab, input_len); +// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); +// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); +// } + +// fn hyperparameters(&self) -> &Self::Hyperparameters { +// &self.hyperparameters +// } + +// fn tokenizer(&self) -> &Tokenizer { +// &self.tokenizer +// } + +// fn context_size(&self) -> usize { +// self.params.context_size +// } + +// fn bot_token_id(&self) -> Option { +// self.tokenizer.id("".as_bytes()) +// } + +// fn eot_token_id(&self) -> TokenId { +// self.tokenizer.id("".as_bytes()).unwrap() +// } + +// fn quantize_tensors() -> Vec { +// vec![Regex::new(".*weight").unwrap()] +// } + +// fn skip_quantize_tensors() -> Vec { +// vec![] +// } + +// fn supports_rewind(&self) -> bool { +// true +// } +// } + +// /// BLOOM [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +// #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +// pub struct Hyperparameters { +// /// Size of the model's vocabulary +// pub n_vocab: usize, +// /// Size of the model's embedding layer +// pub n_embd: usize, +// /// n_mult +// pub n_mult: usize, +// /// n_head +// pub n_head: usize, +// /// Number of layers in the model +// pub n_layer: usize, +// /// file_type +// pub file_type: FileType, +// } + +// impl llm_base::Hyperparameters for Hyperparameters { +// fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { +// Ok(Hyperparameters { +// n_vocab: util::read_i32(reader)?.try_into()?, +// n_embd: util::read_i32(reader)?.try_into()?, +// n_mult: util::read_i32(reader)?.try_into()?, +// n_head: util::read_i32(reader)?.try_into()?, +// n_layer: util::read_i32(reader)?.try_into()?, +// file_type: util::read_filetype(reader)?, +// }) +// } + +// fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { +// util::write_i32(writer, self.n_vocab.try_into()?)?; +// util::write_i32(writer, self.n_embd.try_into()?)?; +// util::write_i32(writer, self.n_mult.try_into()?)?; +// util::write_i32(writer, self.n_head.try_into()?)?; +// util::write_i32(writer, self.n_layer.try_into()?)?; +// util::write_i32(writer, self.file_type.into())?; +// Ok(()) +// } + +// fn n_vocabulary(&self) -> usize { +// self.n_vocab +// } + +// fn file_type(&self) -> Option { +// Some(self.file_type) +// } + +// fn file_type_mut(&mut self) -> Option<&mut FileType> { +// Some(&mut self.file_type) +// } +// } + +// struct Layer { +// pub attention_norm: ggml::Tensor, +// pub attention_norm_b: ggml::Tensor, +// pub wo: ggml::Tensor, +// pub wo_b: ggml::Tensor, +// pub query_key_value: ggml::Tensor, +// pub query_key_value_b: ggml::Tensor, +// // normalization +// pub ffn_norm: ggml::Tensor, +// pub ffn_norm_b: ggml::Tensor, +// // ff +// pub w1: ggml::Tensor, +// pub w1_b: ggml::Tensor, +// pub w2: ggml::Tensor, +// pub w2_b: ggml::Tensor, +// } diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs index 0322e2f2..db26a6d1 100644 --- a/crates/models/falcon/src/lib.rs +++ b/crates/models/falcon/src/lib.rs @@ -1,471 +1,471 @@ -//! An implementation of the [Falcon](https://falconllm.tii.ae/) model for the `llm` ecosystem. -//! -//! This implementation only works for Falcon 7B, and with 32-bit memory tensors (i.e. your inference session -//! must be configured with a 32-bit [InferenceSessionConfig]). -//! -//! This model will not be generally available in the `llm` ecosystem until Falcon 40B and 16-bit memory is -//! supported. It is currently only available as a preview. -#![deny(missing_docs)] - -use ggml::Tensor; -use llm_base::{ - ggml, - model::{common, HyperparametersWriteError}, - util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, -}; - -/// The Falcon model. Ref: [Technology Innovation Institute](https://huggingface.co/tiiuae) -/// -/// # Safety -/// This implements [Send] and [Sync] as it is immutable after construction. -pub struct Falcon { - params: ModelParameters, - - hyperparameters: Hyperparameters, - - tokenizer: Tokenizer, - - // model-global weights - // weighted token embeddings - tok_embeddings: Tensor, - output_norm: Tensor, - output_norm_b: Tensor, - lm_head: Tensor, - - // weights for the model - layers: Vec, - - // must be kept alive for the model - context: ModelContext, -} - -unsafe impl Send for Falcon {} -unsafe impl Sync for Falcon {} - -impl KnownModel for Falcon { - type Hyperparameters = Hyperparameters; - - fn new( - hyperparameters: Self::Hyperparameters, - params: ModelParameters, - tokenizer: Tokenizer, - tensor_loader: impl llm_base::TensorLoader, - ) -> Result { - let mut tl = tensor_loader; - - // model-gobal weights - let tok_embeddings = tl.load("transformer.word_embeddings.weight")?; - - let backend = params.backend(0); - - let output_norm = tl.load("transformer.ln_f.weight")?.transfer_to(backend); - let output_norm_b = tl.load("transformer.ln_f.bias")?.transfer_to(backend); - let lm_head = tl.load("lm_head.weight")?.transfer_to(backend); - - let mut layers = Vec::new(); - // utilizing n_head_kv to determine the model version (parameters) - let Hyperparameters { n_head_kv, .. } = hyperparameters; - for i in 0..hyperparameters.n_layer { - let backend = params.backend(i); - - let (input_layernorm_name, attention_norm_name) = if n_head_kv == 1 { - // falcon 7b - (format!("transformer.h.{i}.input_layernorm"), None) - } else { - // falcon 40b - ( - format!("transformer.h.{i}.ln_mlp"), - Some(format!("transformer.h.{i}.ln_attn")), - ) - }; - - let (attention_norm_weight, attention_norm_bias) = - if let Some(norm_name) = attention_norm_name { - ( - Some( - tl.load(&format!("{}.weight", norm_name))? - .transfer_to(backend), - ), - Some( - tl.load(&format!("{}.bias", norm_name))? - .transfer_to(backend), - ), - ) - } else { - (None, None) - }; - - let layer = Layer { - input_layernorm: tl - .load(&format!("{}.weight", input_layernorm_name))? - .transfer_to(backend), - input_layernorm_b: tl - .load(&format!("{}.bias", input_layernorm_name))? - .transfer_to(backend), - attention_norm: attention_norm_weight, - attention_norm_b: attention_norm_bias, - query_key_value: tl - .load(&format!( - "transformer.h.{i}.self_attention.query_key_value.weight" - ))? - .transfer_to(backend), - wo: tl - .load(&format!("transformer.h.{i}.self_attention.dense.weight"))? - .transfer_to(backend), - - ffn_up: tl - .load(&format!("transformer.h.{i}.mlp.dense_h_to_4h.weight"))? - .transfer_to(backend), - ffn_down: tl - .load(&format!("transformer.h.{i}.mlp.dense_4h_to_h.weight"))? - .transfer_to(backend), - }; - - layers.push(layer); - } - - let context = tl.finish(); - - Ok(Falcon { - hyperparameters, - params, - tokenizer, - tok_embeddings, - output_norm, - output_norm_b, - lm_head, - layers, - context, - }) - } - - fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { - InferenceSession::new( - config, - &self.params, - self.hyperparameters.n_layer, - self.hyperparameters.n_embd, - self.hyperparameters.n_vocab, - ) - } - - fn evaluate( - &self, - session: &mut InferenceSession, - input_tokens: &[TokenId], - output_request: &mut OutputRequest, - ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; - let ctx_size = self.params.context_size; - - let Hyperparameters { - n_embd, - n_head, - n_head_kv, - n_vocab, - n_layer, - .. - } = self.hyperparameters; - - let head_dim = n_embd / n_head; - let n = input_len; - - let outputs = session.compute(self.context.clone(), input_tokens, |builder| { - let mut ctx0 = builder.ctx0.borrow_mut(); - let embd = builder.embd; - let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd); - - let f32_size = std::mem::size_of::(); - - let memory_k = builder.memory_k; - let memory_k_size = memory_k.element_size(); - - let memory_v = builder.memory_v; - let memory_v_size = memory_v.element_size(); - - let mut gf = ctx0.create_compute_graph(); - - let mut current: Tensor; - let mut layernorm_output: Tensor; - - for il in 0..n_layer { - // attention uses first scratch buffer - ctx0.use_scratch(builder.get_scratch(0)); - ctx0.set_offloading(self.params.should_offload(il)); - - // self-attention - layernorm_output = ctx0.op_norm(&input_layer); - layernorm_output = ctx0.op_add( - &ctx0.op_mul(&layernorm_output, &self.layers[il].input_layernorm), - &self.layers[il].input_layernorm_b, - ); - - if n_head_kv == 1 { - // Falcon-7B only - current = layernorm_output.share(); - } else { - // Falcon-40B only - current = ctx0.op_norm(&input_layer); - current = ctx0.op_add( - &ctx0.op_mul(¤t, self.layers[il].attention_norm.as_ref().unwrap()), - self.layers[il].attention_norm_b.as_ref().unwrap(), - ); - } - - // compute QKV - current = ctx0.op_mul_mat(&self.layers[il].query_key_value, ¤t); - - let fused_qkv_row_nb = head_dim * (n_head + 2 * n_head_kv) * f32_size; - - let mut qcur = ctx0.op_view_3d( - ¤t, - (head_dim, n_head, n), - (head_dim * f32_size, fused_qkv_row_nb), - 0, - ); - - let mut kcur = ctx0.op_view_3d( - ¤t, - (head_dim, n_head_kv, n), - (head_dim * f32_size, fused_qkv_row_nb), - head_dim * n_head * f32_size, - ); - - let vcur = ctx0.op_view_3d( - ¤t, - (head_dim, n_head_kv, n), - (head_dim * f32_size, fused_qkv_row_nb), - head_dim * (n_head + n_head_kv) * f32_size, - ); - - // using mode = 2 for neox mode - let overrides = self.params.rope_overrides.as_ref(); - qcur = ctx0.op_rope_inplace(&qcur, session_len, head_dim, 2, overrides); - kcur = ctx0.op_rope_inplace(&kcur, session_len, head_dim, 2, overrides); - - // store key and value to memory - - let k = ctx0.op_view_1d( - memory_k, - n * n_head_kv * head_dim, - (memory_k_size * n_head_kv * head_dim) * (il * ctx_size + session_len), - ); - let v = ctx0.op_view_1d( - memory_v, - n * n_head_kv * head_dim, - (memory_v_size * n_head_kv * head_dim) * (il * ctx_size + session_len), - ); - - gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); - - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - let bigq = ctx0.op_permute(&qcur, (0, 2, 1, 3)); - - let bigk = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - memory_k, - (session_len + n) * n_head_kv * head_dim, - il * ctx_size * memory_k_size * n_head_kv * head_dim, - ), - head_dim, - n_head_kv, - session_len + n, - ), - (0, 2, 1, 3), - ); - - // K * Q - let big_kq = ctx0.op_mul_mat(&bigk, &bigq); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - let big_kq_scaled = ctx0.op_scale_inplace( - &big_kq, - &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - - let big_kq_masked = ctx0.op_diag_mask_inf_inplace(&big_kq_scaled, session_len); - - let big_kq_softmax = ctx0.op_soft_max_inplace(&big_kq_masked); - - let mut bigv = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - memory_v, - (session_len + n) * n_head_kv * head_dim, - il * ctx_size * memory_v_size * n_head_kv * head_dim, - ), - head_dim, - n_head_kv, - session_len + n, - ), - (0, 2, 1, 3), - ); - bigv = ctx0.op_cont(&ctx0.op_transpose(&bigv)); - - let big_kqv = ctx0.op_mul_mat(&bigv, &big_kq_softmax); - // KQV_merged = KQV.permute(0, 2, 1, 3) - let big_kqv_merged = ctx0.op_permute(&big_kqv, (0, 2, 1, 3)); - - // cur = KQV_merged.contiguous().view(n_embd, N) - current = ctx0.op_cpy( - &big_kqv_merged, - &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n), - ); - - // projection - current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); - - // feed forward uses second scratch buffer - ctx0.use_scratch(builder.get_scratch(1)); - - let inp_ff = layernorm_output.share(); - let attn_out = - ctx0.op_cpy(¤t, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); - - current = ctx0.op_mul_mat(&self.layers[il].ffn_up, &inp_ff); - current = ctx0.op_gelu(¤t); - current = ctx0.op_mul_mat(&self.layers[il].ffn_down, ¤t); - - current = ctx0.op_add(¤t, &attn_out); - current = ctx0.op_add(¤t, &input_layer); - - input_layer = current.share(); - } - - ctx0.use_scratch(builder.get_scratch(0)); - - // norm - input_layer = ctx0.op_norm(&input_layer); - - input_layer = ctx0.op_add( - &ctx0.op_mul(&input_layer, &self.output_norm), - &self.output_norm_b, - ); - - let embeddings_tensor: ggml::Tensor = input_layer.share(); - - ctx0.set_offloading(false); - ctx0.use_scratch(None); - - // lm_head - input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer); - - ( - gf, - GraphOutputs { - result: input_layer, - embedding_result: embeddings_tensor, - }, - ) - }); - - // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); - } - - fn hyperparameters(&self) -> &Self::Hyperparameters { - &self.hyperparameters - } - - fn tokenizer(&self) -> &Tokenizer { - &self.tokenizer - } - - fn context_size(&self) -> usize { - self.params.context_size - } - - fn bot_token_id(&self) -> Option { - None - } - - fn eot_token_id(&self) -> TokenId { - self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() - } - - fn quantize_tensors() -> Vec { - vec![Regex::new(".*weight").unwrap()] - } - - fn skip_quantize_tensors() -> Vec { - vec![] - } -} - -/// Falcon [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -#[derive(Debug, Default, PartialEq, Clone, Copy, Eq)] -pub struct Hyperparameters { - /// Size of the model's vocabulary - n_vocab: usize, - /// Size of the model's embedding layer - n_embd: usize, - /// n_heads - n_head: usize, - // Number of heads for key-value pairs - n_head_kv: usize, - /// Number of layers in the model - n_layer: usize, - /// file_type - file_type: FileType, -} - -impl llm_base::Hyperparameters for Hyperparameters { - fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { - let hyperparameters = Hyperparameters { - n_vocab: util::read_i32(reader)?.try_into()?, - n_embd: util::read_i32(reader)?.try_into()?, - n_head: util::read_i32(reader)?.try_into()?, - n_head_kv: util::read_i32(reader)?.try_into()?, - n_layer: util::read_i32(reader)?.try_into()?, - file_type: util::read_filetype(reader)?, - }; - - Ok(hyperparameters) - } - - fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { - util::write_i32(writer, self.n_vocab.try_into()?)?; - util::write_i32(writer, self.n_embd.try_into()?)?; - util::write_i32(writer, self.n_head.try_into()?)?; - util::write_i32(writer, self.n_head_kv.try_into()?)?; - util::write_i32(writer, self.n_layer.try_into()?)?; - util::write_i32(writer, self.file_type.into())?; - Ok(()) - } - - fn n_vocabulary(&self) -> usize { - self.n_vocab - } - - fn file_type(&self) -> Option { - Some(self.file_type) - } - - fn file_type_mut(&mut self) -> Option<&mut FileType> { - Some(&mut self.file_type) - } -} - -struct Layer { - // normalization - input_layernorm: Tensor, - input_layernorm_b: Tensor, - - // Falcon-40B only - attention_norm: Option, - attention_norm_b: Option, - - // attention - query_key_value: Tensor, - wo: Tensor, - - // ff - ffn_up: Tensor, - ffn_down: Tensor, -} +// //! An implementation of the [Falcon](https://falconllm.tii.ae/) model for the `llm` ecosystem. +// //! +// //! This implementation only works for Falcon 7B, and with 32-bit memory tensors (i.e. your inference session +// //! must be configured with a 32-bit [InferenceSessionConfig]). +// //! +// //! This model will not be generally available in the `llm` ecosystem until Falcon 40B and 16-bit memory is +// //! supported. It is currently only available as a preview. +// #![deny(missing_docs)] + +// use ggml::Tensor; +// use llm_base::{ +// ggml, +// model::{common, HyperparametersWriteError}, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, +// }; + +// /// The Falcon model. Ref: [Technology Innovation Institute](https://huggingface.co/tiiuae) +// /// +// /// # Safety +// /// This implements [Send] and [Sync] as it is immutable after construction. +// pub struct Falcon { +// params: ModelParameters, + +// hyperparameters: Hyperparameters, + +// tokenizer: Tokenizer, + +// // model-global weights +// // weighted token embeddings +// tok_embeddings: Tensor, +// output_norm: Tensor, +// output_norm_b: Tensor, +// lm_head: Tensor, + +// // weights for the model +// layers: Vec, + +// // must be kept alive for the model +// context: ModelContext, +// } + +// unsafe impl Send for Falcon {} +// unsafe impl Sync for Falcon {} + +// impl KnownModel for Falcon { +// type Hyperparameters = Hyperparameters; + +// fn new( +// hyperparameters: Self::Hyperparameters, +// params: ModelParameters, +// tokenizer: Tokenizer, +// tensor_loader: impl llm_base::TensorLoader, +// ) -> Result { +// let mut tl = tensor_loader; + +// // model-gobal weights +// let tok_embeddings = tl.load("transformer.word_embeddings.weight")?; + +// let backend = params.backend(0); + +// let output_norm = tl.load("transformer.ln_f.weight")?.transfer_to(backend); +// let output_norm_b = tl.load("transformer.ln_f.bias")?.transfer_to(backend); +// let lm_head = tl.load("lm_head.weight")?.transfer_to(backend); + +// let mut layers = Vec::new(); +// // utilizing n_head_kv to determine the model version (parameters) +// let Hyperparameters { n_head_kv, .. } = hyperparameters; +// for i in 0..hyperparameters.n_layer { +// let backend = params.backend(i); + +// let (input_layernorm_name, attention_norm_name) = if n_head_kv == 1 { +// // falcon 7b +// (format!("transformer.h.{i}.input_layernorm"), None) +// } else { +// // falcon 40b +// ( +// format!("transformer.h.{i}.ln_mlp"), +// Some(format!("transformer.h.{i}.ln_attn")), +// ) +// }; + +// let (attention_norm_weight, attention_norm_bias) = +// if let Some(norm_name) = attention_norm_name { +// ( +// Some( +// tl.load(&format!("{}.weight", norm_name))? +// .transfer_to(backend), +// ), +// Some( +// tl.load(&format!("{}.bias", norm_name))? +// .transfer_to(backend), +// ), +// ) +// } else { +// (None, None) +// }; + +// let layer = Layer { +// input_layernorm: tl +// .load(&format!("{}.weight", input_layernorm_name))? +// .transfer_to(backend), +// input_layernorm_b: tl +// .load(&format!("{}.bias", input_layernorm_name))? +// .transfer_to(backend), +// attention_norm: attention_norm_weight, +// attention_norm_b: attention_norm_bias, +// query_key_value: tl +// .load(&format!( +// "transformer.h.{i}.self_attention.query_key_value.weight" +// ))? +// .transfer_to(backend), +// wo: tl +// .load(&format!("transformer.h.{i}.self_attention.dense.weight"))? +// .transfer_to(backend), + +// ffn_up: tl +// .load(&format!("transformer.h.{i}.mlp.dense_h_to_4h.weight"))? +// .transfer_to(backend), +// ffn_down: tl +// .load(&format!("transformer.h.{i}.mlp.dense_4h_to_h.weight"))? +// .transfer_to(backend), +// }; + +// layers.push(layer); +// } + +// let context = tl.finish(); + +// Ok(Falcon { +// hyperparameters, +// params, +// tokenizer, +// tok_embeddings, +// output_norm, +// output_norm_b, +// lm_head, +// layers, +// context, +// }) +// } + +// fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { +// InferenceSession::new( +// config, +// &self.params, +// self.hyperparameters.n_layer, +// self.hyperparameters.n_embd, +// self.hyperparameters.n_vocab, +// ) +// } + +// fn evaluate( +// &self, +// session: &mut InferenceSession, +// input_tokens: &[TokenId], +// output_request: &mut OutputRequest, +// ) { +// let input_len = input_tokens.len(); +// let session_len = session.n_past; +// let ctx_size = self.params.context_size; + +// let Hyperparameters { +// n_embd, +// n_head, +// n_head_kv, +// n_vocab, +// n_layer, +// .. +// } = self.hyperparameters; + +// let head_dim = n_embd / n_head; +// let n = input_len; + +// let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let mut ctx0 = builder.ctx0.borrow_mut(); +// let embd = builder.embd; +// let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd); + +// let f32_size = std::mem::size_of::(); + +// let memory_k = builder.memory_k; +// let memory_k_size = memory_k.element_size(); + +// let memory_v = builder.memory_v; +// let memory_v_size = memory_v.element_size(); + +// let mut gf = ctx0.create_compute_graph(); + +// let mut current: Tensor; +// let mut layernorm_output: Tensor; + +// for il in 0..n_layer { +// // attention uses first scratch buffer +// ctx0.use_scratch(builder.get_scratch(0)); +// ctx0.set_offloading(self.params.should_offload(il)); + +// // self-attention +// layernorm_output = ctx0.op_norm(&input_layer); +// layernorm_output = ctx0.op_add( +// &ctx0.op_mul(&layernorm_output, &self.layers[il].input_layernorm), +// &self.layers[il].input_layernorm_b, +// ); + +// if n_head_kv == 1 { +// // Falcon-7B only +// current = layernorm_output.share(); +// } else { +// // Falcon-40B only +// current = ctx0.op_norm(&input_layer); +// current = ctx0.op_add( +// &ctx0.op_mul(¤t, self.layers[il].attention_norm.as_ref().unwrap()), +// self.layers[il].attention_norm_b.as_ref().unwrap(), +// ); +// } + +// // compute QKV +// current = ctx0.op_mul_mat(&self.layers[il].query_key_value, ¤t); + +// let fused_qkv_row_nb = head_dim * (n_head + 2 * n_head_kv) * f32_size; + +// let mut qcur = ctx0.op_view_3d( +// ¤t, +// (head_dim, n_head, n), +// (head_dim * f32_size, fused_qkv_row_nb), +// 0, +// ); + +// let mut kcur = ctx0.op_view_3d( +// ¤t, +// (head_dim, n_head_kv, n), +// (head_dim * f32_size, fused_qkv_row_nb), +// head_dim * n_head * f32_size, +// ); + +// let vcur = ctx0.op_view_3d( +// ¤t, +// (head_dim, n_head_kv, n), +// (head_dim * f32_size, fused_qkv_row_nb), +// head_dim * (n_head + n_head_kv) * f32_size, +// ); + +// // using mode = 2 for neox mode +// let overrides = self.params.rope_overrides.as_ref(); +// qcur = ctx0.op_rope_inplace(&qcur, session_len, head_dim, 2, overrides); +// kcur = ctx0.op_rope_inplace(&kcur, session_len, head_dim, 2, overrides); + +// // store key and value to memory + +// let k = ctx0.op_view_1d( +// memory_k, +// n * n_head_kv * head_dim, +// (memory_k_size * n_head_kv * head_dim) * (il * ctx_size + session_len), +// ); +// let v = ctx0.op_view_1d( +// memory_v, +// n * n_head_kv * head_dim, +// (memory_v_size * n_head_kv * head_dim) * (il * ctx_size + session_len), +// ); + +// gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); +// gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); + +// // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) +// let bigq = ctx0.op_permute(&qcur, (0, 2, 1, 3)); + +// let bigk = ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// memory_k, +// (session_len + n) * n_head_kv * head_dim, +// il * ctx_size * memory_k_size * n_head_kv * head_dim, +// ), +// head_dim, +// n_head_kv, +// session_len + n, +// ), +// (0, 2, 1, 3), +// ); + +// // K * Q +// let big_kq = ctx0.op_mul_mat(&bigk, &bigq); + +// // KQ_scaled = KQ / sqrt(n_embd/n_head) +// let big_kq_scaled = ctx0.op_scale_inplace( +// &big_kq, +// &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), +// ); + +// let big_kq_masked = ctx0.op_diag_mask_inf_inplace(&big_kq_scaled, session_len); + +// let big_kq_softmax = ctx0.op_soft_max_inplace(&big_kq_masked); + +// let mut bigv = ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// memory_v, +// (session_len + n) * n_head_kv * head_dim, +// il * ctx_size * memory_v_size * n_head_kv * head_dim, +// ), +// head_dim, +// n_head_kv, +// session_len + n, +// ), +// (0, 2, 1, 3), +// ); +// bigv = ctx0.op_cont(&ctx0.op_transpose(&bigv)); + +// let big_kqv = ctx0.op_mul_mat(&bigv, &big_kq_softmax); +// // KQV_merged = KQV.permute(0, 2, 1, 3) +// let big_kqv_merged = ctx0.op_permute(&big_kqv, (0, 2, 1, 3)); + +// // cur = KQV_merged.contiguous().view(n_embd, N) +// current = ctx0.op_cpy( +// &big_kqv_merged, +// &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n), +// ); + +// // projection +// current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); + +// // feed forward uses second scratch buffer +// ctx0.use_scratch(builder.get_scratch(1)); + +// let inp_ff = layernorm_output.share(); +// let attn_out = +// ctx0.op_cpy(¤t, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); + +// current = ctx0.op_mul_mat(&self.layers[il].ffn_up, &inp_ff); +// current = ctx0.op_gelu(¤t); +// current = ctx0.op_mul_mat(&self.layers[il].ffn_down, ¤t); + +// current = ctx0.op_add(¤t, &attn_out); +// current = ctx0.op_add(¤t, &input_layer); + +// input_layer = current.share(); +// } + +// ctx0.use_scratch(builder.get_scratch(0)); + +// // norm +// input_layer = ctx0.op_norm(&input_layer); + +// input_layer = ctx0.op_add( +// &ctx0.op_mul(&input_layer, &self.output_norm), +// &self.output_norm_b, +// ); + +// let embeddings_tensor: ggml::Tensor = input_layer.share(); + +// ctx0.set_offloading(false); +// ctx0.use_scratch(None); + +// // lm_head +// input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer); + +// ( +// gf, +// GraphOutputs { +// result: input_layer, +// embedding_result: embeddings_tensor, +// }, +// ) +// }); + +// // finish evaluation +// common::read_last_token(session, &outputs.result, n_vocab, input_len); +// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); +// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); +// } + +// fn hyperparameters(&self) -> &Self::Hyperparameters { +// &self.hyperparameters +// } + +// fn tokenizer(&self) -> &Tokenizer { +// &self.tokenizer +// } + +// fn context_size(&self) -> usize { +// self.params.context_size +// } + +// fn bot_token_id(&self) -> Option { +// None +// } + +// fn eot_token_id(&self) -> TokenId { +// self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() +// } + +// fn quantize_tensors() -> Vec { +// vec![Regex::new(".*weight").unwrap()] +// } + +// fn skip_quantize_tensors() -> Vec { +// vec![] +// } +// } + +// /// Falcon [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +// #[derive(Debug, Default, PartialEq, Clone, Copy, Eq)] +// pub struct Hyperparameters { +// /// Size of the model's vocabulary +// n_vocab: usize, +// /// Size of the model's embedding layer +// n_embd: usize, +// /// n_heads +// n_head: usize, +// // Number of heads for key-value pairs +// n_head_kv: usize, +// /// Number of layers in the model +// n_layer: usize, +// /// file_type +// file_type: FileType, +// } + +// impl llm_base::Hyperparameters for Hyperparameters { +// fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { +// let hyperparameters = Hyperparameters { +// n_vocab: util::read_i32(reader)?.try_into()?, +// n_embd: util::read_i32(reader)?.try_into()?, +// n_head: util::read_i32(reader)?.try_into()?, +// n_head_kv: util::read_i32(reader)?.try_into()?, +// n_layer: util::read_i32(reader)?.try_into()?, +// file_type: util::read_filetype(reader)?, +// }; + +// Ok(hyperparameters) +// } + +// fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { +// util::write_i32(writer, self.n_vocab.try_into()?)?; +// util::write_i32(writer, self.n_embd.try_into()?)?; +// util::write_i32(writer, self.n_head.try_into()?)?; +// util::write_i32(writer, self.n_head_kv.try_into()?)?; +// util::write_i32(writer, self.n_layer.try_into()?)?; +// util::write_i32(writer, self.file_type.into())?; +// Ok(()) +// } + +// fn n_vocabulary(&self) -> usize { +// self.n_vocab +// } + +// fn file_type(&self) -> Option { +// Some(self.file_type) +// } + +// fn file_type_mut(&mut self) -> Option<&mut FileType> { +// Some(&mut self.file_type) +// } +// } + +// struct Layer { +// // normalization +// input_layernorm: Tensor, +// input_layernorm_b: Tensor, + +// // Falcon-40B only +// attention_norm: Option, +// attention_norm_b: Option, + +// // attention +// query_key_value: Tensor, +// wo: Tensor, + +// // ff +// ffn_up: Tensor, +// ffn_down: Tensor, +// } diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index b4434ad5..59933b67 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -1,464 +1,464 @@ -//! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem. -#![deny(missing_docs)] - -use ggml::Tensor; -use llm_base::{ - ggml, - model::{common, HyperparametersWriteError}, - util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, -}; - -/// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/) -/// -/// # Safety -/// This implements [Send] and [Sync] as it is immutable after construction. -pub struct Gpt2 { - params: ModelParameters, - - hyperparameters: Hyperparameters, - tokenizer: Tokenizer, - - // model-global weights - // normalization gain & bias - ln_f_g: Tensor, - ln_f_b: Tensor, - // weighted token embeddings - wte: Tensor, - // weighted positional encodings - wpe: Tensor, - // language model head - // - // Optional: if not present, the `wte` tensor is used instead. - lm_head: Option, - - // weights for the model - layers: Vec, - - // must be kept alive for the model - context: ModelContext, -} - -unsafe impl Send for Gpt2 {} -unsafe impl Sync for Gpt2 {} - -impl KnownModel for Gpt2 { - type Hyperparameters = Hyperparameters; - - fn new( - hyperparameters: Self::Hyperparameters, - params: ModelParameters, - tokenizer: Tokenizer, - tensor_loader: impl llm_base::TensorLoader, - ) -> Result { - let mut tl = tensor_loader; - - // model-global weights - let backend = params.backend(0); - - let wpe = tl.load("model/wpe")?.transfer_to(backend); - let wte = tl.load("model/wte")?.transfer_to(backend); - - let ln_f_g = tl.load("model/ln_f/g")?.transfer_to(backend); - let ln_f_b = tl.load("model/ln_f/b")?.transfer_to(backend); - - // GPT-2's language model head is optional; if it is not present, - // the `wte` tensor is used instead. - let lm_head = { - if let Ok(tensor) = tl.load("model/lm_head") { - Some(tensor.transfer_to(backend)) - } else { - None - } - }; - - let mut layers = Vec::new(); - for i in 0..hyperparameters.n_layer { - let backend = params.backend(i); - let layer = Layer { - ln_1_g: tl.load(&format!("model/h{i}/ln_1/g"))?.transfer_to(backend), - ln_1_b: tl.load(&format!("model/h{i}/ln_1/b"))?.transfer_to(backend), - ln_2_g: tl.load(&format!("model/h{i}/ln_2/g"))?.transfer_to(backend), - ln_2_b: tl.load(&format!("model/h{i}/ln_2/b"))?.transfer_to(backend), - c_attn_attn_w: tl - .load(&format!("model/h{i}/attn/c_attn/w"))? - .transfer_to(backend), - c_attn_attn_b: tl - .load(&format!("model/h{i}/attn/c_attn/b"))? - .transfer_to(backend), - c_attn_proj_w: tl - .load(&format!("model/h{i}/attn/c_proj/w"))? - .transfer_to(backend), - c_attn_proj_b: tl - .load(&format!("model/h{i}/attn/c_proj/b"))? - .transfer_to(backend), - c_mlp_fc_w: tl - .load(&format!("model/h{i}/mlp/c_fc/w"))? - .transfer_to(backend), - c_mlp_fc_b: tl - .load(&format!("model/h{i}/mlp/c_fc/b"))? - .transfer_to(backend), - c_mlp_proj_w: tl - .load(&format!("model/h{i}/mlp/c_proj/w"))? - .transfer_to(backend), - c_mlp_proj_b: tl - .load(&format!("model/h{i}/mlp/c_proj/b"))? - .transfer_to(backend), - }; - - layers.push(layer); - } - - let context = tl.finish(); - - Ok(Gpt2 { - hyperparameters, - params, - tokenizer, - layers, - ln_f_g, - ln_f_b, - wte, - wpe, - lm_head, - context, - }) - } - - fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { - InferenceSession::new( - config, - &self.params, - self.hyperparameters.n_layer, - self.hyperparameters.n_embd, - self.hyperparameters.n_vocab, - ) - } - - fn evaluate( - &self, - session: &mut InferenceSession, - input_tokens: &[TokenId], - output_request: &mut OutputRequest, - ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; - let ctx_size = self.params.context_size; - - let Hyperparameters { - n_embd, - n_head, - n_vocab, - n_layer, - .. - } = self.hyperparameters; - - let outputs = session.compute(self.context.clone(), input_tokens, |builder| { - let mut ctx0 = builder.ctx0.borrow_mut(); - let (memory_k_size, memory_v_size) = ( - builder.memory_k.element_size(), - builder.memory_v.element_size(), - ); - let embd = &builder.embd; - - let position_buf: Vec = (0..input_len).map(|i| (session_len + i) as i32).collect(); - - let mut position = ctx0.new_tensor_1d(ggml::Type::I32, input_len); - unsafe { position.write_data(bytemuck::cast_slice(&position_buf)) }; - - let mut input_layer = ctx0.op_add( - &ctx0.op_get_rows(&self.wte, embd), - &ctx0.op_get_rows(&self.wpe, &position), - ); - - let mut gf = ctx0.create_compute_graph(); - for il in 0..n_layer { - ctx0.set_offloading(self.params.should_offload(il)); - ctx0.use_scratch(builder.get_scratch(0)); - // norm - let mut current = ctx0.op_norm(&input_layer); - current = ctx0.op_add( - &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), - &self.layers[il].ln_1_b, - ); - - // attn - current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_attn_attn_b); - - // self-attn - let nb = current.get_nb()[1]; - let f32_size = std::mem::size_of::(); - let qcur = ctx0.op_view_2d(¤t, (n_embd, input_len), nb, 0); - let kcur = ctx0.op_view_2d(¤t, (n_embd, input_len), nb, f32_size * n_embd); - let vcur = - ctx0.op_view_2d(¤t, (n_embd, input_len), nb, f32_size * n_embd * 2); - - if input_len >= 1 { - let k = ctx0.op_view_1d( - builder.memory_k, - input_len * n_embd, - (memory_k_size * n_embd) * (il * ctx_size + session_len), - ); - let v = ctx0.op_view_1d( - builder.memory_v, - input_len * n_embd, - (memory_v_size * n_embd) * (il * ctx_size + session_len), - ); - - gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); - } - - let q = ctx0.op_permute( - &ctx0.op_cpy( - &qcur, - &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, input_len), - ), - (0, 2, 1, 3), - ); - - let k = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - builder.memory_k, - (session_len + input_len) * n_embd, - il * ctx_size * memory_k_size * n_embd, - ), - n_embd / n_head, - n_head, - session_len + input_len, - ), - (0, 2, 1, 3), - ); - - let kq = ctx0.op_mul_mat(&k, &q); - let kq_scaled = ctx0.op_scale_inplace( - &kq, - &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - - let kq_masked = ctx0.op_diag_mask_inf_inplace(&kq_scaled, session_len); - let kq_softmax = ctx0.op_soft_max_inplace(&kq_masked); - - let v_trans = ctx0.op_cpy( - &ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - builder.memory_v, - (session_len + input_len) * n_embd, - il * ctx_size * memory_v_size * n_embd, - ), - n_embd / n_head, - n_head, - session_len + input_len, - ), - (1, 2, 0, 3), - ), - &ctx0.new_tensor_3d( - builder.memory_v.get_type(), - session_len + input_len, - n_embd / n_head, - n_head, - ), - ); - - let kqv = ctx0.op_mul_mat(&v_trans, &kq_softmax); - let kqv_merged = ctx0.op_permute(&kqv, (0, 2, 1, 3)); - - current = ctx0.op_cpy( - &kqv_merged, - &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), - ); - - // projection - current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_attn_proj_b); - - // add input - current = ctx0.op_add(¤t, &input_layer); - - // feed-forward - let ff_in = current.share(); - - ctx0.use_scratch(builder.get_scratch(1)); - - // feed-forward normalization - current = ctx0.op_norm(&ff_in); - current = ctx0.op_add( - &ctx0.op_mul(¤t, &self.layers[il].ln_2_g), - &self.layers[il].ln_2_b, - ); - - // feed-forward fully connected - current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_mlp_fc_b); - - // feed-forward activation - current = ctx0.op_gelu(¤t); - - // feed-forward projection - current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_mlp_proj_b); - - // input for next layer - input_layer = ctx0.op_add(¤t, &ff_in); - } - - ctx0.use_scratch(builder.get_scratch(0)); - - // normalization - input_layer = ctx0.op_norm(&input_layer); - input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); - - ctx0.use_scratch(None); - ctx0.set_offloading(false); - - let embeddings_tensor: ggml::Tensor = input_layer.share(); - - let head = self.lm_head.as_ref().unwrap_or(&self.wte); - input_layer = ctx0.op_mul_mat(head, &input_layer); - - ( - gf, - GraphOutputs { - result: input_layer, - embedding_result: embeddings_tensor, - }, - ) - }); - - // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); - } - - fn hyperparameters(&self) -> &Self::Hyperparameters { - &self.hyperparameters - } - - fn tokenizer(&self) -> &Tokenizer { - &self.tokenizer - } - - fn context_size(&self) -> usize { - self.params.context_size - } - - fn bot_token_id(&self) -> Option { - None - } - - fn eot_token_id(&self) -> TokenId { - self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() - } - - fn quantize_tensors() -> Vec { - [ - "model/wte", - "model/lm_head", - "model/h.*/attn/c_attn/w", - "model/h.*/attn/c_proj/w", - "model/h.*/mlp/c_fc/w", - "model/h.*/mlp/c_proj/w", - ] - .into_iter() - .map(|s| Regex::new(s).unwrap()) - .collect() - } - - fn skip_quantize_tensors() -> Vec { - vec![] - } -} - -/// GPT-2 [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct Hyperparameters { - /// Size of the model's vocabulary - n_vocab: usize, - /// Size of the model's context - n_ctx: usize, - /// Size of the model's embedding layer - n_embd: usize, - /// n_head - n_head: usize, - /// Number of layers in the model - n_layer: usize, - /// file type - file_type: FileType, -} - -impl llm_base::Hyperparameters for Hyperparameters { - fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { - let hyperparameters = Hyperparameters { - n_vocab: util::read_i32(reader)?.try_into()?, - n_ctx: util::read_i32(reader)?.try_into()?, - n_embd: util::read_i32(reader)?.try_into()?, - n_head: util::read_i32(reader)?.try_into()?, - n_layer: util::read_i32(reader)?.try_into()?, - file_type: util::read_filetype(reader)?, - }; - - let n_vocab = util::read_i32(reader)? as usize; - if hyperparameters.n_vocab != n_vocab { - return Err(LoadError::InvariantBroken { - path: None, - invariant: format!( - "GPT2 model expected n_vocab {} found {}", - hyperparameters.n_vocab, n_vocab - ), - }); - } - - Ok(hyperparameters) - } - - fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { - util::write_i32(writer, self.n_vocab.try_into()?)?; - util::write_i32(writer, self.n_ctx.try_into()?)?; - util::write_i32(writer, self.n_embd.try_into()?)?; - util::write_i32(writer, self.n_head.try_into()?)?; - util::write_i32(writer, self.n_layer.try_into()?)?; - util::write_i32(writer, self.file_type.into())?; - util::write_i32(writer, self.n_vocab.try_into()?)?; - - Ok(()) - } - - fn n_vocabulary(&self) -> usize { - self.n_vocab - } - - fn file_type(&self) -> Option { - Some(self.file_type) - } - - fn file_type_mut(&mut self) -> Option<&mut FileType> { - Some(&mut self.file_type) - } -} - -struct Layer { - // normalization - ln_1_g: Tensor, - ln_1_b: Tensor, - - ln_2_g: Tensor, - ln_2_b: Tensor, - - // attention - c_attn_attn_w: Tensor, - c_attn_attn_b: Tensor, - - c_attn_proj_w: Tensor, - c_attn_proj_b: Tensor, - - // mlp - c_mlp_fc_w: Tensor, - c_mlp_fc_b: Tensor, - - c_mlp_proj_w: Tensor, - c_mlp_proj_b: Tensor, -} +// //! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem. +// #![deny(missing_docs)] + +// use ggml::Tensor; +// use llm_base::{ +// ggml, +// model::{common, HyperparametersWriteError}, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, +// }; + +// /// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/) +// /// +// /// # Safety +// /// This implements [Send] and [Sync] as it is immutable after construction. +// pub struct Gpt2 { +// params: ModelParameters, + +// hyperparameters: Hyperparameters, +// tokenizer: Tokenizer, + +// // model-global weights +// // normalization gain & bias +// ln_f_g: Tensor, +// ln_f_b: Tensor, +// // weighted token embeddings +// wte: Tensor, +// // weighted positional encodings +// wpe: Tensor, +// // language model head +// // +// // Optional: if not present, the `wte` tensor is used instead. +// lm_head: Option, + +// // weights for the model +// layers: Vec, + +// // must be kept alive for the model +// context: ModelContext, +// } + +// unsafe impl Send for Gpt2 {} +// unsafe impl Sync for Gpt2 {} + +// impl KnownModel for Gpt2 { +// type Hyperparameters = Hyperparameters; + +// fn new( +// hyperparameters: Self::Hyperparameters, +// params: ModelParameters, +// tokenizer: Tokenizer, +// tensor_loader: impl llm_base::TensorLoader, +// ) -> Result { +// let mut tl = tensor_loader; + +// // model-global weights +// let backend = params.backend(0); + +// let wpe = tl.load("model/wpe")?.transfer_to(backend); +// let wte = tl.load("model/wte")?.transfer_to(backend); + +// let ln_f_g = tl.load("model/ln_f/g")?.transfer_to(backend); +// let ln_f_b = tl.load("model/ln_f/b")?.transfer_to(backend); + +// // GPT-2's language model head is optional; if it is not present, +// // the `wte` tensor is used instead. +// let lm_head = { +// if let Ok(tensor) = tl.load("model/lm_head") { +// Some(tensor.transfer_to(backend)) +// } else { +// None +// } +// }; + +// let mut layers = Vec::new(); +// for i in 0..hyperparameters.n_layer { +// let backend = params.backend(i); +// let layer = Layer { +// ln_1_g: tl.load(&format!("model/h{i}/ln_1/g"))?.transfer_to(backend), +// ln_1_b: tl.load(&format!("model/h{i}/ln_1/b"))?.transfer_to(backend), +// ln_2_g: tl.load(&format!("model/h{i}/ln_2/g"))?.transfer_to(backend), +// ln_2_b: tl.load(&format!("model/h{i}/ln_2/b"))?.transfer_to(backend), +// c_attn_attn_w: tl +// .load(&format!("model/h{i}/attn/c_attn/w"))? +// .transfer_to(backend), +// c_attn_attn_b: tl +// .load(&format!("model/h{i}/attn/c_attn/b"))? +// .transfer_to(backend), +// c_attn_proj_w: tl +// .load(&format!("model/h{i}/attn/c_proj/w"))? +// .transfer_to(backend), +// c_attn_proj_b: tl +// .load(&format!("model/h{i}/attn/c_proj/b"))? +// .transfer_to(backend), +// c_mlp_fc_w: tl +// .load(&format!("model/h{i}/mlp/c_fc/w"))? +// .transfer_to(backend), +// c_mlp_fc_b: tl +// .load(&format!("model/h{i}/mlp/c_fc/b"))? +// .transfer_to(backend), +// c_mlp_proj_w: tl +// .load(&format!("model/h{i}/mlp/c_proj/w"))? +// .transfer_to(backend), +// c_mlp_proj_b: tl +// .load(&format!("model/h{i}/mlp/c_proj/b"))? +// .transfer_to(backend), +// }; + +// layers.push(layer); +// } + +// let context = tl.finish(); + +// Ok(Gpt2 { +// hyperparameters, +// params, +// tokenizer, +// layers, +// ln_f_g, +// ln_f_b, +// wte, +// wpe, +// lm_head, +// context, +// }) +// } + +// fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { +// InferenceSession::new( +// config, +// &self.params, +// self.hyperparameters.n_layer, +// self.hyperparameters.n_embd, +// self.hyperparameters.n_vocab, +// ) +// } + +// fn evaluate( +// &self, +// session: &mut InferenceSession, +// input_tokens: &[TokenId], +// output_request: &mut OutputRequest, +// ) { +// let input_len = input_tokens.len(); +// let session_len = session.n_past; +// let ctx_size = self.params.context_size; + +// let Hyperparameters { +// n_embd, +// n_head, +// n_vocab, +// n_layer, +// .. +// } = self.hyperparameters; + +// let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let mut ctx0 = builder.ctx0.borrow_mut(); +// let (memory_k_size, memory_v_size) = ( +// builder.memory_k.element_size(), +// builder.memory_v.element_size(), +// ); +// let embd = &builder.embd; + +// let position_buf: Vec = (0..input_len).map(|i| (session_len + i) as i32).collect(); + +// let mut position = ctx0.new_tensor_1d(ggml::Type::I32, input_len); +// unsafe { position.write_data(bytemuck::cast_slice(&position_buf)) }; + +// let mut input_layer = ctx0.op_add( +// &ctx0.op_get_rows(&self.wte, embd), +// &ctx0.op_get_rows(&self.wpe, &position), +// ); + +// let mut gf = ctx0.create_compute_graph(); +// for il in 0..n_layer { +// ctx0.set_offloading(self.params.should_offload(il)); +// ctx0.use_scratch(builder.get_scratch(0)); +// // norm +// let mut current = ctx0.op_norm(&input_layer); +// current = ctx0.op_add( +// &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), +// &self.layers[il].ln_1_b, +// ); + +// // attn +// current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].c_attn_attn_b); + +// // self-attn +// let nb = current.get_nb()[1]; +// let f32_size = std::mem::size_of::(); +// let qcur = ctx0.op_view_2d(¤t, (n_embd, input_len), nb, 0); +// let kcur = ctx0.op_view_2d(¤t, (n_embd, input_len), nb, f32_size * n_embd); +// let vcur = +// ctx0.op_view_2d(¤t, (n_embd, input_len), nb, f32_size * n_embd * 2); + +// if input_len >= 1 { +// let k = ctx0.op_view_1d( +// builder.memory_k, +// input_len * n_embd, +// (memory_k_size * n_embd) * (il * ctx_size + session_len), +// ); +// let v = ctx0.op_view_1d( +// builder.memory_v, +// input_len * n_embd, +// (memory_v_size * n_embd) * (il * ctx_size + session_len), +// ); + +// gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); +// gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); +// } + +// let q = ctx0.op_permute( +// &ctx0.op_cpy( +// &qcur, +// &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, input_len), +// ), +// (0, 2, 1, 3), +// ); + +// let k = ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// builder.memory_k, +// (session_len + input_len) * n_embd, +// il * ctx_size * memory_k_size * n_embd, +// ), +// n_embd / n_head, +// n_head, +// session_len + input_len, +// ), +// (0, 2, 1, 3), +// ); + +// let kq = ctx0.op_mul_mat(&k, &q); +// let kq_scaled = ctx0.op_scale_inplace( +// &kq, +// &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), +// ); + +// let kq_masked = ctx0.op_diag_mask_inf_inplace(&kq_scaled, session_len); +// let kq_softmax = ctx0.op_soft_max_inplace(&kq_masked); + +// let v_trans = ctx0.op_cpy( +// &ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// builder.memory_v, +// (session_len + input_len) * n_embd, +// il * ctx_size * memory_v_size * n_embd, +// ), +// n_embd / n_head, +// n_head, +// session_len + input_len, +// ), +// (1, 2, 0, 3), +// ), +// &ctx0.new_tensor_3d( +// builder.memory_v.get_type(), +// session_len + input_len, +// n_embd / n_head, +// n_head, +// ), +// ); + +// let kqv = ctx0.op_mul_mat(&v_trans, &kq_softmax); +// let kqv_merged = ctx0.op_permute(&kqv, (0, 2, 1, 3)); + +// current = ctx0.op_cpy( +// &kqv_merged, +// &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), +// ); + +// // projection +// current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].c_attn_proj_b); + +// // add input +// current = ctx0.op_add(¤t, &input_layer); + +// // feed-forward +// let ff_in = current.share(); + +// ctx0.use_scratch(builder.get_scratch(1)); + +// // feed-forward normalization +// current = ctx0.op_norm(&ff_in); +// current = ctx0.op_add( +// &ctx0.op_mul(¤t, &self.layers[il].ln_2_g), +// &self.layers[il].ln_2_b, +// ); + +// // feed-forward fully connected +// current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].c_mlp_fc_b); + +// // feed-forward activation +// current = ctx0.op_gelu(¤t); + +// // feed-forward projection +// current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].c_mlp_proj_b); + +// // input for next layer +// input_layer = ctx0.op_add(¤t, &ff_in); +// } + +// ctx0.use_scratch(builder.get_scratch(0)); + +// // normalization +// input_layer = ctx0.op_norm(&input_layer); +// input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); + +// ctx0.use_scratch(None); +// ctx0.set_offloading(false); + +// let embeddings_tensor: ggml::Tensor = input_layer.share(); + +// let head = self.lm_head.as_ref().unwrap_or(&self.wte); +// input_layer = ctx0.op_mul_mat(head, &input_layer); + +// ( +// gf, +// GraphOutputs { +// result: input_layer, +// embedding_result: embeddings_tensor, +// }, +// ) +// }); + +// // finish evaluation +// common::read_last_token(session, &outputs.result, n_vocab, input_len); +// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); +// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); +// } + +// fn hyperparameters(&self) -> &Self::Hyperparameters { +// &self.hyperparameters +// } + +// fn tokenizer(&self) -> &Tokenizer { +// &self.tokenizer +// } + +// fn context_size(&self) -> usize { +// self.params.context_size +// } + +// fn bot_token_id(&self) -> Option { +// None +// } + +// fn eot_token_id(&self) -> TokenId { +// self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() +// } + +// fn quantize_tensors() -> Vec { +// [ +// "model/wte", +// "model/lm_head", +// "model/h.*/attn/c_attn/w", +// "model/h.*/attn/c_proj/w", +// "model/h.*/mlp/c_fc/w", +// "model/h.*/mlp/c_proj/w", +// ] +// .into_iter() +// .map(|s| Regex::new(s).unwrap()) +// .collect() +// } + +// fn skip_quantize_tensors() -> Vec { +// vec![] +// } +// } + +// /// GPT-2 [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +// #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +// pub struct Hyperparameters { +// /// Size of the model's vocabulary +// n_vocab: usize, +// /// Size of the model's context +// n_ctx: usize, +// /// Size of the model's embedding layer +// n_embd: usize, +// /// n_head +// n_head: usize, +// /// Number of layers in the model +// n_layer: usize, +// /// file type +// file_type: FileType, +// } + +// impl llm_base::Hyperparameters for Hyperparameters { +// fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { +// let hyperparameters = Hyperparameters { +// n_vocab: util::read_i32(reader)?.try_into()?, +// n_ctx: util::read_i32(reader)?.try_into()?, +// n_embd: util::read_i32(reader)?.try_into()?, +// n_head: util::read_i32(reader)?.try_into()?, +// n_layer: util::read_i32(reader)?.try_into()?, +// file_type: util::read_filetype(reader)?, +// }; + +// let n_vocab = util::read_i32(reader)? as usize; +// if hyperparameters.n_vocab != n_vocab { +// return Err(LoadError::InvariantBroken { +// path: None, +// invariant: format!( +// "GPT2 model expected n_vocab {} found {}", +// hyperparameters.n_vocab, n_vocab +// ), +// }); +// } + +// Ok(hyperparameters) +// } + +// fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { +// util::write_i32(writer, self.n_vocab.try_into()?)?; +// util::write_i32(writer, self.n_ctx.try_into()?)?; +// util::write_i32(writer, self.n_embd.try_into()?)?; +// util::write_i32(writer, self.n_head.try_into()?)?; +// util::write_i32(writer, self.n_layer.try_into()?)?; +// util::write_i32(writer, self.file_type.into())?; +// util::write_i32(writer, self.n_vocab.try_into()?)?; + +// Ok(()) +// } + +// fn n_vocabulary(&self) -> usize { +// self.n_vocab +// } + +// fn file_type(&self) -> Option { +// Some(self.file_type) +// } + +// fn file_type_mut(&mut self) -> Option<&mut FileType> { +// Some(&mut self.file_type) +// } +// } + +// struct Layer { +// // normalization +// ln_1_g: Tensor, +// ln_1_b: Tensor, + +// ln_2_g: Tensor, +// ln_2_b: Tensor, + +// // attention +// c_attn_attn_w: Tensor, +// c_attn_attn_b: Tensor, + +// c_attn_proj_w: Tensor, +// c_attn_proj_b: Tensor, + +// // mlp +// c_mlp_fc_w: Tensor, +// c_mlp_fc_b: Tensor, + +// c_mlp_proj_w: Tensor, +// c_mlp_proj_b: Tensor, +// } diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index c013625a..dd70728f 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -1,434 +1,434 @@ -//! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem. -#![deny(missing_docs)] - -use std::error::Error; - -use ggml::Tensor; -use llm_base::{ - ggml, - model::{common, HyperparametersWriteError}, - util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, -}; - -/// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b) -/// -/// # Safety -/// This implements [Send] and [Sync] as it is immutable after construction. -pub struct GptJ { - params: ModelParameters, - - hyperparameters: Hyperparameters, - tokenizer: Tokenizer, - - // model-global weights - // normalization gain & bias - ln_f_g: Tensor, - ln_f_b: Tensor, - // weighted token embeddings - wte: Tensor, - // language model head gain & bias - lmh_g: Tensor, - lmh_b: Tensor, - - // weights for the model - layers: Vec, - - // must be kept alive for the model - context: ModelContext, -} - -unsafe impl Send for GptJ {} -unsafe impl Sync for GptJ {} - -impl KnownModel for GptJ { - type Hyperparameters = Hyperparameters; - - fn new( - hyperparameters: Self::Hyperparameters, - params: ModelParameters, - tokenizer: Tokenizer, - tensor_loader: impl TensorLoader, - ) -> Result - where - Self: Sized, - { - let mut tl = tensor_loader; - - // model-global weights - let wte = tl.load("transformer.wte.weight")?; - - let backend = params.backend(0); - - let ln_f_g = tl.load("transformer.ln_f.weight")?.transfer_to(backend); - let ln_f_b = tl.load("transformer.ln_f.bias")?.transfer_to(backend); - let lmh_g = tl.load("lm_head.weight")?.transfer_to(backend); - let lmh_b = tl.load("lm_head.bias")?.transfer_to(backend); - - let mut layers = Vec::new(); - for i in 0..hyperparameters.n_layer { - let backend = params.backend(i); - - let layer = Layer { - ln_1_g: tl - .load(&format!("transformer.h.{i}.ln_1.weight"))? - .transfer_to(backend), - ln_1_b: tl - .load(&format!("transformer.h.{i}.ln_1.bias"))? - .transfer_to(backend), - c_attn_q_proj_w: tl - .load(&format!("transformer.h.{i}.attn.q_proj.weight"))? - .transfer_to(backend), - c_attn_k_proj_w: tl - .load(&format!("transformer.h.{i}.attn.k_proj.weight"))? - .transfer_to(backend), - c_attn_v_proj_w: tl - .load(&format!("transformer.h.{i}.attn.v_proj.weight"))? - .transfer_to(backend), - c_attn_proj_w: tl - .load(&format!("transformer.h.{i}.attn.out_proj.weight"))? - .transfer_to(backend), - c_mlp_fc_w: tl - .load(&format!("transformer.h.{i}.mlp.fc_in.weight"))? - .transfer_to(backend), - c_mlp_fc_b: tl - .load(&format!("transformer.h.{i}.mlp.fc_in.bias"))? - .transfer_to(backend), - c_mlp_proj_w: tl - .load(&format!("transformer.h.{i}.mlp.fc_out.weight"))? - .transfer_to(backend), - c_mlp_proj_b: tl - .load(&format!("transformer.h.{i}.mlp.fc_out.bias"))? - .transfer_to(backend), - }; - - layers.push(layer); - } - - let context = tl.finish(); - - Ok(GptJ { - hyperparameters, - params, - tokenizer, - ln_f_g, - ln_f_b, - wte, - lmh_g, - lmh_b, - layers, - context, - }) - } - - fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { - InferenceSession::new( - config, - &self.params, - self.hyperparameters.n_layer, - self.hyperparameters.n_embd, - self.hyperparameters.n_vocab, - ) - } - - fn evaluate( - &self, - session: &mut InferenceSession, - input_tokens: &[TokenId], - output_request: &mut OutputRequest, - ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; - let ctx_size = self.params.context_size; - - let Hyperparameters { - n_embd, - n_head, - n_vocab, - n_layer, - n_rot, - .. - } = self.hyperparameters; - - let outputs = session.compute(self.context.clone(), input_tokens, |builder| { - let mut ctx0 = builder.ctx0.borrow_mut(); - let (memory_k_size, memory_v_size) = ( - builder.memory_k.element_size(), - builder.memory_v.element_size(), - ); - let embd = builder.embd; - - let mut input_layer = ctx0.op_get_rows(&self.wte, embd); - - let mut gf = ctx0.create_compute_graph(); - for il in 0..n_layer { - ctx0.set_offloading(self.params.should_offload(il)); - - // norm - let mut current = ctx0.op_norm(&input_layer); - current = ctx0.op_add( - &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), - &self.layers[il].ln_1_b, - ); - - let input_sa = current.share(); - - // self-attention - let overrides = self.params.rope_overrides.as_ref(); - let qcur = ctx0.op_rope_inplace( - &ctx0.op_reshape_3d( - &ctx0.op_mul_mat(&self.layers[il].c_attn_q_proj_w, ¤t), - n_embd / n_head, - n_head, - input_len, - ), - session_len, - n_rot, - 0, - overrides, - ); - let kcur = ctx0.op_rope_inplace( - &ctx0.op_reshape_3d( - &ctx0.op_mul_mat(&self.layers[il].c_attn_k_proj_w, ¤t), - n_embd / n_head, - n_head, - input_len, - ), - session_len, - n_rot, - 0, - overrides, - ); - - // self-attention store key and value to memory - let vcur = - ctx0.op_transpose(&ctx0.op_mul_mat(&self.layers[il].c_attn_v_proj_w, ¤t)); - - let k = ctx0.op_view_1d( - builder.memory_k, - input_len * n_embd, - (memory_k_size * n_embd) * (il * ctx_size + session_len), - ); - let v = ctx0.op_view_2d( - builder.memory_v, - (input_len, n_embd), - ctx_size * memory_v_size, - (il * ctx_size) * memory_v_size * n_embd + session_len * memory_v_size, - ); - - gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); - - let q = ctx0.op_permute(&qcur, (0, 2, 1, 3)); - let big_k = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - builder.memory_k, - (session_len + input_len) * n_embd, - il * ctx_size * memory_k_size * n_embd, - ), - n_embd / n_head, - n_head, - session_len + input_len, - ), - (0, 2, 1, 3), - ); - - let kq = ctx0.op_mul_mat(&big_k, &q); - let kq_scaled = ctx0.op_scale_inplace( - &kq, - &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - - let kq_masked = ctx0.op_diag_mask_inf_inplace(&kq_scaled, session_len); - let kq_softmax = ctx0.op_soft_max_inplace(&kq_masked); - - let big_v = ctx0.op_view_3d( - builder.memory_v, - (session_len + input_len, n_embd / n_head, n_head), - ( - ctx_size * memory_v_size, - ctx_size * memory_v_size * n_embd / n_head, - ), - il * ctx_size * memory_v_size * n_embd, - ); - - let kqv = ctx0.op_mul_mat(&big_v, &kq_softmax); - let kqv_merged = ctx0.op_permute(&kqv, (0, 2, 1, 3)); - - current = ctx0.op_cpy( - &kqv_merged, - &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), - ); - - // self-attention projection - current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); - - // feed-forward - let ff_in = current.share(); - - current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, &input_sa); - current = ctx0.op_add(¤t, &self.layers[il].c_mlp_fc_b); - - current = ctx0.op_gelu(¤t); - - // feed-forward projection - current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_mlp_proj_b); - - current = ctx0.op_add(¤t, &ff_in); - - // input for next layer - input_layer = ctx0.op_add(¤t, &input_layer); - } - - // norm - input_layer = ctx0.op_norm(&input_layer); - input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); - - let embeddings_tensor: ggml::Tensor = input_layer.share(); - - // lm_head - input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); - - ctx0.set_offloading(false); - - input_layer = ctx0.op_add(&input_layer, &self.lmh_b); - - ( - gf, - GraphOutputs { - result: input_layer, - embedding_result: embeddings_tensor, - }, - ) - }); - - // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, input_len); - common::extract_logits(output_request, &outputs.result, n_vocab, input_len); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); - } - - fn hyperparameters(&self) -> &Self::Hyperparameters { - &self.hyperparameters - } - - fn tokenizer(&self) -> &Tokenizer { - &self.tokenizer - } - - fn context_size(&self) -> usize { - self.params.context_size - } - - fn bot_token_id(&self) -> Option { - None - } - - fn eot_token_id(&self) -> TokenId { - self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() - } - - fn quantize_tensors() -> Vec { - vec![Regex::new(".*weight").unwrap()] - } - - fn skip_quantize_tensors() -> Vec { - vec![] - } - - fn supports_rewind(&self) -> bool { - true - } -} - -/// GPT-J [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct Hyperparameters { - /// Size of the model's vocabulary - pub n_vocab: usize, - /// Size of the model's context - pub n_ctx: usize, - /// Size of the model's embedding layer - pub n_embd: usize, - /// n_head - pub n_head: usize, - /// Number of layers in the model - pub n_layer: usize, - /// n_rot - pub n_rot: usize, - /// file_type - pub file_type: FileType, -} - -impl llm_base::Hyperparameters for Hyperparameters { - fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { - let hyperparameters = Hyperparameters { - n_vocab: util::read_i32(reader)?.try_into()?, - n_ctx: util::read_i32(reader)?.try_into()?, - n_embd: util::read_i32(reader)?.try_into()?, - n_head: util::read_i32(reader)?.try_into()?, - n_layer: util::read_i32(reader)?.try_into()?, - n_rot: util::read_i32(reader)?.try_into()?, - file_type: util::read_filetype(reader)?, - }; - - let n_vocab = util::read_i32(reader)? as usize; - if hyperparameters.n_vocab != n_vocab { - return Err(LoadError::InvariantBroken { - path: None, - invariant: format!( - "GPTJ model expected n_vocab {} found {}", - hyperparameters.n_vocab, n_vocab - ), - }); - } - - Ok(hyperparameters) - } - - fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { - util::write_i32(writer, self.n_vocab.try_into()?)?; - util::write_i32(writer, self.n_ctx.try_into()?)?; - util::write_i32(writer, self.n_embd.try_into()?)?; - util::write_i32(writer, self.n_head.try_into()?)?; - util::write_i32(writer, self.n_layer.try_into()?)?; - util::write_i32(writer, self.n_rot.try_into()?)?; - util::write_i32(writer, self.file_type.into())?; - util::write_i32(writer, self.n_vocab.try_into()?)?; - Ok(()) - } - - fn n_vocabulary(&self) -> usize { - self.n_vocab - } - - fn file_type(&self) -> Option { - Some(self.file_type) - } - - fn file_type_mut(&mut self) -> Option<&mut FileType> { - Some(&mut self.file_type) - } -} - -struct Layer { - // normalization - ln_1_g: Tensor, - ln_1_b: Tensor, - - // attention - c_attn_q_proj_w: Tensor, - c_attn_k_proj_w: Tensor, - c_attn_v_proj_w: Tensor, - - c_attn_proj_w: Tensor, - - // ff - c_mlp_fc_w: Tensor, - c_mlp_fc_b: Tensor, - - c_mlp_proj_w: Tensor, - c_mlp_proj_b: Tensor, -} +// //! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem. +// #![deny(missing_docs)] + +// use std::error::Error; + +// use ggml::Tensor; +// use llm_base::{ +// ggml, +// model::{common, HyperparametersWriteError}, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, +// }; + +// /// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b) +// /// +// /// # Safety +// /// This implements [Send] and [Sync] as it is immutable after construction. +// pub struct GptJ { +// params: ModelParameters, + +// hyperparameters: Hyperparameters, +// tokenizer: Tokenizer, + +// // model-global weights +// // normalization gain & bias +// ln_f_g: Tensor, +// ln_f_b: Tensor, +// // weighted token embeddings +// wte: Tensor, +// // language model head gain & bias +// lmh_g: Tensor, +// lmh_b: Tensor, + +// // weights for the model +// layers: Vec, + +// // must be kept alive for the model +// context: ModelContext, +// } + +// unsafe impl Send for GptJ {} +// unsafe impl Sync for GptJ {} + +// impl KnownModel for GptJ { +// type Hyperparameters = Hyperparameters; + +// fn new( +// hyperparameters: Self::Hyperparameters, +// params: ModelParameters, +// tokenizer: Tokenizer, +// tensor_loader: impl TensorLoader, +// ) -> Result +// where +// Self: Sized, +// { +// let mut tl = tensor_loader; + +// // model-global weights +// let wte = tl.load("transformer.wte.weight")?; + +// let backend = params.backend(0); + +// let ln_f_g = tl.load("transformer.ln_f.weight")?.transfer_to(backend); +// let ln_f_b = tl.load("transformer.ln_f.bias")?.transfer_to(backend); +// let lmh_g = tl.load("lm_head.weight")?.transfer_to(backend); +// let lmh_b = tl.load("lm_head.bias")?.transfer_to(backend); + +// let mut layers = Vec::new(); +// for i in 0..hyperparameters.n_layer { +// let backend = params.backend(i); + +// let layer = Layer { +// ln_1_g: tl +// .load(&format!("transformer.h.{i}.ln_1.weight"))? +// .transfer_to(backend), +// ln_1_b: tl +// .load(&format!("transformer.h.{i}.ln_1.bias"))? +// .transfer_to(backend), +// c_attn_q_proj_w: tl +// .load(&format!("transformer.h.{i}.attn.q_proj.weight"))? +// .transfer_to(backend), +// c_attn_k_proj_w: tl +// .load(&format!("transformer.h.{i}.attn.k_proj.weight"))? +// .transfer_to(backend), +// c_attn_v_proj_w: tl +// .load(&format!("transformer.h.{i}.attn.v_proj.weight"))? +// .transfer_to(backend), +// c_attn_proj_w: tl +// .load(&format!("transformer.h.{i}.attn.out_proj.weight"))? +// .transfer_to(backend), +// c_mlp_fc_w: tl +// .load(&format!("transformer.h.{i}.mlp.fc_in.weight"))? +// .transfer_to(backend), +// c_mlp_fc_b: tl +// .load(&format!("transformer.h.{i}.mlp.fc_in.bias"))? +// .transfer_to(backend), +// c_mlp_proj_w: tl +// .load(&format!("transformer.h.{i}.mlp.fc_out.weight"))? +// .transfer_to(backend), +// c_mlp_proj_b: tl +// .load(&format!("transformer.h.{i}.mlp.fc_out.bias"))? +// .transfer_to(backend), +// }; + +// layers.push(layer); +// } + +// let context = tl.finish(); + +// Ok(GptJ { +// hyperparameters, +// params, +// tokenizer, +// ln_f_g, +// ln_f_b, +// wte, +// lmh_g, +// lmh_b, +// layers, +// context, +// }) +// } + +// fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { +// InferenceSession::new( +// config, +// &self.params, +// self.hyperparameters.n_layer, +// self.hyperparameters.n_embd, +// self.hyperparameters.n_vocab, +// ) +// } + +// fn evaluate( +// &self, +// session: &mut InferenceSession, +// input_tokens: &[TokenId], +// output_request: &mut OutputRequest, +// ) { +// let input_len = input_tokens.len(); +// let session_len = session.n_past; +// let ctx_size = self.params.context_size; + +// let Hyperparameters { +// n_embd, +// n_head, +// n_vocab, +// n_layer, +// n_rot, +// .. +// } = self.hyperparameters; + +// let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let mut ctx0 = builder.ctx0.borrow_mut(); +// let (memory_k_size, memory_v_size) = ( +// builder.memory_k.element_size(), +// builder.memory_v.element_size(), +// ); +// let embd = builder.embd; + +// let mut input_layer = ctx0.op_get_rows(&self.wte, embd); + +// let mut gf = ctx0.create_compute_graph(); +// for il in 0..n_layer { +// ctx0.set_offloading(self.params.should_offload(il)); + +// // norm +// let mut current = ctx0.op_norm(&input_layer); +// current = ctx0.op_add( +// &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), +// &self.layers[il].ln_1_b, +// ); + +// let input_sa = current.share(); + +// // self-attention +// let overrides = self.params.rope_overrides.as_ref(); +// let qcur = ctx0.op_rope_inplace( +// &ctx0.op_reshape_3d( +// &ctx0.op_mul_mat(&self.layers[il].c_attn_q_proj_w, ¤t), +// n_embd / n_head, +// n_head, +// input_len, +// ), +// session_len, +// n_rot, +// 0, +// overrides, +// ); +// let kcur = ctx0.op_rope_inplace( +// &ctx0.op_reshape_3d( +// &ctx0.op_mul_mat(&self.layers[il].c_attn_k_proj_w, ¤t), +// n_embd / n_head, +// n_head, +// input_len, +// ), +// session_len, +// n_rot, +// 0, +// overrides, +// ); + +// // self-attention store key and value to memory +// let vcur = +// ctx0.op_transpose(&ctx0.op_mul_mat(&self.layers[il].c_attn_v_proj_w, ¤t)); + +// let k = ctx0.op_view_1d( +// builder.memory_k, +// input_len * n_embd, +// (memory_k_size * n_embd) * (il * ctx_size + session_len), +// ); +// let v = ctx0.op_view_2d( +// builder.memory_v, +// (input_len, n_embd), +// ctx_size * memory_v_size, +// (il * ctx_size) * memory_v_size * n_embd + session_len * memory_v_size, +// ); + +// gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); +// gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); + +// let q = ctx0.op_permute(&qcur, (0, 2, 1, 3)); +// let big_k = ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// builder.memory_k, +// (session_len + input_len) * n_embd, +// il * ctx_size * memory_k_size * n_embd, +// ), +// n_embd / n_head, +// n_head, +// session_len + input_len, +// ), +// (0, 2, 1, 3), +// ); + +// let kq = ctx0.op_mul_mat(&big_k, &q); +// let kq_scaled = ctx0.op_scale_inplace( +// &kq, +// &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), +// ); + +// let kq_masked = ctx0.op_diag_mask_inf_inplace(&kq_scaled, session_len); +// let kq_softmax = ctx0.op_soft_max_inplace(&kq_masked); + +// let big_v = ctx0.op_view_3d( +// builder.memory_v, +// (session_len + input_len, n_embd / n_head, n_head), +// ( +// ctx_size * memory_v_size, +// ctx_size * memory_v_size * n_embd / n_head, +// ), +// il * ctx_size * memory_v_size * n_embd, +// ); + +// let kqv = ctx0.op_mul_mat(&big_v, &kq_softmax); +// let kqv_merged = ctx0.op_permute(&kqv, (0, 2, 1, 3)); + +// current = ctx0.op_cpy( +// &kqv_merged, +// &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), +// ); + +// // self-attention projection +// current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); + +// // feed-forward +// let ff_in = current.share(); + +// current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, &input_sa); +// current = ctx0.op_add(¤t, &self.layers[il].c_mlp_fc_b); + +// current = ctx0.op_gelu(¤t); + +// // feed-forward projection +// current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].c_mlp_proj_b); + +// current = ctx0.op_add(¤t, &ff_in); + +// // input for next layer +// input_layer = ctx0.op_add(¤t, &input_layer); +// } + +// // norm +// input_layer = ctx0.op_norm(&input_layer); +// input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); + +// let embeddings_tensor: ggml::Tensor = input_layer.share(); + +// // lm_head +// input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); + +// ctx0.set_offloading(false); + +// input_layer = ctx0.op_add(&input_layer, &self.lmh_b); + +// ( +// gf, +// GraphOutputs { +// result: input_layer, +// embedding_result: embeddings_tensor, +// }, +// ) +// }); + +// // finish evaluation +// common::read_last_token(session, &outputs.result, n_vocab, input_len); +// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); +// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); +// } + +// fn hyperparameters(&self) -> &Self::Hyperparameters { +// &self.hyperparameters +// } + +// fn tokenizer(&self) -> &Tokenizer { +// &self.tokenizer +// } + +// fn context_size(&self) -> usize { +// self.params.context_size +// } + +// fn bot_token_id(&self) -> Option { +// None +// } + +// fn eot_token_id(&self) -> TokenId { +// self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() +// } + +// fn quantize_tensors() -> Vec { +// vec![Regex::new(".*weight").unwrap()] +// } + +// fn skip_quantize_tensors() -> Vec { +// vec![] +// } + +// fn supports_rewind(&self) -> bool { +// true +// } +// } + +// /// GPT-J [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +// #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +// pub struct Hyperparameters { +// /// Size of the model's vocabulary +// pub n_vocab: usize, +// /// Size of the model's context +// pub n_ctx: usize, +// /// Size of the model's embedding layer +// pub n_embd: usize, +// /// n_head +// pub n_head: usize, +// /// Number of layers in the model +// pub n_layer: usize, +// /// n_rot +// pub n_rot: usize, +// /// file_type +// pub file_type: FileType, +// } + +// impl llm_base::Hyperparameters for Hyperparameters { +// fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { +// let hyperparameters = Hyperparameters { +// n_vocab: util::read_i32(reader)?.try_into()?, +// n_ctx: util::read_i32(reader)?.try_into()?, +// n_embd: util::read_i32(reader)?.try_into()?, +// n_head: util::read_i32(reader)?.try_into()?, +// n_layer: util::read_i32(reader)?.try_into()?, +// n_rot: util::read_i32(reader)?.try_into()?, +// file_type: util::read_filetype(reader)?, +// }; + +// let n_vocab = util::read_i32(reader)? as usize; +// if hyperparameters.n_vocab != n_vocab { +// return Err(LoadError::InvariantBroken { +// path: None, +// invariant: format!( +// "GPTJ model expected n_vocab {} found {}", +// hyperparameters.n_vocab, n_vocab +// ), +// }); +// } + +// Ok(hyperparameters) +// } + +// fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { +// util::write_i32(writer, self.n_vocab.try_into()?)?; +// util::write_i32(writer, self.n_ctx.try_into()?)?; +// util::write_i32(writer, self.n_embd.try_into()?)?; +// util::write_i32(writer, self.n_head.try_into()?)?; +// util::write_i32(writer, self.n_layer.try_into()?)?; +// util::write_i32(writer, self.n_rot.try_into()?)?; +// util::write_i32(writer, self.file_type.into())?; +// util::write_i32(writer, self.n_vocab.try_into()?)?; +// Ok(()) +// } + +// fn n_vocabulary(&self) -> usize { +// self.n_vocab +// } + +// fn file_type(&self) -> Option { +// Some(self.file_type) +// } + +// fn file_type_mut(&mut self) -> Option<&mut FileType> { +// Some(&mut self.file_type) +// } +// } + +// struct Layer { +// // normalization +// ln_1_g: Tensor, +// ln_1_b: Tensor, + +// // attention +// c_attn_q_proj_w: Tensor, +// c_attn_k_proj_w: Tensor, +// c_attn_v_proj_w: Tensor, + +// c_attn_proj_w: Tensor, + +// // ff +// c_mlp_fc_w: Tensor, +// c_mlp_fc_b: Tensor, + +// c_mlp_proj_w: Tensor, +// c_mlp_proj_b: Tensor, +// } diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 9075eb01..93ecf6cf 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -1,515 +1,515 @@ -//! An implementation of [GPT-NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox) for the `llm` ecosystem. -//! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model. -#![deny(missing_docs)] - -use std::error::Error; - -use ggml::Tensor; -use llm_base::{ - ggml, - model::{common, HyperparametersWriteError}, - util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, -}; - -/// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox) -/// -/// # Safety -/// This implements [Send] and [Sync] as it is immutable after construction. -pub struct GptNeoX { - params: ModelParameters, - - hyperparameters: Hyperparameters, - tokenizer: Tokenizer, - - // model-global weights - // normalization gain & bias - ln_f_g: Tensor, - ln_f_b: Tensor, - // weight token embeddings - wte: Tensor, - // language model head gain - lmh_g: Tensor, - - // weights for the model - layers: Vec, - - // must be kept alive for the model - context: ModelContext, -} - -unsafe impl Send for GptNeoX {} -unsafe impl Sync for GptNeoX {} - -impl KnownModel for GptNeoX { - type Hyperparameters = Hyperparameters; - - fn new( - hyperparameters: Hyperparameters, - params: ModelParameters, - tokenizer: Tokenizer, - tensor_loader: impl TensorLoader, - ) -> Result - where - Self: Sized, - { - let mut tl = tensor_loader; - - // model-global weights - let wte = tl.load("gpt_neox.embed_in.weight")?; - - let backend = params.backend(0); - - let ln_f_g = tl - .load("gpt_neox.final_layer_norm.weight")? - .transfer_to(backend); - let ln_f_b = tl - .load("gpt_neox.final_layer_norm.bias")? - .transfer_to(backend); - let lmh_g = tl.load("embed_out.weight")?.transfer_to(backend); - - let mut layers = Vec::new(); - for i in 0..hyperparameters.n_layer { - let backend = params.backend(i); - let layer = Layer { - ln_1_g: tl - .load(&format!("gpt_neox.layers.{i}.input_layernorm.weight"))? - .transfer_to(backend), - ln_1_b: tl - .load(&format!("gpt_neox.layers.{i}.input_layernorm.bias"))? - .transfer_to(backend), - - c_attn_attn_w: tl - .load(&format!( - "gpt_neox.layers.{i}.attention.query_key_value.weight" - ))? - .transfer_to(backend), - c_attn_attn_b: tl - .load(&format!( - "gpt_neox.layers.{i}.attention.query_key_value.bias" - ))? - .transfer_to(backend), - - c_attn_proj_w: tl - .load(&format!("gpt_neox.layers.{i}.attention.dense.weight"))? - .transfer_to(backend), - c_attn_proj_b: tl - .load(&format!("gpt_neox.layers.{i}.attention.dense.bias"))? - .transfer_to(backend), - - ln_2_g: tl - .load(&format!( - "gpt_neox.layers.{i}.post_attention_layernorm.weight" - ))? - .transfer_to(backend), - ln_2_b: tl - .load(&format!( - "gpt_neox.layers.{i}.post_attention_layernorm.bias" - ))? - .transfer_to(backend), - - c_mlp_fc_w: tl - .load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight"))? - .transfer_to(backend), - c_mlp_fc_b: tl - .load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias"))? - .transfer_to(backend), - - c_mlp_proj_w: tl - .load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight"))? - .transfer_to(backend), - c_mlp_proj_b: tl - .load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias"))? - .transfer_to(backend), - }; - - layers.push(layer); - } - - let context = tl.finish(); - - Ok(GptNeoX { - hyperparameters, - params, - tokenizer, - ln_f_g, - ln_f_b, - wte, - lmh_g, - layers, - context, - }) - } - - fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { - InferenceSession::new( - config, - &self.params, - self.hyperparameters.n_layer, - self.hyperparameters.n_embd, - self.hyperparameters.n_vocab, - ) - } - - // allow snake case here as its a one-to-one mapping of the original names - #[allow(non_snake_case)] - fn evaluate( - &self, - session: &mut InferenceSession, - input_tokens: &[TokenId], - output_request: &mut OutputRequest, - ) { - let n = input_tokens.len(); - let n_past = session.n_past; - let n_ctx = self.params.context_size; - - let Hyperparameters { - n_embd, - n_head, - n_vocab, - n_layer, - n_rot, - use_parallel_residual, - .. - } = self.hyperparameters; - - let outputs = session.compute(self.context.clone(), input_tokens, |builder| { - let mut ctx0 = builder.ctx0.borrow_mut(); - let embd = builder.embd; - let mut input_layer = ctx0.op_get_rows(&self.wte, embd); - let (memory_k_size, memory_v_size) = ( - builder.memory_k.element_size(), - builder.memory_v.element_size(), - ); - - let mut gf = ctx0.create_compute_graph(); - - for il in 0..n_layer { - ctx0.set_offloading(self.params.should_offload(il)); - // attention uses first scratch buffer - ctx0.use_scratch(builder.get_scratch(0)); - - // self-attention - let mut current = ctx0.op_norm(&input_layer); - current = ctx0.op_add( - &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), - &self.layers[il].ln_1_b, - ); - - // self-attention compute QKV - current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_attn_attn_b); - - let nb = current.get_nb()[1]; - let f32_size = std::mem::size_of::(); - - let mut qcur = ctx0.op_cont(&ctx0.op_view_3d( - ¤t, - (n_embd / n_head, n_head, n), - (nb / n_head, nb), - 0, - )); - let mut kcur = ctx0.op_cont(&ctx0.op_view_3d( - ¤t, - (n_embd / n_head, n_head, n), - (nb / n_head, nb), - f32_size * n_embd / n_head, - )); - let mut vcur = ctx0.op_cont(&ctx0.op_view_3d( - ¤t, - (n_embd / n_head, n_head, n), - (nb / n_head, nb), - 2 * f32_size * n_embd / n_head, - )); - - // self-attention using mode = 2 for GPT-NeoX mode - let overrides = self.params.rope_overrides.as_ref(); - qcur = ctx0.op_rope_inplace(&qcur, n_past, n_rot, 2, overrides); - kcur = ctx0.op_rope_inplace(&kcur, n_past, n_rot, 2, overrides); - - // store key and value to memory - vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, n_embd, n)); - - let k = ctx0.op_view_1d( - builder.memory_k, - n * n_embd, - (memory_k_size * n_embd) * (il * n_ctx + n_past), - ); - - let v = ctx0.op_view_2d( - builder.memory_v, - (n, n_embd), - n_ctx * memory_v_size, - (il * n_ctx) * memory_v_size * n_embd + n_past * memory_v_size, - ); - - gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); - - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - let Q = ctx0.op_permute(&qcur, (0, 2, 1, 3)); - // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - let K = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - builder.memory_k, - (n_past + n) * n_embd, - il * n_ctx * memory_k_size * n_embd, - ), - n_embd / n_head, - n_head, - n_past + n, - ), - (0, 2, 1, 3), - ); - - // K * Q - let KQ = ctx0.op_mul_mat(&K, &Q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - let KQ_scaled = ctx0.op_scale_inplace( - &KQ, - &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - - // KQ_masked = mask_past(KQ_scaled) - let KQ_masked = ctx0.op_diag_mask_inf_inplace(&KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - let KQ_softmax = ctx0.op_soft_max_inplace(&KQ_masked); - - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - let V = ctx0.op_view_3d( - builder.memory_v, - (n_past + n, n_embd / n_head, n_head), - ( - n_ctx * memory_v_size, - n_ctx * memory_v_size * n_embd / n_head, - ), - il * n_ctx * memory_v_size * n_embd, - ); - - // KQV = transpose(V) * KQ_soft_max - let KQV = ctx0.op_mul_mat(&V, &KQ_softmax); - // KQV_merged = KQV.permute(0, 2, 1, 3) - let KQV_merged = ctx0.op_permute(&KQV, (0, 2, 1, 3)); - - // cur = KQV_merged.contiguous().view(n_embd, N) - current = ctx0.op_cpy(&KQV_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); - - // self-attention projection - current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_attn_proj_b); - - // use the second scratch for the feed forward - ctx0.use_scratch(builder.get_scratch(1)); - - let feedforward_input: Tensor; - if !use_parallel_residual { - feedforward_input = ctx0.op_add(¤t, &input_layer); - current = feed_forward_network(&ctx0, &self.layers[il], &feedforward_input); - // input for next layer - input_layer = ctx0.op_add(¤t, &feedforward_input); - } else { - // calculate with parallel residual - feedforward_input = current.share(); - - // this is independent of the self-attention result, so it could be done in parallel to the self-attention - // note here we pass inpL instead of cur - current = feed_forward_network(&ctx0, &self.layers[il], &input_layer); - - // layer input + FF - current = ctx0.op_add(¤t, &feedforward_input); - - // input for next layer - input_layer = ctx0.op_add(¤t, &input_layer); - } - } - - // use the first scratch for the norm - ctx0.use_scratch(builder.get_scratch(0)); - - // normalize the output - input_layer = ctx0.op_norm(&input_layer); - // inpL = ln_f_g*inpL + ln_f_b - input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); - - let embeddings_tensor: ggml::Tensor = input_layer.share(); - - // Disable the scratchbuffer - ctx0.use_scratch(None); - ctx0.set_offloading(false); - // apply language model head - input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); - - ( - gf, - GraphOutputs { - result: input_layer, - embedding_result: embeddings_tensor, - }, - ) - }); - - // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, n); - common::extract_logits(output_request, &outputs.result, n_vocab, n); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); - } - - fn hyperparameters(&self) -> &Self::Hyperparameters { - &self.hyperparameters - } - - fn tokenizer(&self) -> &Tokenizer { - &self.tokenizer - } - - fn context_size(&self) -> usize { - self.params.context_size - } - - fn bot_token_id(&self) -> Option { - None - } - - fn eot_token_id(&self) -> TokenId { - self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() - } - - fn quantize_tensors() -> Vec { - vec![Regex::new(".*weight").unwrap()] - } - - fn skip_quantize_tensors() -> Vec { - vec![] - } - - fn supports_rewind(&self) -> bool { - true - } -} - -/// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct Hyperparameters { - /// Size of the model's vocabulary - pub n_vocab: usize, - /// Size of the model's context - pub n_ctx: usize, - /// Size of the model's embedding layer - pub n_embd: usize, - /// n_head - pub n_head: usize, - /// Number of layers in the model - pub n_layer: usize, - /// n_rot - pub n_rot: usize, - /// Whether to use a "parallel" formulation in each Transformer layer. - /// This is on for most models, but is off for some e.g. RedPajama. - pub use_parallel_residual: bool, - /// file_type - pub file_type: FileType, -} - -impl Default for Hyperparameters { - fn default() -> Self { - Self { - n_vocab: Default::default(), - n_ctx: Default::default(), - n_embd: Default::default(), - n_head: Default::default(), - n_layer: Default::default(), - n_rot: Default::default(), - file_type: Default::default(), - use_parallel_residual: true, - } - } -} - -impl llm_base::Hyperparameters for Hyperparameters { - fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { - Ok(Hyperparameters { - n_vocab: util::read_i32(reader)?.try_into()?, - n_ctx: util::read_i32(reader)?.try_into()?, - n_embd: util::read_i32(reader)?.try_into()?, - n_head: util::read_i32(reader)?.try_into()?, - n_layer: util::read_i32(reader)?.try_into()?, - n_rot: util::read_i32(reader)?.try_into()?, - use_parallel_residual: util::read_bool(reader)?, - file_type: util::read_filetype(reader)?, - }) - } - - fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { - util::write_i32(writer, self.n_vocab.try_into()?)?; - util::write_i32(writer, self.n_ctx.try_into()?)?; - util::write_i32(writer, self.n_embd.try_into()?)?; - util::write_i32(writer, self.n_head.try_into()?)?; - util::write_i32(writer, self.n_layer.try_into()?)?; - util::write_i32(writer, self.n_rot.try_into()?)?; - util::write_bool(writer, self.use_parallel_residual)?; - util::write_i32(writer, self.file_type.into())?; - Ok(()) - } - - fn n_vocabulary(&self) -> usize { - self.n_vocab - } - - fn file_type(&self) -> Option { - Some(self.file_type) - } - - fn file_type_mut(&mut self) -> Option<&mut FileType> { - Some(&mut self.file_type) - } -} - -struct Layer { - // pre-normalization - ln_1_g: Tensor, - ln_1_b: Tensor, - - // attention - c_attn_attn_w: Tensor, - c_attn_attn_b: Tensor, - - c_attn_proj_w: Tensor, - c_attn_proj_b: Tensor, - - // post normalization - ln_2_g: Tensor, - ln_2_b: Tensor, - - // feed-forward - c_mlp_fc_w: Tensor, - c_mlp_fc_b: Tensor, - - c_mlp_proj_w: Tensor, - c_mlp_proj_b: Tensor, -} - -fn feed_forward_network(context: &ggml::Context, layer: &Layer, input: &Tensor) -> Tensor { - let mut current = context.op_norm(input); - - //gain and bias - current = context.op_add(&context.op_mul(¤t, &layer.ln_2_g), &layer.ln_2_b); - - // apply weights - current = context.op_mul_mat(&layer.c_mlp_fc_w, ¤t); - - // apply bias - current = context.op_add(¤t, &layer.c_mlp_fc_b); - - // GELU activation - current = context.op_gelu(¤t); - - // projection - // cur = proj_w*cur + proj_b - current = context.op_mul_mat(&layer.c_mlp_proj_w, ¤t); - - current = context.op_add(¤t, &layer.c_mlp_proj_b); - - current -} +// //! An implementation of [GPT-NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox) for the `llm` ecosystem. +// //! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model. +// #![deny(missing_docs)] + +// use std::error::Error; + +// use ggml::Tensor; +// use llm_base::{ +// ggml, +// model::{common, HyperparametersWriteError}, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, +// }; + +// /// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox) +// /// +// /// # Safety +// /// This implements [Send] and [Sync] as it is immutable after construction. +// pub struct GptNeoX { +// params: ModelParameters, + +// hyperparameters: Hyperparameters, +// tokenizer: Tokenizer, + +// // model-global weights +// // normalization gain & bias +// ln_f_g: Tensor, +// ln_f_b: Tensor, +// // weight token embeddings +// wte: Tensor, +// // language model head gain +// lmh_g: Tensor, + +// // weights for the model +// layers: Vec, + +// // must be kept alive for the model +// context: ModelContext, +// } + +// unsafe impl Send for GptNeoX {} +// unsafe impl Sync for GptNeoX {} + +// impl KnownModel for GptNeoX { +// type Hyperparameters = Hyperparameters; + +// fn new( +// hyperparameters: Hyperparameters, +// params: ModelParameters, +// tokenizer: Tokenizer, +// tensor_loader: impl TensorLoader, +// ) -> Result +// where +// Self: Sized, +// { +// let mut tl = tensor_loader; + +// // model-global weights +// let wte = tl.load("gpt_neox.embed_in.weight")?; + +// let backend = params.backend(0); + +// let ln_f_g = tl +// .load("gpt_neox.final_layer_norm.weight")? +// .transfer_to(backend); +// let ln_f_b = tl +// .load("gpt_neox.final_layer_norm.bias")? +// .transfer_to(backend); +// let lmh_g = tl.load("embed_out.weight")?.transfer_to(backend); + +// let mut layers = Vec::new(); +// for i in 0..hyperparameters.n_layer { +// let backend = params.backend(i); +// let layer = Layer { +// ln_1_g: tl +// .load(&format!("gpt_neox.layers.{i}.input_layernorm.weight"))? +// .transfer_to(backend), +// ln_1_b: tl +// .load(&format!("gpt_neox.layers.{i}.input_layernorm.bias"))? +// .transfer_to(backend), + +// c_attn_attn_w: tl +// .load(&format!( +// "gpt_neox.layers.{i}.attention.query_key_value.weight" +// ))? +// .transfer_to(backend), +// c_attn_attn_b: tl +// .load(&format!( +// "gpt_neox.layers.{i}.attention.query_key_value.bias" +// ))? +// .transfer_to(backend), + +// c_attn_proj_w: tl +// .load(&format!("gpt_neox.layers.{i}.attention.dense.weight"))? +// .transfer_to(backend), +// c_attn_proj_b: tl +// .load(&format!("gpt_neox.layers.{i}.attention.dense.bias"))? +// .transfer_to(backend), + +// ln_2_g: tl +// .load(&format!( +// "gpt_neox.layers.{i}.post_attention_layernorm.weight" +// ))? +// .transfer_to(backend), +// ln_2_b: tl +// .load(&format!( +// "gpt_neox.layers.{i}.post_attention_layernorm.bias" +// ))? +// .transfer_to(backend), + +// c_mlp_fc_w: tl +// .load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight"))? +// .transfer_to(backend), +// c_mlp_fc_b: tl +// .load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias"))? +// .transfer_to(backend), + +// c_mlp_proj_w: tl +// .load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight"))? +// .transfer_to(backend), +// c_mlp_proj_b: tl +// .load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias"))? +// .transfer_to(backend), +// }; + +// layers.push(layer); +// } + +// let context = tl.finish(); + +// Ok(GptNeoX { +// hyperparameters, +// params, +// tokenizer, +// ln_f_g, +// ln_f_b, +// wte, +// lmh_g, +// layers, +// context, +// }) +// } + +// fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { +// InferenceSession::new( +// config, +// &self.params, +// self.hyperparameters.n_layer, +// self.hyperparameters.n_embd, +// self.hyperparameters.n_vocab, +// ) +// } + +// // allow snake case here as its a one-to-one mapping of the original names +// #[allow(non_snake_case)] +// fn evaluate( +// &self, +// session: &mut InferenceSession, +// input_tokens: &[TokenId], +// output_request: &mut OutputRequest, +// ) { +// let n = input_tokens.len(); +// let n_past = session.n_past; +// let n_ctx = self.params.context_size; + +// let Hyperparameters { +// n_embd, +// n_head, +// n_vocab, +// n_layer, +// n_rot, +// use_parallel_residual, +// .. +// } = self.hyperparameters; + +// let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let mut ctx0 = builder.ctx0.borrow_mut(); +// let embd = builder.embd; +// let mut input_layer = ctx0.op_get_rows(&self.wte, embd); +// let (memory_k_size, memory_v_size) = ( +// builder.memory_k.element_size(), +// builder.memory_v.element_size(), +// ); + +// let mut gf = ctx0.create_compute_graph(); + +// for il in 0..n_layer { +// ctx0.set_offloading(self.params.should_offload(il)); +// // attention uses first scratch buffer +// ctx0.use_scratch(builder.get_scratch(0)); + +// // self-attention +// let mut current = ctx0.op_norm(&input_layer); +// current = ctx0.op_add( +// &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), +// &self.layers[il].ln_1_b, +// ); + +// // self-attention compute QKV +// current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].c_attn_attn_b); + +// let nb = current.get_nb()[1]; +// let f32_size = std::mem::size_of::(); + +// let mut qcur = ctx0.op_cont(&ctx0.op_view_3d( +// ¤t, +// (n_embd / n_head, n_head, n), +// (nb / n_head, nb), +// 0, +// )); +// let mut kcur = ctx0.op_cont(&ctx0.op_view_3d( +// ¤t, +// (n_embd / n_head, n_head, n), +// (nb / n_head, nb), +// f32_size * n_embd / n_head, +// )); +// let mut vcur = ctx0.op_cont(&ctx0.op_view_3d( +// ¤t, +// (n_embd / n_head, n_head, n), +// (nb / n_head, nb), +// 2 * f32_size * n_embd / n_head, +// )); + +// // self-attention using mode = 2 for GPT-NeoX mode +// let overrides = self.params.rope_overrides.as_ref(); +// qcur = ctx0.op_rope_inplace(&qcur, n_past, n_rot, 2, overrides); +// kcur = ctx0.op_rope_inplace(&kcur, n_past, n_rot, 2, overrides); + +// // store key and value to memory +// vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, n_embd, n)); + +// let k = ctx0.op_view_1d( +// builder.memory_k, +// n * n_embd, +// (memory_k_size * n_embd) * (il * n_ctx + n_past), +// ); + +// let v = ctx0.op_view_2d( +// builder.memory_v, +// (n, n_embd), +// n_ctx * memory_v_size, +// (il * n_ctx) * memory_v_size * n_embd + n_past * memory_v_size, +// ); + +// gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); +// gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); + +// // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) +// let Q = ctx0.op_permute(&qcur, (0, 2, 1, 3)); +// // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) +// let K = ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// builder.memory_k, +// (n_past + n) * n_embd, +// il * n_ctx * memory_k_size * n_embd, +// ), +// n_embd / n_head, +// n_head, +// n_past + n, +// ), +// (0, 2, 1, 3), +// ); + +// // K * Q +// let KQ = ctx0.op_mul_mat(&K, &Q); + +// // KQ_scaled = KQ / sqrt(n_embd/n_head) +// let KQ_scaled = ctx0.op_scale_inplace( +// &KQ, +// &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), +// ); + +// // KQ_masked = mask_past(KQ_scaled) +// let KQ_masked = ctx0.op_diag_mask_inf_inplace(&KQ_scaled, n_past); + +// // KQ = soft_max(KQ_masked) +// let KQ_softmax = ctx0.op_soft_max_inplace(&KQ_masked); + +// // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() +// let V = ctx0.op_view_3d( +// builder.memory_v, +// (n_past + n, n_embd / n_head, n_head), +// ( +// n_ctx * memory_v_size, +// n_ctx * memory_v_size * n_embd / n_head, +// ), +// il * n_ctx * memory_v_size * n_embd, +// ); + +// // KQV = transpose(V) * KQ_soft_max +// let KQV = ctx0.op_mul_mat(&V, &KQ_softmax); +// // KQV_merged = KQV.permute(0, 2, 1, 3) +// let KQV_merged = ctx0.op_permute(&KQV, (0, 2, 1, 3)); + +// // cur = KQV_merged.contiguous().view(n_embd, N) +// current = ctx0.op_cpy(&KQV_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); + +// // self-attention projection +// current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].c_attn_proj_b); + +// // use the second scratch for the feed forward +// ctx0.use_scratch(builder.get_scratch(1)); + +// let feedforward_input: Tensor; +// if !use_parallel_residual { +// feedforward_input = ctx0.op_add(¤t, &input_layer); +// current = feed_forward_network(&ctx0, &self.layers[il], &feedforward_input); +// // input for next layer +// input_layer = ctx0.op_add(¤t, &feedforward_input); +// } else { +// // calculate with parallel residual +// feedforward_input = current.share(); + +// // this is independent of the self-attention result, so it could be done in parallel to the self-attention +// // note here we pass inpL instead of cur +// current = feed_forward_network(&ctx0, &self.layers[il], &input_layer); + +// // layer input + FF +// current = ctx0.op_add(¤t, &feedforward_input); + +// // input for next layer +// input_layer = ctx0.op_add(¤t, &input_layer); +// } +// } + +// // use the first scratch for the norm +// ctx0.use_scratch(builder.get_scratch(0)); + +// // normalize the output +// input_layer = ctx0.op_norm(&input_layer); +// // inpL = ln_f_g*inpL + ln_f_b +// input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); + +// let embeddings_tensor: ggml::Tensor = input_layer.share(); + +// // Disable the scratchbuffer +// ctx0.use_scratch(None); +// ctx0.set_offloading(false); +// // apply language model head +// input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); + +// ( +// gf, +// GraphOutputs { +// result: input_layer, +// embedding_result: embeddings_tensor, +// }, +// ) +// }); + +// // finish evaluation +// common::read_last_token(session, &outputs.result, n_vocab, n); +// common::extract_logits(output_request, &outputs.result, n_vocab, n); +// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); +// } + +// fn hyperparameters(&self) -> &Self::Hyperparameters { +// &self.hyperparameters +// } + +// fn tokenizer(&self) -> &Tokenizer { +// &self.tokenizer +// } + +// fn context_size(&self) -> usize { +// self.params.context_size +// } + +// fn bot_token_id(&self) -> Option { +// None +// } + +// fn eot_token_id(&self) -> TokenId { +// self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() +// } + +// fn quantize_tensors() -> Vec { +// vec![Regex::new(".*weight").unwrap()] +// } + +// fn skip_quantize_tensors() -> Vec { +// vec![] +// } + +// fn supports_rewind(&self) -> bool { +// true +// } +// } + +// /// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +// #[derive(Debug, PartialEq, Eq, Clone, Copy)] +// pub struct Hyperparameters { +// /// Size of the model's vocabulary +// pub n_vocab: usize, +// /// Size of the model's context +// pub n_ctx: usize, +// /// Size of the model's embedding layer +// pub n_embd: usize, +// /// n_head +// pub n_head: usize, +// /// Number of layers in the model +// pub n_layer: usize, +// /// n_rot +// pub n_rot: usize, +// /// Whether to use a "parallel" formulation in each Transformer layer. +// /// This is on for most models, but is off for some e.g. RedPajama. +// pub use_parallel_residual: bool, +// /// file_type +// pub file_type: FileType, +// } + +// impl Default for Hyperparameters { +// fn default() -> Self { +// Self { +// n_vocab: Default::default(), +// n_ctx: Default::default(), +// n_embd: Default::default(), +// n_head: Default::default(), +// n_layer: Default::default(), +// n_rot: Default::default(), +// file_type: Default::default(), +// use_parallel_residual: true, +// } +// } +// } + +// impl llm_base::Hyperparameters for Hyperparameters { +// fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { +// Ok(Hyperparameters { +// n_vocab: util::read_i32(reader)?.try_into()?, +// n_ctx: util::read_i32(reader)?.try_into()?, +// n_embd: util::read_i32(reader)?.try_into()?, +// n_head: util::read_i32(reader)?.try_into()?, +// n_layer: util::read_i32(reader)?.try_into()?, +// n_rot: util::read_i32(reader)?.try_into()?, +// use_parallel_residual: util::read_bool(reader)?, +// file_type: util::read_filetype(reader)?, +// }) +// } + +// fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { +// util::write_i32(writer, self.n_vocab.try_into()?)?; +// util::write_i32(writer, self.n_ctx.try_into()?)?; +// util::write_i32(writer, self.n_embd.try_into()?)?; +// util::write_i32(writer, self.n_head.try_into()?)?; +// util::write_i32(writer, self.n_layer.try_into()?)?; +// util::write_i32(writer, self.n_rot.try_into()?)?; +// util::write_bool(writer, self.use_parallel_residual)?; +// util::write_i32(writer, self.file_type.into())?; +// Ok(()) +// } + +// fn n_vocabulary(&self) -> usize { +// self.n_vocab +// } + +// fn file_type(&self) -> Option { +// Some(self.file_type) +// } + +// fn file_type_mut(&mut self) -> Option<&mut FileType> { +// Some(&mut self.file_type) +// } +// } + +// struct Layer { +// // pre-normalization +// ln_1_g: Tensor, +// ln_1_b: Tensor, + +// // attention +// c_attn_attn_w: Tensor, +// c_attn_attn_b: Tensor, + +// c_attn_proj_w: Tensor, +// c_attn_proj_b: Tensor, + +// // post normalization +// ln_2_g: Tensor, +// ln_2_b: Tensor, + +// // feed-forward +// c_mlp_fc_w: Tensor, +// c_mlp_fc_b: Tensor, + +// c_mlp_proj_w: Tensor, +// c_mlp_proj_b: Tensor, +// } + +// fn feed_forward_network(context: &ggml::Context, layer: &Layer, input: &Tensor) -> Tensor { +// let mut current = context.op_norm(input); + +// //gain and bias +// current = context.op_add(&context.op_mul(¤t, &layer.ln_2_g), &layer.ln_2_b); + +// // apply weights +// current = context.op_mul_mat(&layer.c_mlp_fc_w, ¤t); + +// // apply bias +// current = context.op_add(¤t, &layer.c_mlp_fc_b); + +// // GELU activation +// current = context.op_gelu(¤t); + +// // projection +// // cur = proj_w*cur + proj_b +// current = context.op_mul_mat(&layer.c_mlp_proj_w, ¤t); + +// current = context.op_add(¤t, &layer.c_mlp_proj_b); + +// current +// } diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 3d22efff..2b7db5b8 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -1,371 +1,371 @@ -//! An implementation of [MPT](https://huggingface.co/mosaicml) for the `llm` ecosystem. -#![deny(missing_docs)] - -use ggml::Tensor; -use llm_base::{ - ggml::{self}, - model::{common, HyperparametersWriteError}, - util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, -}; - -/// The MosaicML Pretrained Transformer (MPT) model. Ref: [Mosaic ML](https://www.mosaicml.com/blog/mpt-7b) -/// -/// # Safety -/// This implements [Send] and [Sync] as it is immutable after construction. -pub struct Mpt { - params: ModelParameters, - - hyperparameters: Hyperparameters, - tokenizer: Tokenizer, - - // model-global weights - // weighted token embeddings - wte: Tensor, - // normalization - norm: Tensor, - - // weights for the model - layers: Vec, - - // must be kept alive for the model - context: ModelContext, -} - -unsafe impl Send for Mpt {} -unsafe impl Sync for Mpt {} - -impl KnownModel for Mpt { - type Hyperparameters = Hyperparameters; - - fn new( - hyperparameters: Self::Hyperparameters, - params: ModelParameters, - tokenizer: Tokenizer, - tensor_loader: impl llm_base::TensorLoader, - ) -> Result { - let mut tl = tensor_loader; - - // model-gobal weights - let wte = tl.load("transformer.wte.weight")?; - let norm = tl.load("transformer.norm_f.weight")?; - - let mut layers = Vec::new(); - for i in 0..hyperparameters.n_layer { - let layer = Layer { - norm_1_weight: tl.load(&format!("transformer.blocks.{i}.norm_1.weight"))?, - c_attn_wqkv_weight: tl.load(&format!("transformer.blocks.{i}.attn.Wqkv.weight"))?, - - c_attn_out_proj_weight: tl - .load(&format!("transformer.blocks.{i}.attn.out_proj.weight"))?, - norm_2_weight: tl.load(&format!("transformer.blocks.{i}.norm_2.weight"))?, - - ffn_up_proj: tl.load(&format!("transformer.blocks.{i}.ffn.up_proj.weight"))?, - ffn_down_proj: tl.load(&format!("transformer.blocks.{i}.ffn.down_proj.weight"))?, - }; - - layers.push(layer); - } - - let context = tl.finish(); - - Ok(Mpt { - hyperparameters, - params, - tokenizer, - wte, - norm, - layers, - context, - }) - } - - fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { - InferenceSession::new( - config, - &self.params, - self.hyperparameters.n_layer, - self.hyperparameters.n_embd, - self.hyperparameters.n_vocab, - ) - } - - fn evaluate( - &self, - session: &mut InferenceSession, - input_tokens: &[TokenId], - output_request: &mut OutputRequest, - ) { - let n = input_tokens.len(); - let session_len = session.n_past; - let ctx_size = self.params.context_size; - - let Hyperparameters { - n_embd, - n_head, - n_vocab, - n_layer, - alibi_bias_max, - .. - } = self.hyperparameters; - - let outputs = session.compute(self.context.clone(), input_tokens, |builder| { - let ctx0 = builder.ctx0.borrow(); - let (memory_k_size, memory_v_size) = ( - builder.memory_k.element_size(), - builder.memory_v.element_size(), - ); - let embd = builder.embd; - - let mut input_layer = ctx0.op_get_rows(&self.wte, embd); - - let f32_size = std::mem::size_of::(); - - let mut gf = ctx0.create_compute_graph(); - for il in 0..n_layer { - // attention uses first scratch buffer - ctx0.use_scratch(builder.get_scratch(0)); - - let mut current = ctx0.op_norm(&input_layer); - current = ctx0.op_mul(¤t, &self.layers[il].norm_1_weight); - - current = ctx0.op_mul_mat(&self.layers[il].c_attn_wqkv_weight, ¤t); - - let nb = current.get_nb()[1]; - let qcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, 0); - let kcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, f32_size * n_embd); - let vcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, f32_size * n_embd * 2); - - let k = ctx0.op_view_1d( - builder.memory_k, - n * n_embd, - (memory_k_size * n_embd) * (il * ctx_size + session_len), - ); - let v = ctx0.op_view_1d( - builder.memory_v, - n * n_embd, - (memory_v_size * n_embd) * (il * ctx_size + session_len), - ); - - gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); - - let q = ctx0.op_permute( - &ctx0.op_cpy( - &qcur, - &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, n), - ), - (0, 2, 1, 3), - ); - - let bigk = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - builder.memory_k, - (session_len + n) * n_embd, - il * ctx_size * memory_k_size * n_embd, - ), - n_embd / n_head, - n_head, - session_len + n, - ), - (0, 2, 1, 3), - ); - - let kq = ctx0.op_mul_mat(&bigk, &q); - let kq_scaled = ctx0.op_scale( - &kq, - &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - let kq_scaled_alibi = - ctx0.op_alibi(&kq_scaled, session_len, n_head, alibi_bias_max); - let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled_alibi, session_len); - let kq_softmax = ctx0.op_soft_max(&kq_masked); - - let v_trans = ctx0.op_cpy( - &ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - builder.memory_v, - (session_len + n) * n_embd, - il * ctx_size * memory_v_size * n_embd, - ), - n_embd / n_head, - n_head, - session_len + n, - ), - (1, 2, 0, 3), - ), - &ctx0.new_tensor_3d( - builder.memory_v.get_type(), - session_len + n, - n_embd / n_head, - n_head, - ), - ); - - let kqv = ctx0.op_mul_mat(&v_trans, &kq_softmax); - let kqv_merged = ctx0.op_permute(&kqv, (0, 2, 1, 3)); - - current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); - // projection - current = ctx0.op_mul_mat(&self.layers[il].c_attn_out_proj_weight, ¤t); - - input_layer = ctx0.op_add(&input_layer, ¤t); - - // feed forward uses second scratch buffer - ctx0.use_scratch(builder.get_scratch(1)); - - current = ctx0.op_norm(&input_layer); - current = ctx0.op_mul(¤t, &self.layers[il].norm_2_weight); - - current = ctx0.op_mul_mat(&self.layers[il].ffn_up_proj, ¤t); - - current = ctx0.op_gelu(¤t); - - // projection - current = ctx0.op_mul_mat(&self.layers[il].ffn_down_proj, ¤t); - - input_layer = ctx0.op_add(&input_layer, ¤t); - } - - //use scratch buffer 0 for the rest - ctx0.use_scratch(builder.get_scratch(0)); - - // norm - input_layer = ctx0.op_norm(&input_layer); - input_layer = ctx0.op_mul(&input_layer, &self.norm); - - let embeddings_tensor: ggml::Tensor = input_layer.share(); - - // disable scratch buffer for last layer - ctx0.use_scratch(None); - // output embedding weight tied to input embedding - input_layer = ctx0.op_mul_mat(&self.wte, &input_layer); - - ( - gf, - GraphOutputs { - result: input_layer, - embedding_result: embeddings_tensor, - }, - ) - }); - - // finish evaluation - common::read_last_token(session, &outputs.result, n_vocab, n); - common::extract_logits(output_request, &outputs.result, n_vocab, n); - common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); - } - - fn hyperparameters(&self) -> &Self::Hyperparameters { - &self.hyperparameters - } - - fn tokenizer(&self) -> &Tokenizer { - &self.tokenizer - } - - fn context_size(&self) -> usize { - self.params.context_size - } - - fn bot_token_id(&self) -> Option { - self.tokenizer.id("<|padding|>".as_bytes()) - } - - fn eot_token_id(&self) -> TokenId { - self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() - } - - fn quantize_tensors() -> Vec { - vec![Regex::new(".*weight").unwrap()] - } - - fn skip_quantize_tensors() -> Vec { - vec![] - } - - fn supports_rewind(&self) -> bool { - true - } -} - -/// MPT [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -#[derive(Debug, Default, PartialEq, Clone, Copy)] -pub struct Hyperparameters { - /// Size of the model's embedding layer - n_embd: usize, - /// Maximum sequence length - max_seq_len: usize, - /// n_heads - n_head: usize, - /// Number of layers in the model - n_layer: usize, - /// Size of the model's vocabulary - n_vocab: usize, - /// Alibi bias max - alibi_bias_max: f32, - /// Clip KQV - clip_kqv: f32, - /// file_type - file_type: FileType, -} -impl Eq for Hyperparameters {} - -impl llm_base::Hyperparameters for Hyperparameters { - fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { - let hyperparameters = Hyperparameters { - n_embd: util::read_i32(reader)?.try_into()?, - max_seq_len: util::read_i32(reader)?.try_into()?, - n_head: util::read_i32(reader)?.try_into()?, - n_layer: util::read_i32(reader)?.try_into()?, - n_vocab: util::read_i32(reader)?.try_into()?, - alibi_bias_max: util::read_f32(reader)?, - clip_kqv: util::read_f32(reader)?, - file_type: util::read_filetype(reader)?, - }; - - Ok(hyperparameters) - } - - fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { - util::write_i32(writer, self.n_embd.try_into()?)?; - util::write_i32(writer, self.max_seq_len.try_into()?)?; - util::write_i32(writer, self.n_head.try_into()?)?; - util::write_i32(writer, self.n_layer.try_into()?)?; - util::write_i32(writer, self.n_vocab.try_into()?)?; - util::write_f32(writer, self.alibi_bias_max)?; - util::write_f32(writer, self.clip_kqv)?; - util::write_i32(writer, self.file_type.into())?; - Ok(()) - } - - fn n_vocabulary(&self) -> usize { - self.n_vocab - } - - fn file_type(&self) -> Option { - Some(self.file_type) - } - - fn file_type_mut(&mut self) -> Option<&mut FileType> { - Some(&mut self.file_type) - } -} - -struct Layer { - // pre normalization - norm_1_weight: Tensor, - - // attention - c_attn_wqkv_weight: Tensor, - c_attn_out_proj_weight: Tensor, - - // post normalization - norm_2_weight: Tensor, - - // ff - ffn_up_proj: Tensor, - ffn_down_proj: Tensor, -} +// //! An implementation of [MPT](https://huggingface.co/mosaicml) for the `llm` ecosystem. +// #![deny(missing_docs)] + +// use ggml::Tensor; +// use llm_base::{ +// ggml::{self}, +// model::{common, HyperparametersWriteError}, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, +// }; + +// /// The MosaicML Pretrained Transformer (MPT) model. Ref: [Mosaic ML](https://www.mosaicml.com/blog/mpt-7b) +// /// +// /// # Safety +// /// This implements [Send] and [Sync] as it is immutable after construction. +// pub struct Mpt { +// params: ModelParameters, + +// hyperparameters: Hyperparameters, +// tokenizer: Tokenizer, + +// // model-global weights +// // weighted token embeddings +// wte: Tensor, +// // normalization +// norm: Tensor, + +// // weights for the model +// layers: Vec, + +// // must be kept alive for the model +// context: ModelContext, +// } + +// unsafe impl Send for Mpt {} +// unsafe impl Sync for Mpt {} + +// impl KnownModel for Mpt { +// type Hyperparameters = Hyperparameters; + +// fn new( +// hyperparameters: Self::Hyperparameters, +// params: ModelParameters, +// tokenizer: Tokenizer, +// tensor_loader: impl llm_base::TensorLoader, +// ) -> Result { +// let mut tl = tensor_loader; + +// // model-gobal weights +// let wte = tl.load("transformer.wte.weight")?; +// let norm = tl.load("transformer.norm_f.weight")?; + +// let mut layers = Vec::new(); +// for i in 0..hyperparameters.n_layer { +// let layer = Layer { +// norm_1_weight: tl.load(&format!("transformer.blocks.{i}.norm_1.weight"))?, +// c_attn_wqkv_weight: tl.load(&format!("transformer.blocks.{i}.attn.Wqkv.weight"))?, + +// c_attn_out_proj_weight: tl +// .load(&format!("transformer.blocks.{i}.attn.out_proj.weight"))?, +// norm_2_weight: tl.load(&format!("transformer.blocks.{i}.norm_2.weight"))?, + +// ffn_up_proj: tl.load(&format!("transformer.blocks.{i}.ffn.up_proj.weight"))?, +// ffn_down_proj: tl.load(&format!("transformer.blocks.{i}.ffn.down_proj.weight"))?, +// }; + +// layers.push(layer); +// } + +// let context = tl.finish(); + +// Ok(Mpt { +// hyperparameters, +// params, +// tokenizer, +// wte, +// norm, +// layers, +// context, +// }) +// } + +// fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { +// InferenceSession::new( +// config, +// &self.params, +// self.hyperparameters.n_layer, +// self.hyperparameters.n_embd, +// self.hyperparameters.n_vocab, +// ) +// } + +// fn evaluate( +// &self, +// session: &mut InferenceSession, +// input_tokens: &[TokenId], +// output_request: &mut OutputRequest, +// ) { +// let n = input_tokens.len(); +// let session_len = session.n_past; +// let ctx_size = self.params.context_size; + +// let Hyperparameters { +// n_embd, +// n_head, +// n_vocab, +// n_layer, +// alibi_bias_max, +// .. +// } = self.hyperparameters; + +// let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let ctx0 = builder.ctx0.borrow(); +// let (memory_k_size, memory_v_size) = ( +// builder.memory_k.element_size(), +// builder.memory_v.element_size(), +// ); +// let embd = builder.embd; + +// let mut input_layer = ctx0.op_get_rows(&self.wte, embd); + +// let f32_size = std::mem::size_of::(); + +// let mut gf = ctx0.create_compute_graph(); +// for il in 0..n_layer { +// // attention uses first scratch buffer +// ctx0.use_scratch(builder.get_scratch(0)); + +// let mut current = ctx0.op_norm(&input_layer); +// current = ctx0.op_mul(¤t, &self.layers[il].norm_1_weight); + +// current = ctx0.op_mul_mat(&self.layers[il].c_attn_wqkv_weight, ¤t); + +// let nb = current.get_nb()[1]; +// let qcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, 0); +// let kcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, f32_size * n_embd); +// let vcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, f32_size * n_embd * 2); + +// let k = ctx0.op_view_1d( +// builder.memory_k, +// n * n_embd, +// (memory_k_size * n_embd) * (il * ctx_size + session_len), +// ); +// let v = ctx0.op_view_1d( +// builder.memory_v, +// n * n_embd, +// (memory_v_size * n_embd) * (il * ctx_size + session_len), +// ); + +// gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); +// gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); + +// let q = ctx0.op_permute( +// &ctx0.op_cpy( +// &qcur, +// &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, n), +// ), +// (0, 2, 1, 3), +// ); + +// let bigk = ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// builder.memory_k, +// (session_len + n) * n_embd, +// il * ctx_size * memory_k_size * n_embd, +// ), +// n_embd / n_head, +// n_head, +// session_len + n, +// ), +// (0, 2, 1, 3), +// ); + +// let kq = ctx0.op_mul_mat(&bigk, &q); +// let kq_scaled = ctx0.op_scale( +// &kq, +// &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), +// ); +// let kq_scaled_alibi = +// ctx0.op_alibi(&kq_scaled, session_len, n_head, alibi_bias_max); +// let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled_alibi, session_len); +// let kq_softmax = ctx0.op_soft_max(&kq_masked); + +// let v_trans = ctx0.op_cpy( +// &ctx0.op_permute( +// &ctx0.op_reshape_3d( +// &ctx0.op_view_1d( +// builder.memory_v, +// (session_len + n) * n_embd, +// il * ctx_size * memory_v_size * n_embd, +// ), +// n_embd / n_head, +// n_head, +// session_len + n, +// ), +// (1, 2, 0, 3), +// ), +// &ctx0.new_tensor_3d( +// builder.memory_v.get_type(), +// session_len + n, +// n_embd / n_head, +// n_head, +// ), +// ); + +// let kqv = ctx0.op_mul_mat(&v_trans, &kq_softmax); +// let kqv_merged = ctx0.op_permute(&kqv, (0, 2, 1, 3)); + +// current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); +// // projection +// current = ctx0.op_mul_mat(&self.layers[il].c_attn_out_proj_weight, ¤t); + +// input_layer = ctx0.op_add(&input_layer, ¤t); + +// // feed forward uses second scratch buffer +// ctx0.use_scratch(builder.get_scratch(1)); + +// current = ctx0.op_norm(&input_layer); +// current = ctx0.op_mul(¤t, &self.layers[il].norm_2_weight); + +// current = ctx0.op_mul_mat(&self.layers[il].ffn_up_proj, ¤t); + +// current = ctx0.op_gelu(¤t); + +// // projection +// current = ctx0.op_mul_mat(&self.layers[il].ffn_down_proj, ¤t); + +// input_layer = ctx0.op_add(&input_layer, ¤t); +// } + +// //use scratch buffer 0 for the rest +// ctx0.use_scratch(builder.get_scratch(0)); + +// // norm +// input_layer = ctx0.op_norm(&input_layer); +// input_layer = ctx0.op_mul(&input_layer, &self.norm); + +// let embeddings_tensor: ggml::Tensor = input_layer.share(); + +// // disable scratch buffer for last layer +// ctx0.use_scratch(None); +// // output embedding weight tied to input embedding +// input_layer = ctx0.op_mul_mat(&self.wte, &input_layer); + +// ( +// gf, +// GraphOutputs { +// result: input_layer, +// embedding_result: embeddings_tensor, +// }, +// ) +// }); + +// // finish evaluation +// common::read_last_token(session, &outputs.result, n_vocab, n); +// common::extract_logits(output_request, &outputs.result, n_vocab, n); +// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); +// } + +// fn hyperparameters(&self) -> &Self::Hyperparameters { +// &self.hyperparameters +// } + +// fn tokenizer(&self) -> &Tokenizer { +// &self.tokenizer +// } + +// fn context_size(&self) -> usize { +// self.params.context_size +// } + +// fn bot_token_id(&self) -> Option { +// self.tokenizer.id("<|padding|>".as_bytes()) +// } + +// fn eot_token_id(&self) -> TokenId { +// self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() +// } + +// fn quantize_tensors() -> Vec { +// vec![Regex::new(".*weight").unwrap()] +// } + +// fn skip_quantize_tensors() -> Vec { +// vec![] +// } + +// fn supports_rewind(&self) -> bool { +// true +// } +// } + +// /// MPT [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +// #[derive(Debug, Default, PartialEq, Clone, Copy)] +// pub struct Hyperparameters { +// /// Size of the model's embedding layer +// n_embd: usize, +// /// Maximum sequence length +// max_seq_len: usize, +// /// n_heads +// n_head: usize, +// /// Number of layers in the model +// n_layer: usize, +// /// Size of the model's vocabulary +// n_vocab: usize, +// /// Alibi bias max +// alibi_bias_max: f32, +// /// Clip KQV +// clip_kqv: f32, +// /// file_type +// file_type: FileType, +// } +// impl Eq for Hyperparameters {} + +// impl llm_base::Hyperparameters for Hyperparameters { +// fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { +// let hyperparameters = Hyperparameters { +// n_embd: util::read_i32(reader)?.try_into()?, +// max_seq_len: util::read_i32(reader)?.try_into()?, +// n_head: util::read_i32(reader)?.try_into()?, +// n_layer: util::read_i32(reader)?.try_into()?, +// n_vocab: util::read_i32(reader)?.try_into()?, +// alibi_bias_max: util::read_f32(reader)?, +// clip_kqv: util::read_f32(reader)?, +// file_type: util::read_filetype(reader)?, +// }; + +// Ok(hyperparameters) +// } + +// fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { +// util::write_i32(writer, self.n_embd.try_into()?)?; +// util::write_i32(writer, self.max_seq_len.try_into()?)?; +// util::write_i32(writer, self.n_head.try_into()?)?; +// util::write_i32(writer, self.n_layer.try_into()?)?; +// util::write_i32(writer, self.n_vocab.try_into()?)?; +// util::write_f32(writer, self.alibi_bias_max)?; +// util::write_f32(writer, self.clip_kqv)?; +// util::write_i32(writer, self.file_type.into())?; +// Ok(()) +// } + +// fn n_vocabulary(&self) -> usize { +// self.n_vocab +// } + +// fn file_type(&self) -> Option { +// Some(self.file_type) +// } + +// fn file_type_mut(&mut self) -> Option<&mut FileType> { +// Some(&mut self.file_type) +// } +// } + +// struct Layer { +// // pre normalization +// norm_1_weight: Tensor, + +// // attention +// c_attn_wqkv_weight: Tensor, +// c_attn_out_proj_weight: Tensor, + +// // post normalization +// norm_2_weight: Tensor, + +// // ff +// ffn_up_proj: Tensor, +// ffn_down_proj: Tensor, +// } From eb8c5085df424e2859813e298086b1e34b08679d Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 6 Sep 2023 23:29:35 +0200 Subject: [PATCH 14/33] feat(llama): validate tensor data layout --- crates/llm-base/src/loader.rs | 14 ++++++++++++++ crates/models/llama/src/lib.rs | 17 ++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index e5330410..eb254091 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -348,6 +348,7 @@ pub trait MetadataExt { key: &'a str, getter: impl Fn(&MetadataValue) -> Option<&T>, ) -> Result<&'a T, LoadError>; + fn fallible_get_string(&self, key: &str) -> Result; fn fallible_get_countable(&self, key: &str) -> Result; } impl MetadataExt for Metadata { @@ -370,6 +371,19 @@ impl MetadataExt for Metadata { }) } + // TODO: see if we can generalize this with `ToOwned` or something? + fn fallible_get_string(&self, key: &str) -> Result { + let metadata_value = self.fallible_get(key)?; + Ok(metadata_value + .as_string() + .ok_or_else(|| LoadError::InvalidMetadataType { + key: key.to_string(), + expected_type: MetadataValueType::String, + actual_type: metadata_value.value_type(), + })? + .to_string()) + } + fn fallible_get_countable(&self, key: &str) -> Result { let metadata_value = self.fallible_get(key)?; match metadata_value { diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 09955f4e..f32669a2 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -2,16 +2,15 @@ #![deny(missing_docs)] use llm_base::{ - ggml::{ - self, - format::gguf::{Metadata, MetadataValue, MetadataValueTypeFromRustType}, - }, + ggml::{self, format::gguf::Metadata}, model::{common, HyperparametersWriteError}, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, MetadataExt, ModelContext, ModelParameters, ModelTensorLoader, OutputRequest, Regex, TokenId, Tokenizer, }; +const META_TENSOR_DATA_LAYOUT: &str = "Meta AI original pth"; + /// The LLaMA model. Ref: [Introducing LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) /// /// # Safety @@ -47,6 +46,8 @@ impl KnownModel for Llama { tokenizer: Tokenizer, tensor_loader: ModelTensorLoader, ) -> Result { + assert_eq!(hyperparameters.tensor_data_layout, META_TENSOR_DATA_LAYOUT); + let mut tl = tensor_loader; // model-global weights @@ -136,6 +137,7 @@ impl KnownModel for Llama { head_count_kv, block_count, file_type: _, + tensor_data_layout: _, } = self.hyperparameters; let embedding_length_gqa = embedding_length / self.hyperparameters.grouped_query_attention(); @@ -392,7 +394,7 @@ impl KnownModel for Llama { } /// LLaMA [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone)] pub struct Hyperparameters { /// Size of the model's vocabulary pub vocabulary_count: usize, @@ -406,6 +408,8 @@ pub struct Hyperparameters { pub block_count: usize, /// file_type pub file_type: Option, + /// The tensor data layout that this model was encoded with + pub tensor_data_layout: String, } impl llm_base::Hyperparameters for Hyperparameters { fn read_gguf(metadata: &Metadata) -> Result { @@ -423,6 +427,9 @@ impl llm_base::Hyperparameters for Hyperparameters { .and_then(|v| v.as_uint32()) .map(|v| FileType::try_from(v as i32)) .transpose()?, + tensor_data_layout: metadata + .fallible_get_string("llama.tensor_data_layout") + .unwrap_or(META_TENSOR_DATA_LAYOUT.to_string()), }) } From f398ebdf828f2c92ed6adedd70715f2fa52f1786 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 8 Oct 2023 04:29:19 +0200 Subject: [PATCH 15/33] feat(llm): remove architecture param --- Cargo.lock | 1 + binaries/llm-cli/src/cli_args.rs | 16 +- binaries/llm-test/configs/bloom.json | 1 - binaries/llm-test/configs/gptj.json | 1 - binaries/llm-test/configs/gptneox.json | 1 - binaries/llm-test/configs/llama.json | 1 - binaries/llm-test/configs/mpt.json | 1 - binaries/llm-test/src/common.rs | 4 +- binaries/llm-test/src/delete.rs | 4 +- binaries/llm-test/src/main.rs | 274 +++++------- binaries/llm-test/src/tokens.rs | 4 +- crates/ggml/src/format/gguf/metadata.rs | 85 ++++ crates/llm-base/src/lib.rs | 6 +- crates/llm-base/src/loader.rs | 540 +++++------------------- crates/llm-base/src/lora.rs | 33 +- crates/llm-base/src/model/mod.rs | 50 +-- crates/llm-base/src/quantize.rs | 8 +- crates/llm-base/src/tokenizer/mod.rs | 107 +++-- crates/llm/Cargo.toml | 1 + crates/llm/examples/embeddings.rs | 9 +- crates/llm/examples/inference.rs | 9 +- crates/llm/examples/vicuna-chat.rs | 9 +- crates/llm/src/lib.rs | 77 +--- crates/llm/src/loader.rs | 380 +++++++++++++++++ crates/models/llama/src/lib.rs | 20 +- 25 files changed, 839 insertions(+), 803 deletions(-) create mode 100644 crates/llm/src/loader.rs diff --git a/Cargo.lock b/Cargo.lock index ab1f6bcf..811ef654 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2146,6 +2146,7 @@ dependencies = [ "serde", "serde_json", "spinoff", + "thiserror", "tracing", ] diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 00c0b0ed..97fa8a85 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -1,5 +1,4 @@ use std::{ - fmt, ops::Deref, path::{Path, PathBuf}, }; @@ -7,7 +6,7 @@ use std::{ use clap::{Parser, ValueEnum}; use color_eyre::eyre::{self, WrapErr}; use llm::{ - ggml_format, samplers::build_sampler, ElementType, InferenceParameters, InferenceSessionConfig, + samplers::build_sampler, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, LoadProgress, Model, ModelKVMemoryType, ModelParameters, RoPEOverrides, TokenBias, TokenId, TokenizerSource, }; @@ -427,22 +426,12 @@ impl ModelTokenizer { } } -#[derive(Parser, Debug)] -pub struct ModelArchitecture { - /// The model architecture to use. Will attempt to guess if not specified. - #[arg(long, short = 'a')] - pub model_architecture: Option, -} - #[derive(Parser, Debug)] pub struct ModelAndTokenizer { /// Where to load the model from #[arg(long, short = 'm')] pub model_path: PathBuf, - #[command(flatten)] - pub architecture: ModelArchitecture, - #[command(flatten)] pub tokenizer: ModelTokenizer, } @@ -539,8 +528,7 @@ impl ModelLoad { } }; - let model = llm::load_dynamic( - self.model_and_tokenizer.architecture.model_architecture, + let model = llm::load( &self.model_and_tokenizer.model_path, tokenizer_source, params, diff --git a/binaries/llm-test/configs/bloom.json b/binaries/llm-test/configs/bloom.json index 5383386d..cec5e750 100644 --- a/binaries/llm-test/configs/bloom.json +++ b/binaries/llm-test/configs/bloom.json @@ -1,7 +1,6 @@ { "url": "https://huggingface.co/rustformers/bloom-ggml/resolve/main/bloom-560m-q4_0.bin", "filename": "bloom.bin", - "architecture": "bloom", "test_cases": [ { "Inference": { diff --git a/binaries/llm-test/configs/gptj.json b/binaries/llm-test/configs/gptj.json index 50966748..febf76f9 100644 --- a/binaries/llm-test/configs/gptj.json +++ b/binaries/llm-test/configs/gptj.json @@ -1,7 +1,6 @@ { "url": "https://huggingface.co/rustformers/gpt-j-ggml/resolve/main/gpt-j-6b-q4_0-ggjt.bin", "filename": "gptj.bin", - "architecture": "gptj", "test_cases": [ { "Inference": { diff --git a/binaries/llm-test/configs/gptneox.json b/binaries/llm-test/configs/gptneox.json index c8cce4d9..96c58906 100644 --- a/binaries/llm-test/configs/gptneox.json +++ b/binaries/llm-test/configs/gptneox.json @@ -1,7 +1,6 @@ { "url": "https://huggingface.co/rustformers/redpajama-3b-ggml/resolve/main/RedPajama-INCITE-Base-3B-v1-q4_0-ggjt.bin", "filename": "gptneox.bin", - "architecture": "gptneox", "test_cases": [ { "Inference": { diff --git a/binaries/llm-test/configs/llama.json b/binaries/llm-test/configs/llama.json index 9bd6094a..9eec8a73 100644 --- a/binaries/llm-test/configs/llama.json +++ b/binaries/llm-test/configs/llama.json @@ -1,7 +1,6 @@ { "url": "https://huggingface.co/rustformers/open-llama-ggml/resolve/main/open_llama_3b-q4_0-ggjt.bin", "filename": "llama.bin", - "architecture": "llama", "test_cases": [ { "Inference": { diff --git a/binaries/llm-test/configs/mpt.json b/binaries/llm-test/configs/mpt.json index 57a8bc89..37b39bf3 100644 --- a/binaries/llm-test/configs/mpt.json +++ b/binaries/llm-test/configs/mpt.json @@ -1,7 +1,6 @@ { "url": "https://huggingface.co/rustformers/mpt-7b-ggml/resolve/main/mpt-7b-q4_0-ggjt.bin", "filename": "mpt.bin", - "architecture": "mpt", "test_cases": [ { "Inference": { diff --git a/binaries/llm-test/src/common.rs b/binaries/llm-test/src/common.rs index 46ab2a50..63ea41d4 100644 --- a/binaries/llm-test/src/common.rs +++ b/binaries/llm-test/src/common.rs @@ -1,6 +1,8 @@ //! Tests that are run on every model, regardless of config. -pub(super) fn can_send(model: M) -> anyhow::Result { +use llm::Model; + +pub(super) fn can_send(model: Box) -> anyhow::Result> { let model = std::thread::spawn(move || model) .join() .map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}")); diff --git a/binaries/llm-test/src/delete.rs b/binaries/llm-test/src/delete.rs index 7bcf81df..d8e40aa8 100644 --- a/binaries/llm-test/src/delete.rs +++ b/binaries/llm-test/src/delete.rs @@ -12,7 +12,7 @@ use serde::Serialize; use crate::{TestCaseReport, TestCaseReportMeta}; /// Tests that models can delete tokens without changing the model's behavior. -pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { +pub(crate) fn can_delete(model: &dyn Model) -> TestCaseReport { let report = DeleteReport::default(); let mut session = model.start_session(Default::default()); let mut output = OutputRequest { @@ -61,7 +61,7 @@ pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { fn feed_prompt( prompt: &str, session: &mut InferenceSession, - model: &impl Model, + model: &dyn Model, output: &mut OutputRequest, ) -> Result<(), llm::InferenceError> { session.feed_prompt(model, prompt, output, always_continue) diff --git a/binaries/llm-test/src/main.rs b/binaries/llm-test/src/main.rs index 4b982505..0493821f 100644 --- a/binaries/llm-test/src/main.rs +++ b/binaries/llm-test/src/main.rs @@ -18,7 +18,6 @@ use std::{ fs::{self, File}, io::Write, path::{Path, PathBuf}, - str::FromStr, time::Instant, }; @@ -61,7 +60,7 @@ async fn main() -> anyhow::Result<()> { fs::create_dir_all(&results_dir)?; // Load configurations - let test_configs: HashMap = fs::read_dir(configs_dir)? + let mut test_configs: HashMap = fs::read_dir(configs_dir)? .filter_map(Result::ok) .map(|de| de.path()) .filter(|p| p.is_file()) @@ -78,24 +77,20 @@ async fn main() -> anyhow::Result<()> { }; // Test models - let mut test_configs = if let Some(specific_architecture) = specific_model { - vec![test_configs - .get(&specific_architecture) - .with_context(|| { - format!( - "No config found for `{specific_architecture}`. Available configs: {:?}", - test_configs.keys() - ) - })? - .clone()] - } else { - test_configs.values().cloned().collect() - }; - test_configs.sort_by_key(|tc| tc.architecture.clone()); + if let Some(specific_architecture) = specific_model { + test_configs.retain(|k, _| *k == specific_architecture); + } let test_configs_len = test_configs.len(); - for test_config in test_configs { - test_model(&model_config, &test_config, &download_dir, &results_dir).await?; + for (test_name, test_config) in &test_configs { + test_model( + &model_config, + test_name, + test_config, + &download_dir, + &results_dir, + ) + .await?; if test_configs_len > 1 { log::info!("----"); } @@ -114,7 +109,6 @@ struct ModelConfig { struct TestConfig { url: String, filename: PathBuf, - architecture: String, test_cases: Vec, } @@ -165,13 +159,12 @@ pub enum TestCaseReportInner { async fn test_model( model_config: &ModelConfig, + test_name: &str, test_config: &TestConfig, download_dir: &Path, results_dir: &Path, ) -> anyhow::Result<()> { // Load the model - let architecture = llm::ModelArchitecture::from_str(&test_config.architecture)?; - let local_path = if test_config.filename.is_file() { // If this filename points towards a valid file, use it test_config.filename.clone() @@ -180,160 +173,127 @@ async fn test_model( download_dir.join(&test_config.filename) }; - log::info!( - "Testing architecture: `{}` ({})", - test_config.architecture, - local_path.display() - ); + log::info!("Testing `{test_name}`: `{}`", local_path.display()); // Download the model if necessary download_file(&test_config.url, &local_path).await?; - struct TestVisitor<'a> { - model_config: &'a ModelConfig, - test_config: &'a TestConfig, - results_dir: &'a Path, - local_path: &'a Path, - } - impl<'a> llm::ModelArchitectureVisitor> for TestVisitor<'a> { - fn visit(&mut self) -> anyhow::Result<()> { - let Self { - model_config, - test_config, - results_dir, - local_path, - } = *self; - - let start_time = Instant::now(); - - let model = { - let model = llm::load::( - local_path, - llm::TokenizerSource::Embedded, - llm::ModelParameters { - prefer_mmap: model_config.mmap, - ..Default::default() - }, - |progress| { - let print = !matches!(&progress, - llm::LoadProgress::TensorLoaded { current_tensor, tensor_count } - if current_tensor % (tensor_count / 10) != 0 - ); - - if print { - log::info!("loading: {:?}", progress); - } - }, + let start_time = Instant::now(); + + let model = { + let model = llm::load( + &local_path, + llm::TokenizerSource::Embedded, + llm::ModelParameters { + prefer_mmap: model_config.mmap, + ..Default::default() + }, + |progress| { + let print = !matches!(&progress, + llm::LoadProgress::TensorLoaded { current_tensor, tensor_count } + if current_tensor % (tensor_count / 10) != 0 ); - match model { - Ok(m) => m, - Err(err) => { - write_report( - test_config, - results_dir, - &Report::LoadFail { - error: format!("Failed to load model: {}", err), - }, - )?; - - return Err(err.into()); - } - } - }; - - log::info!( - "Model fully loaded! Elapsed: {}ms", - start_time.elapsed().as_millis() - ); - - // - // Non-model-specific tests - // - - // Confirm that the model can be sent to a thread, then sent back - let model = common::can_send(model)?; - - // Confirm that the hyperparameters can be roundtripped - // common::can_roundtrip_hyperparameters(&model)?; - - // - - // - // Model-specific tests - // - - // Run the test cases - let mut test_case_reports = vec![]; - for test_case in &test_config.test_cases { - match test_case { - TestCase::Inference { - input, - output, - maximum_token_count, - } => test_case_reports.push(inference::can_infer( - &model, - model_config, - input, - output.as_deref(), - *maximum_token_count, - )?), - TestCase::Tokens { input, output } => { - test_case_reports.push(tokens::can_feed(&model, input, *output)); - } - TestCase::Delete {} => { - test_case_reports.push(delete::can_delete(&model)); - } + if print { + log::info!("loading: {:?}", progress); } + }, + ); + + match model { + Ok(m) => m, + Err(err) => { + write_report( + test_name, + results_dir, + &Report::LoadFail { + error: format!("Failed to load model: {}", err), + }, + )?; + + return Err(err.into()); } - let first_error: Option = - test_case_reports - .iter() - .find_map(|report: &TestCaseReport| match &report.meta { - TestCaseReportMeta::Error { error } => Some(error.clone()), - _ => None, - }); - - // Save the results - // Serialize the report to a JSON string - write_report( - test_config, - results_dir, - &Report::LoadSuccess { - test_cases: test_case_reports, - }, - )?; - - // Optionally, panic if there was an error - if let Some(err) = first_error { - panic!("Error: {}", err); - } + } + }; - log::info!( - "Successfully tested architecture `{}`!", - test_config.architecture - ); + log::info!( + "Model fully loaded! Elapsed: {}ms", + start_time.elapsed().as_millis() + ); + + // + // Non-model-specific tests + // + + // Confirm that the model can be sent to a thread, then sent back + let model = common::can_send(model)?; + + // Confirm that the hyperparameters can be roundtripped + // common::can_roundtrip_hyperparameters(&model)?; + + // - Ok(()) + // + // Model-specific tests + // + + // Run the test cases + let mut test_case_reports = vec![]; + for test_case in &test_config.test_cases { + match test_case { + TestCase::Inference { + input, + output, + maximum_token_count, + } => test_case_reports.push(inference::can_infer( + model.as_ref(), + model_config, + input, + output.as_deref(), + *maximum_token_count, + )?), + TestCase::Tokens { input, output } => { + test_case_reports.push(tokens::can_feed(model.as_ref(), input, *output)); + } + TestCase::Delete {} => { + test_case_reports.push(delete::can_delete(model.as_ref())); + } } } - architecture.visit(&mut TestVisitor { - model_config, - test_config, + let first_error: Option = + test_case_reports + .iter() + .find_map(|report: &TestCaseReport| match &report.meta { + TestCaseReportMeta::Error { error } => Some(error.clone()), + _ => None, + }); + + // Save the results + // Serialize the report to a JSON string + write_report( + test_name, results_dir, - local_path: &local_path, - })?; + &Report::LoadSuccess { + test_cases: test_case_reports, + }, + )?; + + // Optionally, panic if there was an error + if let Some(err) = first_error { + panic!("Error: {}", err); + } + + log::info!( + "Successfully tested `{test_name}`: `{}`!", + local_path.display() + ); Ok(()) } -fn write_report( - test_config: &TestConfig, - results_dir: &Path, - report: &Report, -) -> anyhow::Result<()> { +fn write_report(test_name: &str, results_dir: &Path, report: &Report) -> anyhow::Result<()> { let json_report = serde_json::to_string_pretty(&report)?; - let report_path = results_dir.join(format!("{}.json", test_config.architecture)); + let report_path = results_dir.join(format!("{test_name}.json")); fs::write(report_path, json_report)?; Ok(()) } diff --git a/binaries/llm-test/src/tokens.rs b/binaries/llm-test/src/tokens.rs index adddd678..b9fed471 100644 --- a/binaries/llm-test/src/tokens.rs +++ b/binaries/llm-test/src/tokens.rs @@ -12,7 +12,7 @@ use serde::Serialize; use crate::{TestCaseReport, TestCaseReportMeta}; /// Tests that the model performs as expected when feeding tokens -pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) -> TestCaseReport { +pub(crate) fn can_feed(model: &dyn Model, input: &str, expected_output: usize) -> TestCaseReport { let mut report = TokensReport::default(); let mut session = model.start_session(Default::default()); let mut output = OutputRequest { @@ -62,7 +62,7 @@ pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) fn feed_prompt( prompt: &str, session: &mut InferenceSession, - model: &impl Model, + model: &dyn Model, output: &mut OutputRequest, ) -> Result<(), llm::InferenceError> { session.feed_prompt(model, prompt, output, always_continue) diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index a39301f4..e58bcb13 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -1,9 +1,12 @@ use std::{collections::HashMap, io::BufRead}; +use thiserror::Error; + use crate::util; use super::{GgufContext, GgufLoadError}; +// TODO: make this a newtype instead pub type Metadata = HashMap; #[repr(u32)] @@ -467,3 +470,85 @@ impl MetadataArrayValue { } } } + +#[doc(hidden)] +pub trait MetadataExt { + fn fallible_get(&self, key: &str) -> Result<&MetadataValue, MetadataError>; + fn fallible_typed_get<'a, T: MetadataValueTypeFromRustType>( + &'a self, + key: &'a str, + getter: impl Fn(&MetadataValue) -> Option<&T>, + ) -> Result<&'a T, MetadataError>; + fn fallible_get_string(&self, key: &str) -> Result; + fn fallible_get_countable(&self, key: &str) -> Result; +} +impl MetadataExt for Metadata { + fn fallible_get(&self, key: &str) -> Result<&MetadataValue, MetadataError> { + self.get(key).ok_or_else(|| MetadataError::MissingKey { + key: key.to_owned(), + }) + } + + fn fallible_typed_get<'a, T: MetadataValueTypeFromRustType>( + &'a self, + key: &'a str, + getter: impl Fn(&MetadataValue) -> Option<&T>, + ) -> Result<&'a T, MetadataError> { + let metadata_value = self.fallible_get(key)?; + getter(metadata_value).ok_or_else(|| MetadataError::InvalidType { + key: key.to_string(), + expected_type: T::value_type(), + actual_type: metadata_value.value_type(), + }) + } + + // TODO: see if we can generalize this with `ToOwned` or something? + fn fallible_get_string(&self, key: &str) -> Result { + let metadata_value = self.fallible_get(key)?; + Ok(metadata_value + .as_string() + .ok_or_else(|| MetadataError::InvalidType { + key: key.to_string(), + expected_type: MetadataValueType::String, + actual_type: metadata_value.value_type(), + })? + .to_string()) + } + + fn fallible_get_countable(&self, key: &str) -> Result { + let metadata_value = self.fallible_get(key)?; + match metadata_value { + MetadataValue::UInt32(v) => Ok(usize::try_from(*v)?), + MetadataValue::UInt64(v) => Ok(usize::try_from(*v)?), + _ => Err(MetadataError::InvalidType { + key: key.to_string(), + expected_type: MetadataValueType::UInt64, + actual_type: metadata_value.value_type(), + }), + } + } +} + +#[derive(Error, Debug)] +/// Errors encountered during the loading process. +pub enum MetadataError { + /// The model expected a metadata key-value pair, but the key was missing. + #[error("missing metadata key {key:?}")] + MissingKey { + /// The key that was missing. + key: String, + }, + /// The metadata key-value pair was not of the expected type. + #[error("metadata key {key:?} was not of the expected type")] + InvalidType { + /// The key with the invalid type. + key: String, + /// The expected type. + expected_type: MetadataValueType, + /// The actual type. + actual_type: MetadataValueType, + }, + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), +} diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index 0c54c954..fe521c08 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -8,11 +8,11 @@ #![deny(missing_docs)] mod inference_session; -mod loader; mod lora; mod quantize; mod tokenizer; +pub mod loader; pub mod model; pub mod samplers; pub mod util; @@ -30,8 +30,8 @@ pub use inference_session::{ }; pub use llm_samplers::prelude::{Sampler, SamplerChain}; pub use loader::{ - load, load_progress_callback_stdout, ContainerType, FileMagic, FileType, FileTypeFormat, - LoadError, LoadProgress, MetadataExt, ModelTensorLoader, + load_known_internal, ContainerType, FileMagic, FileType, FileTypeFormat, LoadKnownError, + MetadataError, MetadataExt, ModelTensorLoader, TensorLoadError, }; pub use lora::{LoraAdapter, LoraParameters}; pub use memmap2::Mmap; diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index eb254091..edfb59e3 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -1,27 +1,27 @@ +//! Functionality for loading models. Very barebones; designed to be driven by `llm`. + use std::{ - error::Error, fmt::{Display, Formatter}, - fs::File, - io::{BufReader, Read, Seek, SeekFrom}, - path::{Path, PathBuf}, + io::{BufRead, Seek, SeekFrom}, + path::Path, sync::Arc, }; use crate::{ - Hyperparameters, KnownModel, LoraAdapter, ModelContext, ModelParameters, Tokenizer, - TokenizerLoadError, TokenizerSource, + model::{Hyperparameters, HyperparametersReadError}, + KnownModel, LoraAdapter, ModelContext, ModelParameters, Tokenizer, }; use ggml::{ - format::gguf::{ - self, GgufLoadError, Metadata, MetadataValue, MetadataValueType, - MetadataValueTypeFromRustType, TensorInfo, - }, + format::gguf::{Gguf, Metadata, TensorInfo}, + sys::llama::llama_ftype, Context, MAX_NAME_LENGTH, }; -pub use ggml::{format::ContainerType, util::FileMagic}; -use memmap2::Mmap; +pub use ggml::{ + format::gguf::{MetadataError, MetadataExt}, + format::ContainerType, + util::FileMagic, +}; use thiserror::Error; -use tracing::log; #[derive(Debug, PartialEq, Clone, Copy, Eq, Default)] /// Information about the file. @@ -34,16 +34,15 @@ pub struct FileType { impl From for i32 { fn from(value: FileType) -> Self { (value.quantization_version * ggml::QNT_VERSION_FACTOR) as i32 - + ggml::sys::llama::llama_ftype::from(value.format) + + llama_ftype::from(value.format) } } impl TryFrom for FileType { - type Error = LoadError; + type Error = llama_ftype; fn try_from(value: i32) -> Result { - let format = FileTypeFormat::try_from( - ((value as u32) % ggml::QNT_VERSION_FACTOR) as ggml::sys::llama::llama_ftype, - )?; + let format = + FileTypeFormat::try_from(((value as u32) % ggml::QNT_VERSION_FACTOR) as llama_ftype)?; Ok(Self { format, @@ -56,6 +55,23 @@ impl Display for FileType { write!(f, "{}_qnt{}", self.format, self.quantization_version) } } +impl FileType { + /// Helper function that reads the file type from the metadata and converts + /// it to the enum, or fails with a `HyperparametersReadError`. + pub fn read_for_hyperparameters( + metadata: &Metadata, + ) -> Result, HyperparametersReadError> { + metadata + .get("general.file_type") + .and_then(|v| v.as_uint32()) + .map(|v| { + FileType::try_from(v as i32).map_err(|ftype| { + HyperparametersReadError::UnsupportedFileType { file_type: ftype } + }) + }) + .transpose() + } +} /// How the tensors are stored in GGML LLM models. #[derive(Debug, PartialEq, Clone, Copy, Eq, Default)] @@ -98,10 +114,10 @@ pub enum FileTypeFormat { /// The tensors are stored using the `Q6_K` quantization scheme. MostlyQ6_K, } -impl TryFrom for FileTypeFormat { - type Error = LoadError; +impl TryFrom for FileTypeFormat { + type Error = llama_ftype; - fn try_from(value: ggml::sys::llama::llama_ftype) -> Result { + fn try_from(value: llama_ftype) -> Result { use ggml::sys::llama::*; match value { LLAMA_FTYPE_ALL_F32 => Ok(FileTypeFormat::F32), @@ -122,13 +138,11 @@ impl TryFrom for FileTypeFormat { LLAMA_FTYPE_MOSTLY_Q5_K_M => Ok(FileTypeFormat::MostlyQ5_K_M), LLAMA_FTYPE_MOSTLY_Q6_K => Ok(FileTypeFormat::MostlyQ6_K), #[allow(clippy::unnecessary_cast)] - _ => Err(LoadError::UnsupportedFileType { - file_type_format: value as u32, - }), + _ => Err(value), } } } -impl From for ggml::sys::llama::llama_ftype { +impl From for llama_ftype { fn from(value: FileTypeFormat) -> Self { use ggml::sys::llama::*; match value { @@ -180,397 +194,78 @@ impl Display for FileTypeFormat { } } -/// Each variant represents a step within the process of loading the model. -/// These can be used to report progress to the user. -#[derive(Clone, PartialEq, Eq, Debug)] -pub enum LoadProgress<'a> { - /// The hyperparameters have been loaded from the model. - HyperparametersLoaded, - /// The context has been created. - ContextSize { - /// The size of the context. - bytes: usize, - }, - /// A tensor was patched with a LoRA. - LoraApplied { - /// The name of the patched tensor. - name: &'a str, - /// LoRA file the patch was applied from. - source: &'a Path, - }, - /// A tensor from the current part has been loaded. - TensorLoaded { - /// The current tensor (0-indexed). - current_tensor: usize, - /// The number of total tensors. - tensor_count: usize, - }, - /// A model part has finished fully loading. - Loaded { - /// The number of bytes in the part. - file_size: u64, - /// The number of tensors in the part. - tensor_count: usize, - }, -} +/// Helper trait that implements traits required for reading. +pub trait Source: BufRead + Seek {} +impl Source for S {} +/// Errors that can occur when loading a known model. #[derive(Error, Debug)] -/// Errors encountered during the loading process. -pub enum LoadError { - #[error("the file {path:?} does not exist")] - /// The file does not exist. - FileDoesNotExist { - /// The path that failed. - path: PathBuf, - }, - #[error("could not open file {path:?}")] - /// A file failed to open. - OpenFileFailed { - /// The original error. - source: std::io::Error, - /// The path that failed. - path: PathBuf, - }, - #[error("non-specific I/O error")] - /// A non-specific IO error. - Io(#[from] std::io::Error), - #[error("could not convert bytes to a UTF-8 string")] - /// One of the strings encountered was not valid UTF-8. - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("invalid magic value {magic}")] - /// An invalid magic value was encountered during the loading process. - InvalidMagic { - /// The magic value that was encountered. - magic: FileMagic, - }, - #[error("invalid file format {container_type:?}")] - /// The version of the format is not supported by this version of `llm`. - InvalidFormatVersion { - /// The format that was encountered. - container_type: ContainerType, - }, - #[error("unknown tensor `{tensor_name}` in {path:?}")] - /// The tensor `tensor_name` is required for this model architecture, - /// but was not found in the model. - UnknownTensor { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - /// The tensor `tensor_name` had an unsupported element type. - #[error("invalid element type {element_type} for tensor `{tensor_name}` in {path:?}")] - UnsupportedElementType { - /// The name of the tensor. - tensor_name: String, - /// The element type that was encountered. - element_type: u32, - /// The path that failed. - path: PathBuf, - }, - /// The tokenizer could not be loaded. - #[error("could not load tokenizer: {0}")] - TokenizerLoadFail(#[from] TokenizerLoadError), - /// The quantization version was missing, despite this model containing quantized tensors. - #[error("quantization version was missing, despite model containing quantized tensors")] - MissingQuantizationVersion, - /// The quantization version is not supported by this version of `llm`. - #[error("quantization version {quantization_version:?} is not supported")] - UnsupportedQuantizationVersion { - /// The quantization version that was encountered. - quantization_version: MetadataValue, - }, - /// A tensor with an unsupported number of dimensions was encountered. - #[error( - "tensor {tensor_name} has {dimensions} dimensions, but only 1-3 dimensions are supported" - )] - UnsupportedTensorDimensionCount { - /// The name of the tensor. - tensor_name: String, - /// The number of dimensions that were encountered. - dimensions: usize, - }, - /// The model expected a metadata key-value pair, but the key was missing. - #[error("missing metadata key {key:?}")] - MissingMetadataKey { - /// The key that was missing. - key: String, - }, - /// The metadata key-value pair was not of the expected type. - #[error("metadata key {key:?} was not of the expected type")] - InvalidMetadataType { - /// The key with the invalid type. - key: String, - /// The expected type. - expected_type: MetadataValueType, - /// The actual type. - actual_type: MetadataValueType, - }, - /// The file type within the model was not supported by this version of `llm`. - #[error("file type {file_type_format} is not supported")] - UnsupportedFileType { - /// The file type format (ignoring the quantization version) that was encountered. - file_type_format: u32, - }, -} -impl LoadError { - #[doc(hidden)] - pub fn from_gguf(value: GgufLoadError, path: PathBuf) -> Self { - match value { - GgufLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { magic }, - GgufLoadError::InvalidFormatVersion(container_type) => { - LoadError::InvalidFormatVersion { container_type } - } - GgufLoadError::Io(err) => LoadError::Io(err), - GgufLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err), - GgufLoadError::InvalidIntegerConversion(err) => { - LoadError::InvalidIntegerConversion(err) - } - GgufLoadError::UnsupportedElementType { tensor_name, ftype } => { - LoadError::UnsupportedElementType { - path, - tensor_name, - element_type: ftype, - } - } - } - } +pub enum LoadKnownError { + /// Failed to read the hyperparameters + #[error("{0}")] + HyperparametersReadError(#[from] HyperparametersReadError), + /// Failed to load the tensors + #[error("{0}")] + TensorLoadError(#[from] TensorLoadError), } +/// Each variant represents a step within loading a known model. +#[derive(Debug, Copy, Clone)] #[doc(hidden)] -pub trait MetadataExt { - fn fallible_get(&self, key: &str) -> Result<&MetadataValue, LoadError>; - fn fallible_typed_get<'a, T: MetadataValueTypeFromRustType>( - &'a self, - key: &'a str, - getter: impl Fn(&MetadataValue) -> Option<&T>, - ) -> Result<&'a T, LoadError>; - fn fallible_get_string(&self, key: &str) -> Result; - fn fallible_get_countable(&self, key: &str) -> Result; +pub enum LoadKnownProgress<'a> { + /// A LoRA has been applied. + LoraApplied { name: &'a str, source: &'a Path }, + /// A tensor has been loaded. + TensorLoaded { current_tensor: usize }, } -impl MetadataExt for Metadata { - fn fallible_get(&self, key: &str) -> Result<&MetadataValue, LoadError> { - self.get(key).ok_or_else(|| LoadError::MissingMetadataKey { - key: key.to_owned(), - }) - } - - fn fallible_typed_get<'a, T: MetadataValueTypeFromRustType>( - &'a self, - key: &'a str, - getter: impl Fn(&MetadataValue) -> Option<&T>, - ) -> Result<&'a T, LoadError> { - let metadata_value = self.fallible_get(key)?; - getter(metadata_value).ok_or_else(|| LoadError::InvalidMetadataType { - key: key.to_string(), - expected_type: T::value_type(), - actual_type: metadata_value.value_type(), - }) - } - - // TODO: see if we can generalize this with `ToOwned` or something? - fn fallible_get_string(&self, key: &str) -> Result { - let metadata_value = self.fallible_get(key)?; - Ok(metadata_value - .as_string() - .ok_or_else(|| LoadError::InvalidMetadataType { - key: key.to_string(), - expected_type: MetadataValueType::String, - actual_type: metadata_value.value_type(), - })? - .to_string()) - } - fn fallible_get_countable(&self, key: &str) -> Result { - let metadata_value = self.fallible_get(key)?; - match metadata_value { - MetadataValue::UInt32(v) => Ok(usize::try_from(*v)?), - MetadataValue::UInt64(v) => Ok(usize::try_from(*v)?), - _ => Err(LoadError::InvalidMetadataType { - key: key.to_string(), - expected_type: MetadataValueType::UInt64, - actual_type: metadata_value.value_type(), - }), - } - } -} - -/// Load a GGML model from the `path` and configure it per the `params`. The status -/// of the loading process will be reported through `load_progress_callback`. -/// -/// Note that the model must be a single-part model, and the model in `path` -/// *must* match the architecture of `M`. -/// -/// # Panics -/// -/// - If the model does not match the architecture of `M`. This is not checked -/// before execution, so this function will panic if the model does not match -/// the architecture. -/// -/// This is a limitation of the GGML format, which does not -/// store any information about the architecture. -pub fn load( - path: &Path, - tokenizer_source: TokenizerSource, +/// Internal function that takes all of the state that can be derived without +/// knowing a concrete type and loads a concrete model. A *lot* of precondition +/// logic is done in `llm`. +// TODO: think about this design. Do we want to let people to be able to load +// known models directly? +#[doc(hidden)] +#[allow(clippy::too_many_arguments)] +pub fn load_known_internal( + source: &mut dyn Source, + gguf: &Gguf, + tokenizer: Tokenizer, + context: Context, + lora_adapters: Option>, + progress_callback: &mut dyn FnMut(LoadKnownProgress), params: ModelParameters, - mut load_progress_callback: impl FnMut(LoadProgress), -) -> Result { - if !path.exists() { - return Err(LoadError::FileDoesNotExist { - path: path.to_owned(), - }); - } - - let mut file = File::open(path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: path.to_owned(), - })?; - let mut reader = BufReader::new(&file); - log::trace!("Read model file from {:?}", path); - - let mut tokenizer = tokenizer_source.retrieve(path)?; - - let gguf = - gguf::Gguf::load(&mut reader).map_err(|e| LoadError::from_gguf(e, path.to_owned()))?; - log::trace!("Loaded GGML model from reader"); - - let quantization_version = gguf.metadata.get("general.quantization_version"); - log::trace!( - "Determined quantization version of model as {:?}", - quantization_version - ); - - // TODO: this is temporary while we figure out how to handle this - let any_quantized = gguf - .tensor_infos - .values() - .any(|t| t.element_type.is_quantized()); - if any_quantized { - match quantization_version { - Some(MetadataValue::UInt32(2)) => { - // Currently supported version - } - Some(quantization_version) => { - return Err(LoadError::UnsupportedQuantizationVersion { - quantization_version: quantization_version.clone(), - }) - } - None => return Err(LoadError::MissingQuantizationVersion), - } - } - - // Populate the embedded tokenizer if required - if let Tokenizer::Embedded(tokenizer) = &mut tokenizer { - if let Some((tokens, scores)) = gguf.tokenizer_embedded() { - for (i, (token, score)) in tokens.iter().zip(scores.iter()).enumerate() { - tokenizer.push_token(i as u32, token.as_bytes().to_vec(), *score); - } - } else { - return Err(TokenizerLoadError::NoTokenizerFound.into()); - } - } - - let use_mmap = params.prefer_mmap && params.lora_adapters.is_none(); - - let ctx_size = gguf - .tensor_infos - .values() - .map(|ti| ti.calc_absolute_size(use_mmap)) - .sum::(); - log::trace!("Context size: {:?}", ctx_size); - - let mut lora_adapters: Option> = None; - if let Some(lora_paths) = ¶ms.lora_adapters { - let adapters: Result, _> = lora_paths - .iter() - .map(|lora_path| { - // Read the LoRA file - let lora_file = File::open(lora_path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: lora_path.to_owned(), - })?; - let mut lora_reader = BufReader::new(&lora_file); - let gguf = gguf::Gguf::load(&mut lora_reader).map_err(|e| LoadError::from_gguf(e, lora_path.to_owned()))?; - - // Collect the names of the tensors that should be patched - let tensors_to_patch = gguf - .tensor_infos - .keys() - .filter_map(|k| Some(k.rsplit_once('.')?.0.to_owned())) - .collect(); - - log::trace!("Loaded LoRA weights"); - // Return the LoRA patches - #[allow(unreachable_code)] - Ok::<_, LoadError>(LoraAdapter { - tensors: gguf.tensor_infos.clone(), - tensors_to_patch, - file: lora_file, - path: lora_path.to_owned(), - gguf, - scaling: todo!("Calculate scaling from LoRA file metadata (GGUF does not have standardised metadata yet)"), - }) - }) - .collect(); - lora_adapters = Some(adapters?); - } - - (load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size }); - let (context, file_size) = if use_mmap { - let file = File::open(path)?; - unsafe { - let mmap = Mmap::map(&file)?; - let file_size = mmap.len() as u64; - (Context::new_with_mmap(mmap), file_size) - } - } else { - (Context::new_with_allocate(ctx_size), file.metadata()?.len()) - }; - +) -> Result { let hyperparameters = ::read_gguf(&gguf.metadata)?; let tl = ModelTensorLoader { tensor_loader: TensorLoader { - path, - file: &mut file, + source, gguf: &gguf, context, }, lora_adapters, - load_progress_callback: &mut load_progress_callback, + progress_callback, loaded_tensor_count: 0, }; - let model = KnownModel::new(hyperparameters, params, tokenizer, tl)?; - - let tensors_len = gguf.tensor_infos.len(); - (load_progress_callback)(LoadProgress::Loaded { - file_size, - tensor_count: tensors_len, - }); - - log::trace!("Loaded model"); - Ok(model) + Ok(KnownModel::new(hyperparameters, params, tokenizer, tl)?) } /// A helper struct for loading tensors from a model. pub struct ModelTensorLoader<'a> { pub(crate) tensor_loader: TensorLoader<'a>, pub(crate) lora_adapters: Option>, - pub(crate) load_progress_callback: &'a mut dyn FnMut(LoadProgress), + pub(crate) progress_callback: &'a mut dyn FnMut(LoadKnownProgress), pub(crate) loaded_tensor_count: usize, } impl ModelTensorLoader<'_> { /// Load a tensor from the model. - pub fn load(&mut self, name: &str) -> Result { + pub fn load(&mut self, name: &str) -> Result { let (mut tensor, info) = self.tensor_loader.load(name)?; if let Some(lora_adapters) = &mut self.lora_adapters { for lora_adapter in lora_adapters { lora_adapter.patch(name, info, &mut tensor)?; - (self.load_progress_callback)(LoadProgress::LoraApplied { + (self.progress_callback)(LoadKnownProgress::LoraApplied { name, source: &lora_adapter.path, }); @@ -578,9 +273,8 @@ impl ModelTensorLoader<'_> { } self.loaded_tensor_count += 1; - (self.load_progress_callback)(LoadProgress::TensorLoaded { + (self.progress_callback)(LoadKnownProgress::TensorLoaded { current_tensor: self.loaded_tensor_count, - tensor_count: self.tensor_loader.gguf.tensor_infos.len(), }); Ok(tensor) @@ -596,20 +290,18 @@ impl ModelTensorLoader<'_> { } pub(crate) struct TensorLoader<'a> { - pub path: &'a Path, - pub file: &'a mut File, + pub source: &'a mut dyn Source, + pub gguf: &'a Gguf, pub context: Context, - pub gguf: &'a gguf::Gguf, } impl TensorLoader<'_> { - pub fn load(&mut self, name: &str) -> Result<(ggml::Tensor, &TensorInfo), LoadError> { + pub fn load(&mut self, name: &str) -> Result<(ggml::Tensor, &TensorInfo), TensorLoadError> { let info = self .gguf .tensor_infos .get(name) - .ok_or(LoadError::UnknownTensor { + .ok_or(TensorLoadError::UnknownTensor { tensor_name: String::from(name), - path: self.path.to_path_buf(), })?; let ty = info.element_type; @@ -620,7 +312,7 @@ impl TensorLoader<'_> { 2 => self.context.new_tensor_2d(ty, dims[0], dims[1]), 3 => self.context.new_tensor_3d(ty, dims[0], dims[1], dims[2]), other => { - return Err(LoadError::UnsupportedTensorDimensionCount { + return Err(TensorLoadError::UnsupportedTensorDimensionCount { tensor_name: name.to_string(), dimensions: other, }); @@ -637,8 +329,8 @@ impl TensorLoader<'_> { let buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; - self.file.seek(SeekFrom::Start(offset))?; - self.file.read_exact(buf)?; + self.source.seek(SeekFrom::Start(offset))?; + self.source.read_exact(buf)?; } } @@ -657,41 +349,27 @@ impl TensorLoader<'_> { } } -/// A implementation for `load_progress_callback` that outputs to `stdout`. -pub fn load_progress_callback_stdout(progress: LoadProgress) { - match progress { - LoadProgress::HyperparametersLoaded => println!("Loaded hyperparameters"), - LoadProgress::ContextSize { bytes } => println!( - "ggml ctx size = {:.2} MB\n", - bytes as f64 / (1024.0 * 1024.0) - ), - LoadProgress::TensorLoaded { - current_tensor, - tensor_count, - .. - } => { - let current_tensor = current_tensor + 1; - if current_tensor % 8 == 0 { - println!("Loaded tensor {current_tensor}/{tensor_count}"); - } - } - LoadProgress::Loaded { - file_size: byte_size, - tensor_count, - } => { - println!("Loading of model complete"); - println!( - "Model size = {:.2} MB / num tensors = {}", - byte_size as f64 / 1024.0 / 1024.0, - tensor_count - ); - } - LoadProgress::LoraApplied { name, source } => { - println!( - "Patched tensor {} via LoRA from '{}'", - name, - source.file_name().unwrap().to_str().unwrap() - ); - } - }; +#[derive(Error, Debug)] +/// Errors encountered during loaing of tensors. +pub enum TensorLoadError { + #[error("unknown tensor `{tensor_name}`")] + /// The tensor `tensor_name` is required for this model architecture, + /// but was not found in the model. + UnknownTensor { + /// The name of the tensor. + tensor_name: String, + }, + /// A tensor with an unsupported number of dimensions was encountered. + #[error( + "tensor {tensor_name} has {dimensions} dimensions, but only 1-3 dimensions are supported" + )] + UnsupportedTensorDimensionCount { + /// The name of the tensor. + tensor_name: String, + /// The number of dimensions that were encountered. + dimensions: usize, + }, + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), } diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index d44da997..c2010767 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -1,11 +1,15 @@ use crate::{ - loader::TensorLoader, model::HyperparametersWriteError, FileType, Hyperparameters, LoadError, + loader::{Source, TensorLoadError, TensorLoader}, + model::{HyperparametersReadError, HyperparametersWriteError}, + FileType, Hyperparameters, }; -use ggml::{format::gguf::TensorInfo, GraphExecutionPlan}; +use ggml::{ + format::gguf::{Gguf, Metadata, TensorInfo}, + GraphExecutionPlan, +}; use std::{ collections::{HashMap, HashSet}, - fs::File, path::PathBuf, }; @@ -24,14 +28,11 @@ impl LoraParameters { } } impl Hyperparameters for LoraParameters { - fn read_gguf(metadata: &ggml::format::gguf::Metadata) -> Result { + fn read_gguf(metadata: &Metadata) -> Result { todo!() } - fn write_gguf( - &self, - metadata: &mut ggml::format::gguf::Metadata, - ) -> Result<(), HyperparametersWriteError> { + fn write_gguf(&self, metadata: &mut Metadata) -> Result<(), HyperparametersWriteError> { todo!() } @@ -52,12 +53,12 @@ pub struct LoraAdapter { pub tensors: HashMap, /// Names of the tensors that should be patched. pub tensors_to_patch: HashSet, - /// File containing the LoRA weights. - pub file: File, + /// Source containing the LoRA weights. + pub source: Box, /// Path to the LoRA file. pub path: PathBuf, /// The loaded GGUF for the LoRA. - pub gguf: ggml::format::gguf::Gguf, + pub gguf: Gguf, } impl LoraAdapter { @@ -67,7 +68,7 @@ impl LoraAdapter { name: &str, info: &TensorInfo, tensor: &mut ggml::Tensor, - ) -> Result<(), LoadError> { + ) -> Result<(), TensorLoadError> { // Check if we need to patch this tensor if !self.tensors_to_patch.contains(name) { return Ok(()); @@ -105,8 +106,7 @@ impl LoraAdapter { // TODO: test if GPU can be enabled (make it configurable) let patch_context = ggml::Context::new_with_allocate(patch_context_size); let mut loader = TensorLoader { - path: &self.path, - file: &mut self.file, + source: self.source.as_mut(), context: patch_context, gguf: &self.gguf, }; @@ -144,12 +144,11 @@ impl LoraAdapter { Ok(()) } - fn get_info(&self, name: &str) -> Result { + fn get_info(&self, name: &str) -> Result { self.tensors .get(name) .cloned() - .ok_or(LoadError::UnknownTensor { - path: self.path.to_owned(), + .ok_or(TensorLoadError::UnknownTensor { tensor_name: name.to_owned(), }) } diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 768cdf9d..849dc317 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -1,20 +1,18 @@ //! Large language model traits and types -use std::{ - error::Error, - fmt::Debug, - io::{BufRead, Write}, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{fmt::Debug, path::PathBuf, sync::Arc}; -use ggml::{accelerator::Backend, format::gguf::Metadata}; +use ggml::{ + accelerator::Backend, + format::gguf::{Metadata, MetadataError}, + sys::llama::llama_ftype, +}; use regex::Regex; use thiserror::Error; use crate::{ - loader::TensorLoader, tokenizer::TokenId, FileType, InferenceSession, InferenceSessionConfig, - LoadError, LoadProgress, ModelTensorLoader, Tokenizer, TokenizerSource, + tokenizer::TokenId, FileType, InferenceSession, InferenceSessionConfig, ModelTensorLoader, + TensorLoadError, Tokenizer, }; /// Common functions for model evaluation @@ -26,21 +24,6 @@ pub trait KnownModel: Send + Sync { /// Hyperparameters for the model. type Hyperparameters: Hyperparameters; - /// Load this model from the `path` and configure it per the `params`. The status - /// of the loading process will be reported through `load_progress_callback`. This - /// is a helper function on top of [llm_base::load](crate::load). - fn load( - path: &Path, - tokenizer_source: TokenizerSource, - params: ModelParameters, - load_progress_callback: impl FnMut(LoadProgress), - ) -> Result - where - Self: Sized, - { - crate::load(path, tokenizer_source, params, load_progress_callback) - } - /// Creates a new model from the provided [ModelParameters] hyperparameters. /// This function is called by the [load](crate::loader::load) function. fn new( @@ -48,7 +31,7 @@ pub trait KnownModel: Send + Sync { params: ModelParameters, tokenizer: Tokenizer, tensor_loader: ModelTensorLoader, - ) -> Result + ) -> Result where Self: Sized; @@ -167,7 +150,7 @@ impl> Model for M { /// without knowing what they are, as well as writing/reading them as required. pub trait Hyperparameters: Sized + Default + Debug + PartialEq + Eq { /// Read the parameters from GGUF metadata. - fn read_gguf(metadata: &Metadata) -> Result; + fn read_gguf(metadata: &Metadata) -> Result; /// Write the parameters to GGUF metadata. fn write_gguf(&self, metadata: &mut Metadata) -> Result<(), HyperparametersWriteError>; @@ -180,6 +163,19 @@ pub trait Hyperparameters: Sized + Default + Debug + PartialEq + Eq { } #[derive(Error, Debug)] /// Reported from functions that write +pub enum HyperparametersReadError { + #[error("{0}")] + /// A metadata error. + MetadataError(#[from] MetadataError), + /// The file type within the model was not supported by this version of `llm`. + #[error("file type {file_type} is not supported")] + UnsupportedFileType { + /// The file type (ignoring the quantization version) that was encountered. + file_type: llama_ftype, + }, +} +#[derive(Error, Debug)] +/// Reported from functions that write pub enum HyperparametersWriteError { #[error("non-specific I/O error")] /// A non-specific IO error. diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index a870b51e..95d30e50 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -5,7 +5,7 @@ use crate::{ loader::FileTypeFormat, model::HyperparametersWriteError, Hyperparameters, KnownModel, - LoadError, LoadProgress, Tokenizer, + Tokenizer, }; use ggml::format::gguf::GgufSaveError; use half::f16; @@ -72,9 +72,9 @@ pub enum QuantizeProgress<'a> { #[derive(Error, Debug)] /// Errors encountered during the quantization process. pub enum QuantizeError { - #[error("could not load model")] - /// There was an error while attempting to load the model. - Load(#[from] LoadError), + // #[error("could not load model")] + // /// There was an error while attempting to load the model. + // Load(#[from] LoadError), #[error("non-specific I/O error")] /// A non-specific IO error. Io(#[from] std::io::Error), diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 0469098e..fae8a25c 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -1,10 +1,6 @@ -use std::{ - error::Error, - fmt::Display, - path::{Path, PathBuf}, - str::FromStr, -}; +use std::{error::Error, fmt::Display, path::PathBuf, str::FromStr}; +use ggml::format::gguf::Gguf; use thiserror::Error; mod embedded; @@ -36,11 +32,11 @@ pub enum TokenizationError { /// Errors related to loading the tokenizer. #[error("error loading tokenizer from {path}: {error}")] pub enum TokenizerLoadError { - #[error("error loading Hugging Face tokenizer from {path}: {error}")] + #[error("error loading Hugging Face tokenizer from {tokenizer_source}: {error}")] /// An error occurred while loading a Hugging Face tokenizer. HuggingFaceTokenizerError { - /// The path to the tokenizer. - path: PathBuf, + /// The source of the tokenizer that failed. + tokenizer_source: HuggingFaceTokenizerErrorSource, /// The error that occurred during loading. error: Box, }, @@ -49,14 +45,27 @@ pub enum TokenizerLoadError { NoTokenizerFound, } -impl TokenizerLoadError { - fn huggingface_error( - path: impl Into, - error: impl Into>, - ) -> Self { - Self::HuggingFaceTokenizerError { - path: path.into(), - error: error.into(), +/// Used to identify where the tokenizer that errored came from. +// NOTE: We could potentially reuse `TokenizerSource` for this, but I want to avoid +// cloning and/or displaying the entire `String` case. Revisit in future and see if +// I still feel the same. +#[derive(Debug)] +pub enum HuggingFaceTokenizerErrorSource { + /// The tokenizer was loaded from this file. + File(PathBuf), + /// The tokenizer was loaded from thep rovided string. + String, + #[cfg(feature = "tokenizers-remote")] + /// The tokenizer was loaded from the given HF ID. + Remote(String), +} +impl Display for HuggingFaceTokenizerErrorSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::File(file) => write!(f, "file {file:?}"), + Self::String => write!(f, "string"), + #[cfg(feature = "tokenizers-remote")] + Self::Remote(remote) => write!(f, "HF ID {remote:?}"), } } } @@ -81,36 +90,61 @@ pub enum TokenizerSource { /// and may store files locally, so it is not recommended for production use. #[cfg(feature = "tokenizers-remote")] HuggingFaceRemote(String), + // + // TODO: Support embedded huggingface tokenizer from GGUF + // } impl TokenizerSource { /// Retrieve the tokenizer from the source. /// /// Note that this may make a blocking HTTP request to Hugging Face to retrieve the tokenizer. /// if `self` is `Self::HuggingFaceRemote`. - pub fn retrieve(self, model_path: &Path) -> Result { - let _ = model_path; - + pub fn retrieve(self, gguf: &Gguf) -> Result { Ok(match self { #[cfg(feature = "tokenizers-remote")] Self::HuggingFaceRemote(identifier) => HuggingFaceTokenizer::new( - tokenizers::Tokenizer::from_pretrained(&identifier, None) - .map_err(|error| TokenizerLoadError::huggingface_error(model_path, error))?, + tokenizers::Tokenizer::from_pretrained(&identifier, None).map_err(|error| { + TokenizerLoadError::HuggingFaceTokenizerError { + tokenizer_source: HuggingFaceTokenizerErrorSource::Remote( + identifier.clone(), + ), + error: error.into(), + } + })?, ) .into(), - Self::HuggingFaceTokenizerFile(path) => HuggingFaceTokenizer::new( - tokenizers::Tokenizer::from_file(&path) - .map_err(|error| TokenizerLoadError::huggingface_error(path, error))?, - ) - .into(), + Self::HuggingFaceTokenizerFile(path) => { + HuggingFaceTokenizer::new(tokenizers::Tokenizer::from_file(&path).map_err( + |error| TokenizerLoadError::HuggingFaceTokenizerError { + tokenizer_source: HuggingFaceTokenizerErrorSource::File(path.clone()), + error: error.into(), + }, + )?) + .into() + } - Self::HuggingFaceTokenizerString(s) => HuggingFaceTokenizer::new( - tokenizers::Tokenizer::from_str(&s) - .map_err(|error| TokenizerLoadError::huggingface_error(model_path, error))?, - ) - .into(), + Self::HuggingFaceTokenizerString(s) => { + HuggingFaceTokenizer::new(tokenizers::Tokenizer::from_str(&s).map_err(|error| { + TokenizerLoadError::HuggingFaceTokenizerError { + tokenizer_source: HuggingFaceTokenizerErrorSource::String, + error: error.into(), + } + })?) + .into() + } - Self::Embedded => EmbeddedTokenizer::default().into(), + Self::Embedded => { + let mut tokenizer = EmbeddedTokenizer::default(); + if let Some((tokens, scores)) = gguf.tokenizer_embedded() { + for (i, (token, score)) in tokens.iter().zip(scores.iter()).enumerate() { + tokenizer.push_token(i as u32, token.as_bytes().to_vec(), *score); + } + } else { + return Err(TokenizerLoadError::NoTokenizerFound); + } + tokenizer.into() + } }) } } @@ -133,13 +167,6 @@ impl From for Tokenizer { Self::HuggingFace(v) } } -impl Tokenizer { - /// Creates an empty embedded tokenizer, for contexts where you need a tokenizer but don't - /// need to tokenize anything. - pub(crate) fn empty_embedded() -> Self { - Self::Embedded(EmbeddedTokenizer::default()) - } -} impl Tokenizer { /// Converts a token to the token ID it represents in this tokenizer. pub fn id(&self, token: &[u8]) -> Option { diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index f3c93e65..43f3eba7 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -19,6 +19,7 @@ llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" serde = { workspace = true } tracing = { workspace = true } +thiserror = { workspace = true } [dev-dependencies] bytesize = { workspace = true } diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index 427c9e48..795d9740 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -4,7 +4,6 @@ use clap::Parser; #[derive(Parser)] struct Args { - model_architecture: llm::ModelArchitecture, model_path: PathBuf, #[arg(long, short = 'v')] pub tokenizer_path: Option, @@ -32,7 +31,6 @@ fn main() { let args = Args::parse(); let tokenizer_source = args.to_tokenizer_source(); - let model_architecture = args.model_architecture; let model_path = args.model_path; let query = args .query @@ -50,16 +48,13 @@ fn main() { // Load model let model_params = llm::ModelParameters::default(); - let model = llm::load_dynamic( - Some(model_architecture), + let model = llm::load( &model_path, tokenizer_source, model_params, llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| { - panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") - }); + .unwrap_or_else(|err| panic!("Failed to load model from {model_path:?}: {err}")); let inference_parameters = llm::InferenceParameters::default(); // Generate embeddings for query and comparands diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index 51e7369a..c3ffcb02 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -3,7 +3,6 @@ use std::{convert::Infallible, io::Write, path::PathBuf}; #[derive(Parser)] struct Args { - model_architecture: llm::ModelArchitecture, model_path: PathBuf, #[arg(long, short = 'p')] prompt: Option, @@ -29,7 +28,6 @@ fn main() { let args = Args::parse(); let tokenizer_source = args.to_tokenizer_source(); - let model_architecture = args.model_architecture; let model_path = args.model_path; let prompt = args .prompt @@ -38,16 +36,13 @@ fn main() { let now = std::time::Instant::now(); - let model = llm::load_dynamic( - Some(model_architecture), + let model = llm::load( &model_path, tokenizer_source, Default::default(), llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| { - panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") - }); + .unwrap_or_else(|err| panic!("Failed to load model from {model_path:?}: {err}")); println!( "Model fully loaded! Elapsed: {}ms", diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index 4ced1ef2..1efb088e 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -5,7 +5,6 @@ use std::{convert::Infallible, io::Write, path::PathBuf}; #[derive(Parser)] struct Args { - model_architecture: llm::ModelArchitecture, model_path: PathBuf, #[arg(long, short = 'v')] pub tokenizer_path: Option, @@ -29,18 +28,14 @@ fn main() { let args = Args::parse(); let tokenizer_source = args.to_tokenizer_source(); - let model_architecture = args.model_architecture; let model_path = args.model_path; - let model = llm::load_dynamic( - Some(model_architecture), + let model = llm::load( &model_path, tokenizer_source, Default::default(), llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| { - panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") - }); + .unwrap_or_else(|err| panic!("Failed to load model from {model_path:?}: {err}")); let mut session = model.start_session(Default::default()); diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index b15bbdf3..c58f36f2 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -70,7 +70,6 @@ use std::{ error::Error, fmt::{Debug, Display}, - path::Path, str::FromStr, }; @@ -80,16 +79,18 @@ pub use llm_base::{ conversation_inference_callback, feed_prompt_callback, ggml::accelerator::get_accelerator as ggml_get_accelerator, ggml::accelerator::Accelerator as GgmlAccelerator, ggml::format as ggml_format, - ggml::RoPEOverrides, load, load_progress_callback_stdout, quantize, samplers, ElementType, - FileMagic, FileType, FileTypeFormat, Hyperparameters, InferenceError, InferenceFeedback, - InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession, - InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, - InvalidTokenBias, KnownModel, LoadError, LoadProgress, Model, ModelKVMemoryType, + ggml::RoPEOverrides, quantize, samplers, ElementType, FileMagic, FileType, FileTypeFormat, + Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, + InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, + InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, Model, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, RewindError, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource, }; +mod loader; +pub use loader::{load, load_progress_callback_stdout, LoadError, LoadProgress}; + use serde::Serialize; macro_rules! define_models { @@ -124,7 +125,7 @@ macro_rules! define_models { impl ModelArchitecture { /// Use a visitor to dispatch some code based on the model architecture. - pub fn visit(&self, visitor: &mut impl ModelArchitectureVisitor) -> R { + pub fn visit(&self, visitor: impl ModelArchitectureVisitor) -> R { match self { $( #[cfg(feature = $model_lowercase_str)] @@ -184,11 +185,11 @@ define_models!( /// Used to dispatch some code based on the model architecture. pub trait ModelArchitectureVisitor { /// Visit a model architecture. - fn visit(&mut self) -> R; + fn visit(self) -> R; } /// An unsupported model architecture was specified. -pub struct UnsupportedModelArchitecture(String); +pub struct UnsupportedModelArchitecture(pub String); impl Display for UnsupportedModelArchitecture { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) @@ -203,64 +204,6 @@ impl Debug for UnsupportedModelArchitecture { } } -/// A helper function that loads the specified model from disk using an architecture -/// specified at runtime. If no architecture is specified, it will try to infer it -/// from the model's metadata. -/// -/// This method returns a [`Box`], which means that the model will have single ownership. -/// If you'd like to share ownership (i.e. to use the model in multiple threads), we -/// suggest using [`Arc::from(Box)`](https://doc.rust-lang.org/std/sync/struct.Arc.html#impl-From%3CBox%3CT,+Global%3E%3E-for-Arc%3CT%3E) -/// to convert the [`Box`] into an [`Arc`](std::sync::Arc) after loading. -pub fn load_dynamic( - architecture: Option, - path: &Path, - tokenizer_source: TokenizerSource, - params: ModelParameters, - load_progress_callback: impl FnMut(LoadProgress), -) -> Result, LoadError> { - fn load_model( - path: &Path, - tokenizer_source: TokenizerSource, - params: ModelParameters, - load_progress_callback: impl FnMut(LoadProgress), - ) -> Result, LoadError> { - Ok(Box::new(load::( - path, - tokenizer_source, - params, - load_progress_callback, - )?)) - } - - let architecture = architecture.expect("TODO: This option will be removed soon"); - - struct LoadVisitor<'a, F: FnMut(LoadProgress)> { - path: &'a Path, - tokenizer_source: TokenizerSource, - params: ModelParameters, - load_progress_callback: F, - } - impl<'a, F: FnMut(LoadProgress)> ModelArchitectureVisitor, LoadError>> - for LoadVisitor<'a, F> - { - fn visit(&mut self) -> Result, LoadError> { - load_model::( - self.path, - self.tokenizer_source.clone(), - self.params.clone(), - &mut self.load_progress_callback, - ) - } - } - - architecture.visit(&mut LoadVisitor { - path, - tokenizer_source, - params, - load_progress_callback, - }) -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/llm/src/loader.rs b/crates/llm/src/loader.rs new file mode 100644 index 00000000..9bbe0d51 --- /dev/null +++ b/crates/llm/src/loader.rs @@ -0,0 +1,380 @@ +use std::{fs::File, io::BufReader, path::Path}; + +use llm_base::{ + ggml::{ + format::gguf::{Gguf, GgufLoadError, MetadataValue, MetadataValueType}, + sys::llama::llama_ftype, + Context, + }, + loader::{LoadKnownProgress, Source}, + model::HyperparametersReadError, + ContainerType, FileMagic, KnownModel, LoadKnownError, LoraAdapter, MetadataError, MetadataExt, + Mmap, Model, ModelParameters, TensorLoadError, Tokenizer, TokenizerLoadError, TokenizerSource, +}; +use thiserror::Error; + +use tracing::log; + +use crate::{ModelArchitecture, ModelArchitectureVisitor, UnsupportedModelArchitecture}; + +/// Each variant represents a step within the process of loading the model. +/// These can be used to report progress to the user. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum LoadProgress<'a> { + /// The hyperparameters have been loaded from the model. + HyperparametersLoaded, + /// The context has been created. + ContextSize { + /// The size of the context. + bytes: usize, + }, + /// A tensor was patched with a LoRA. + LoraApplied { + /// The name of the patched tensor. + name: &'a str, + /// LoRA file the patch was applied from. + source: &'a Path, + }, + /// A tensor from the current part has been loaded. + TensorLoaded { + /// The current tensor (0-indexed). + current_tensor: usize, + /// The number of total tensors. + tensor_count: usize, + }, + /// A model part has finished fully loading. + Loaded { + /// The number of bytes in the part. + file_size: u64, + /// The number of tensors in the part. + tensor_count: usize, + }, +} + +#[derive(Error, Debug)] +/// Errors encountered during the loading process. +pub enum LoadError { + #[error("the file does not exist")] + /// The file does not exist. + FileDoesNotExist, + #[error("could not open file")] + /// A file failed to open. + OpenFileFailed { + /// The original error. + source: std::io::Error, + }, + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("invalid magic value {magic}")] + /// An invalid magic value was encountered during the loading process. + InvalidMagic { + /// The magic value that was encountered. + magic: FileMagic, + }, + #[error("invalid file format {container_type:?}")] + /// The version of the format is not supported by this version of `llm`. + InvalidFormatVersion { + /// The format that was encountered. + container_type: ContainerType, + }, + /// The tensor `tensor_name` had an unsupported element type. + #[error("invalid element type {element_type} for tensor `{tensor_name}`")] + UnsupportedElementType { + /// The name of the tensor. + tensor_name: String, + /// The element type that was encountered. + element_type: u32, + }, + /// The tokenizer could not be loaded. + #[error("could not load tokenizer: {0}")] + TokenizerLoadFail(#[from] TokenizerLoadError), + /// The quantization version was missing, despite this model containing quantized tensors. + #[error("quantization version was missing, despite model containing quantized tensors")] + MissingQuantizationVersion, + /// The quantization version is not supported by this version of `llm`. + #[error("quantization version {quantization_version:?} is not supported")] + UnsupportedQuantizationVersion { + /// The quantization version that was encountered. + quantization_version: MetadataValue, + }, + /// The model expected a metadata key-value pair, but the key was missing. + #[error("missing metadata key {key:?}")] + MissingMetadataKey { + /// The key that was missing. + key: String, + }, + /// The metadata key-value pair was not of the expected type. + #[error("metadata key {key:?} was not of the expected type")] + InvalidMetadataType { + /// The key with the invalid type. + key: String, + /// The expected type. + expected_type: MetadataValueType, + /// The actual type. + actual_type: MetadataValueType, + }, + /// The file type within the model was not supported by this version of `llm`. + #[error("file type {file_type} is not supported")] + UnsupportedFileType { + /// The file type (ignoring the quantization version) that was encountered. + file_type: llama_ftype, + }, + /// The architecture specified in this model is not supported by `llm`. + #[error("architecture is not supported: {0}")] + UnsupportedArchitecture(#[from] UnsupportedModelArchitecture), + /// An error occurred while reading the hyperparameters. + #[error("{0}")] + HyperparametersReadError(HyperparametersReadError), + /// An error occurred while reading the tensors. + #[error("{0}")] + TensorLoadError(TensorLoadError), +} +impl From for LoadError { + fn from(value: GgufLoadError) -> Self { + match value { + GgufLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { magic }, + GgufLoadError::InvalidFormatVersion(container_type) => { + LoadError::InvalidFormatVersion { container_type } + } + GgufLoadError::Io(err) => LoadError::Io(err), + GgufLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err), + GgufLoadError::InvalidIntegerConversion(err) => { + LoadError::InvalidIntegerConversion(err) + } + GgufLoadError::UnsupportedElementType { tensor_name, ftype } => { + LoadError::UnsupportedElementType { + tensor_name, + element_type: ftype, + } + } + } + } +} +impl From for LoadError { + fn from(value: LoadKnownError) -> Self { + match value { + LoadKnownError::HyperparametersReadError(e) => Self::HyperparametersReadError(e), + LoadKnownError::TensorLoadError(e) => Self::TensorLoadError(e), + } + } +} +impl From for LoadError { + fn from(value: MetadataError) -> Self { + Self::HyperparametersReadError(HyperparametersReadError::MetadataError(value)) + } +} + +/// Loads the specified GGUF model from disk, determining its architecture from the metadata. +/// +/// This method returns a [`Box`], which means that the model will have single ownership. +/// If you'd like to share ownership (i.e. to use the model in multiple threads), we +/// suggest using [`Arc::from(Box)`](https://doc.rust-lang.org/std/sync/struct.Arc.html#impl-From%3CBox%3CT,+Global%3E%3E-for-Arc%3CT%3E) +/// to convert the [`Box`] into an [`Arc`](std::sync::Arc) after loading. +pub fn load( + path: &Path, + tokenizer_source: TokenizerSource, + params: ModelParameters, + mut load_progress_callback: impl FnMut(LoadProgress), +) -> Result, LoadError> { + if !path.exists() { + return Err(LoadError::FileDoesNotExist); + } + + let file = File::open(path).map_err(|e| LoadError::OpenFileFailed { source: e })?; + let mut reader = BufReader::new(&file); + log::trace!("Read model file from {:?}", path); + + let gguf = Gguf::load(&mut reader)?; + log::trace!("Loaded GGML model from reader"); + + let architecture = gguf + .metadata + .fallible_get_string("general.architecture")? + .parse::()?; + + let tokenizer = tokenizer_source.retrieve(&gguf)?; + + let quantization_version = gguf.metadata.get("general.quantization_version"); + log::trace!( + "Determined quantization version of model as {:?}", + quantization_version + ); + + // TODO: this is temporary while we figure out how to handle this + let any_quantized = gguf + .tensor_infos + .values() + .any(|t| t.element_type.is_quantized()); + if any_quantized { + match quantization_version { + Some(MetadataValue::UInt32(2)) => { + // Currently supported version + } + Some(quantization_version) => { + return Err(LoadError::UnsupportedQuantizationVersion { + quantization_version: quantization_version.clone(), + }) + } + None => return Err(LoadError::MissingQuantizationVersion), + } + } + + let use_mmap = params.prefer_mmap && params.lora_adapters.is_none(); + + let ctx_size = gguf + .tensor_infos + .values() + .map(|ti| ti.calc_absolute_size(use_mmap)) + .sum::(); + log::trace!("Context size: {:?}", ctx_size); + + let mut lora_adapters: Option> = None; + if let Some(lora_paths) = ¶ms.lora_adapters { + let adapters: Result, _> = lora_paths + .iter() + .map(|lora_path| { + // Read the LoRA file + let lora_file = File::open(lora_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + })?; + let mut lora_reader = BufReader::new(&lora_file); + let gguf = Gguf::load(&mut lora_reader)?; + + // Collect the names of the tensors that should be patched + let tensors_to_patch = gguf + .tensor_infos + .keys() + .filter_map(|k| Some(k.rsplit_once('.')?.0.to_owned())) + .collect(); + + log::trace!("Loaded LoRA weights"); + // Return the LoRA patches + #[allow(unreachable_code)] + Ok::<_, LoadError>(LoraAdapter { + tensors: gguf.tensor_infos.clone(), + tensors_to_patch, + source: Box::new(lora_reader), + path: lora_path.to_owned(), + gguf, + scaling: todo!("Calculate scaling from LoRA file metadata (GGUF does not have standardised metadata yet)"), + }) + }) + .collect(); + lora_adapters = Some(adapters?); + } + + (load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size }); + let (context, file_size) = if use_mmap { + unsafe { + let mmap = Mmap::map(&file)?; + let file_size = mmap.len() as u64; + (Context::new_with_mmap(mmap), file_size) + } + } else { + (Context::new_with_allocate(ctx_size), file.metadata()?.len()) + }; + + let model = architecture.visit(LoadVisitor { + source: &mut reader, + gguf: &gguf, + tokenizer, + context, + lora_adapters, + load_progress_callback: &mut load_progress_callback, + params, + })?; + + (load_progress_callback)(LoadProgress::Loaded { + file_size, + tensor_count: gguf.tensor_infos.len(), + }); + + log::trace!("Loaded model"); + + Ok(model) +} + +struct LoadVisitor<'a, F: FnMut(LoadProgress)> { + source: &'a mut dyn Source, + gguf: &'a Gguf, + tokenizer: Tokenizer, + context: Context, + lora_adapters: Option>, + load_progress_callback: F, + params: ModelParameters, +} +impl<'a, F: FnMut(LoadProgress)> ModelArchitectureVisitor, LoadError>> + for LoadVisitor<'a, F> +{ + fn visit(mut self) -> Result, LoadError> { + let model = Box::new(llm_base::load_known_internal::( + self.source, + self.gguf, + self.tokenizer, + self.context, + self.lora_adapters, + &mut |step| { + (self.load_progress_callback)(match step { + LoadKnownProgress::LoraApplied { name, source } => { + LoadProgress::LoraApplied { name, source } + } + LoadKnownProgress::TensorLoaded { current_tensor } => { + LoadProgress::TensorLoaded { + current_tensor, + tensor_count: self.gguf.tensor_infos.len(), + } + } + }) + }, + self.params, + )?); + + Ok(model) + } +} + +/// A implementation for `load_progress_callback` that outputs to `stdout`. +pub fn load_progress_callback_stdout(progress: LoadProgress) { + match progress { + LoadProgress::HyperparametersLoaded => println!("Loaded hyperparameters"), + LoadProgress::ContextSize { bytes } => println!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::TensorLoaded { + current_tensor, + tensor_count, + .. + } => { + let current_tensor = current_tensor + 1; + if current_tensor % 8 == 0 { + println!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::Loaded { + file_size: byte_size, + tensor_count, + } => { + println!("Loading of model complete"); + println!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + LoadProgress::LoraApplied { name, source } => { + println!( + "Patched tensor {} via LoRA from '{}'", + name, + source.file_name().unwrap().to_str().unwrap() + ); + } + }; +} diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index f32669a2..46fe8152 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -3,10 +3,10 @@ use llm_base::{ ggml::{self, format::gguf::Metadata}, - model::{common, HyperparametersWriteError}, - FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, - MetadataExt, ModelContext, ModelParameters, ModelTensorLoader, OutputRequest, Regex, TokenId, - Tokenizer, + model::{common, HyperparametersReadError, HyperparametersWriteError}, + FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, MetadataExt, + ModelContext, ModelParameters, ModelTensorLoader, OutputRequest, Regex, TensorLoadError, + TokenId, Tokenizer, }; const META_TENSOR_DATA_LAYOUT: &str = "Meta AI original pth"; @@ -41,11 +41,11 @@ impl KnownModel for Llama { type Hyperparameters = Hyperparameters; fn new( - mut hyperparameters: Self::Hyperparameters, + hyperparameters: Self::Hyperparameters, params: ModelParameters, tokenizer: Tokenizer, tensor_loader: ModelTensorLoader, - ) -> Result { + ) -> Result { assert_eq!(hyperparameters.tensor_data_layout, META_TENSOR_DATA_LAYOUT); let mut tl = tensor_loader; @@ -412,7 +412,7 @@ pub struct Hyperparameters { pub tensor_data_layout: String, } impl llm_base::Hyperparameters for Hyperparameters { - fn read_gguf(metadata: &Metadata) -> Result { + fn read_gguf(metadata: &Metadata) -> Result { Ok(Self { // TODO: handle models without an embedded vocabulary vocabulary_count: metadata @@ -422,11 +422,7 @@ impl llm_base::Hyperparameters for Hyperparameters { head_count: metadata.fallible_get_countable("llama.attention.head_count")?, head_count_kv: metadata.fallible_get_countable("llama.attention.head_count_kv")?, block_count: metadata.fallible_get_countable("llama.block_count")?, - file_type: metadata - .get("general.file_type") - .and_then(|v| v.as_uint32()) - .map(|v| FileType::try_from(v as i32)) - .transpose()?, + file_type: FileType::read_for_hyperparameters(metadata)?, tensor_data_layout: metadata .fallible_get_string("llama.tensor_data_layout") .unwrap_or(META_TENSOR_DATA_LAYOUT.to_string()), From 588eb98f2a65563b78bee54cdcd9ed43e22e07b2 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 8 Oct 2023 13:34:55 -0700 Subject: [PATCH 16/33] feat(ggml): use newtype for metadata --- binaries/gguf-explorer/src/main.rs | 2 +- binaries/llm-cli/src/main.rs | 11 +- crates/ggml/src/format/gguf/metadata.rs | 138 ++++++++++++---------- crates/ggml/src/format/gguf/mod.rs | 20 +--- crates/llm-base/src/lib.rs | 3 +- crates/llm-base/src/loader.rs | 8 +- crates/llm-base/src/model/mod.rs | 5 +- crates/llm-base/src/tokenizer/embedded.rs | 87 +++++++++++++- crates/llm-base/src/tokenizer/mod.rs | 15 ++- crates/llm/src/lib.rs | 12 +- crates/llm/src/loader.rs | 9 +- crates/models/llama/src/lib.rs | 18 +-- 12 files changed, 209 insertions(+), 119 deletions(-) diff --git a/binaries/gguf-explorer/src/main.rs b/binaries/gguf-explorer/src/main.rs index 11435057..bb37bce2 100644 --- a/binaries/gguf-explorer/src/main.rs +++ b/binaries/gguf-explorer/src/main.rs @@ -128,7 +128,7 @@ impl Explorer { }) .body(|mut body| { for key in metadata_keys { - let value = &metadata[key]; + let value = metadata.get_optional(key).unwrap(); body.row(30.0, |mut row| { row.col(|ui| { diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 1f8b330c..5cfb6c64 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -140,7 +140,7 @@ fn info(args: &cli_args::Info) -> eyre::Result<()> { let gguf = gguf::Gguf::load(&mut reader)?; log::info!("Non-array parameters:"); - for (metadata_key, metadata_value) in &gguf.metadata { + for (metadata_key, metadata_value) in gguf.metadata.iter() { if metadata_value.as_array().is_some() { continue; } @@ -148,12 +148,15 @@ fn info(args: &cli_args::Info) -> eyre::Result<()> { log::info!("- {}: {:?}", metadata_key, metadata_value); } - if let Some((tokens, _scores)) = gguf.tokenizer_embedded() { - log::info!("Embedded tokenizer vocabulary size: {}", tokens.len()); + if let Ok(tokenizer) = llm::tokenizer::GgufEmbeddedTokenizer::from_metadata(&gguf.metadata) { + log::info!( + "Embedded tokenizer vocabulary size: {}", + tokenizer.tokens.len() + ); if args.tokenizer { log::info!("Embedded tokenizer vocabulary:"); - for (i, token) in tokens.iter().enumerate() { + for (i, token) in tokenizer.tokens.iter().enumerate() { log::info!("- {}: {}", i, token); } } diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index e58bcb13..58959e4a 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -6,8 +6,84 @@ use crate::util; use super::{GgufContext, GgufLoadError}; -// TODO: make this a newtype instead -pub type Metadata = HashMap; +#[derive(Debug, Clone, PartialEq)] +pub struct Metadata(pub HashMap); +impl Metadata { + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn keys(&self) -> impl Iterator { + self.0.keys() + } + + pub fn values(&self) -> impl Iterator { + self.0.values() + } + + pub fn get_optional(&self, key: &str) -> Option<&MetadataValue> { + self.0.get(key) + } + + pub fn get(&self, key: &str) -> Result<&MetadataValue, MetadataError> { + self.get_optional(key) + .ok_or_else(|| MetadataError::MissingKey { + key: key.to_owned(), + }) + } + + pub fn get_with_type<'a, T: MetadataValueTypeFromRustType>( + &'a self, + key: &'a str, + getter: impl Fn(&MetadataValue) -> Option<&T>, + ) -> Result<&'a T, MetadataError> { + let metadata_value = self.get(key)?; + getter(metadata_value).ok_or_else(|| MetadataError::InvalidType { + key: key.to_string(), + expected_type: T::value_type(), + actual_type: metadata_value.value_type(), + }) + } + + pub fn get_array_with_type<'a, T: MetadataValueTypeFromRustType>( + &'a self, + key: &'a str, + getter: impl Fn(&MetadataValue) -> Option<&[T]>, + ) -> Result<&'a [T], MetadataError> { + let metadata_value = self.get(key)?; + getter(metadata_value).ok_or_else(|| MetadataError::InvalidType { + key: key.to_string(), + expected_type: T::value_type(), + actual_type: metadata_value.value_type(), + }) + } + + // TODO: see if we can generalize this with `ToOwned` or something? + pub fn get_string(&self, key: &str) -> Result { + let metadata_value = self.get(key)?; + Ok(metadata_value + .as_string() + .ok_or_else(|| MetadataError::InvalidType { + key: key.to_string(), + expected_type: MetadataValueType::String, + actual_type: metadata_value.value_type(), + })? + .to_string()) + } + + pub fn get_countable(&self, key: &str) -> Result { + let metadata_value = self.get(key)?; + match metadata_value { + MetadataValue::UInt32(v) => Ok(usize::try_from(*v)?), + MetadataValue::UInt64(v) => Ok(usize::try_from(*v)?), + _ => Err(MetadataError::InvalidType { + key: key.to_string(), + expected_type: MetadataValueType::UInt64, + actual_type: metadata_value.value_type(), + }), + } + } +} #[repr(u32)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -471,64 +547,6 @@ impl MetadataArrayValue { } } -#[doc(hidden)] -pub trait MetadataExt { - fn fallible_get(&self, key: &str) -> Result<&MetadataValue, MetadataError>; - fn fallible_typed_get<'a, T: MetadataValueTypeFromRustType>( - &'a self, - key: &'a str, - getter: impl Fn(&MetadataValue) -> Option<&T>, - ) -> Result<&'a T, MetadataError>; - fn fallible_get_string(&self, key: &str) -> Result; - fn fallible_get_countable(&self, key: &str) -> Result; -} -impl MetadataExt for Metadata { - fn fallible_get(&self, key: &str) -> Result<&MetadataValue, MetadataError> { - self.get(key).ok_or_else(|| MetadataError::MissingKey { - key: key.to_owned(), - }) - } - - fn fallible_typed_get<'a, T: MetadataValueTypeFromRustType>( - &'a self, - key: &'a str, - getter: impl Fn(&MetadataValue) -> Option<&T>, - ) -> Result<&'a T, MetadataError> { - let metadata_value = self.fallible_get(key)?; - getter(metadata_value).ok_or_else(|| MetadataError::InvalidType { - key: key.to_string(), - expected_type: T::value_type(), - actual_type: metadata_value.value_type(), - }) - } - - // TODO: see if we can generalize this with `ToOwned` or something? - fn fallible_get_string(&self, key: &str) -> Result { - let metadata_value = self.fallible_get(key)?; - Ok(metadata_value - .as_string() - .ok_or_else(|| MetadataError::InvalidType { - key: key.to_string(), - expected_type: MetadataValueType::String, - actual_type: metadata_value.value_type(), - })? - .to_string()) - } - - fn fallible_get_countable(&self, key: &str) -> Result { - let metadata_value = self.fallible_get(key)?; - match metadata_value { - MetadataValue::UInt32(v) => Ok(usize::try_from(*v)?), - MetadataValue::UInt64(v) => Ok(usize::try_from(*v)?), - _ => Err(MetadataError::InvalidType { - key: key.to_string(), - expected_type: MetadataValueType::UInt64, - actual_type: metadata_value.value_type(), - }), - } - } -} - #[derive(Error, Debug)] /// Errors encountered during the loading process. pub enum MetadataError { diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index 5c2152a7..0b5da22b 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -79,9 +79,10 @@ impl Gguf { let (key, value) = MetadataValue::read_key_value(&ctx, reader)?; metadata.insert(key, value); } + let metadata = Metadata(metadata); let alignment = metadata - .get("general.alignment") + .get_optional("general.alignment") .and_then(|v| v.as_uint32()) .unwrap_or(DEFAULT_ALIGNMENT) as u64; @@ -102,23 +103,6 @@ impl Gguf { tensor_data_position, }) } - - // TODO: consider moving this to a `ModelGguf` abstraction that wraps this - // and provides a model-specific interface - pub fn tokenizer_embedded(&self) -> Option<(&[String], &[f32])> { - let tokens = self - .metadata - .get("tokenizer.ggml.tokens")? - .as_array()? - .as_string_array()?; - let scores = self - .metadata - .get("tokenizer.ggml.scores")? - .as_array()? - .as_float32_array()?; - - Some((tokens, scores)) - } } struct GgufContext { diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index fe521c08..8796bbe1 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -10,11 +10,11 @@ mod inference_session; mod lora; mod quantize; -mod tokenizer; pub mod loader; pub mod model; pub mod samplers; +pub mod tokenizer; pub mod util; use std::sync::{Arc, Mutex}; @@ -31,7 +31,6 @@ pub use inference_session::{ pub use llm_samplers::prelude::{Sampler, SamplerChain}; pub use loader::{ load_known_internal, ContainerType, FileMagic, FileType, FileTypeFormat, LoadKnownError, - MetadataError, MetadataExt, ModelTensorLoader, TensorLoadError, }; pub use lora::{LoraAdapter, LoraParameters}; pub use memmap2::Mmap; diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index edfb59e3..a765d77c 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -11,16 +11,12 @@ use crate::{ model::{Hyperparameters, HyperparametersReadError}, KnownModel, LoraAdapter, ModelContext, ModelParameters, Tokenizer, }; +pub use ggml::{format::gguf::MetadataError, format::ContainerType, util::FileMagic}; use ggml::{ format::gguf::{Gguf, Metadata, TensorInfo}, sys::llama::llama_ftype, Context, MAX_NAME_LENGTH, }; -pub use ggml::{ - format::gguf::{MetadataError, MetadataExt}, - format::ContainerType, - util::FileMagic, -}; use thiserror::Error; #[derive(Debug, PartialEq, Clone, Copy, Eq, Default)] @@ -62,7 +58,7 @@ impl FileType { metadata: &Metadata, ) -> Result, HyperparametersReadError> { metadata - .get("general.file_type") + .get_optional("general.file_type") .and_then(|v| v.as_uint32()) .map(|v| { FileType::try_from(v as i32).map_err(|ftype| { diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 849dc317..a2640516 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -11,8 +11,9 @@ use regex::Regex; use thiserror::Error; use crate::{ - tokenizer::TokenId, FileType, InferenceSession, InferenceSessionConfig, ModelTensorLoader, - TensorLoadError, Tokenizer, + loader::{ModelTensorLoader, TensorLoadError}, + tokenizer::TokenId, + FileType, InferenceSession, InferenceSessionConfig, Tokenizer, }; /// Common functions for model evaluation diff --git a/crates/llm-base/src/tokenizer/embedded.rs b/crates/llm-base/src/tokenizer/embedded.rs index cf96b183..2825b609 100644 --- a/crates/llm-base/src/tokenizer/embedded.rs +++ b/crates/llm-base/src/tokenizer/embedded.rs @@ -1,5 +1,6 @@ -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr}; +use ggml::format::gguf::{Metadata, MetadataError}; use thiserror::Error; use super::{Token, TokenId, TokenScore, TokenizationError}; @@ -155,3 +156,87 @@ impl EmbeddedTokenizer { .map(|(token, score)| (token.clone(), *score)) } } + +/// An embedded tokenizer definition in a GGUF. +pub struct GgufEmbeddedTokenizer<'a> { + /// The model type. + pub model: Option<&'a str>, + /// The tokens. + pub tokens: &'a [String], + /// The token scores. + pub scores: &'a [f32], + /// The token types. + pub types: Option<&'a [u32]>, +} +impl GgufEmbeddedTokenizer<'_> { + /// Attempt to retrieve the embedded tokenizer from the metadata. + pub fn from_metadata(metadata: &Metadata) -> Result { + Ok(GgufEmbeddedTokenizer { + model: metadata + .get_optional("tokenizer.ggml.model") + .and_then(|v| v.as_string()), + tokens: metadata.get_array_with_type("tokenizer.ggml.tokens", |v| { + v.as_array()?.as_string_array() + })?, + scores: metadata.get_array_with_type("tokenizer.ggml.scores", |v| { + v.as_array()?.as_float32_array() + })?, + types: metadata + .get_array_with_type("tokenizer.ggml.token_type", |v| { + v.as_array()?.as_uint32_array() + }) + .ok(), + }) + } +} + +/// Typesafe tokenizer models. +pub enum GgufEmbeddedTokenizerModel { + /// Llama style SentencePiece (tokens and scores extracted from HF `tokenizer.model`) + Llama, + /// Replit style SentencePiece (tokens and scores extracted from HF `spiece.model`) + Replit, + /// GPT-2 / GPT-NeoX style BPE (tokens extracted from HF `tokenizer.json`) + Gpt2, + /// RWKV tokenizer + Rwkv, +} +impl FromStr for GgufEmbeddedTokenizerModel { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "llama" => Ok(Self::Llama), + "replit" => Ok(Self::Replit), + "gpt2" => Ok(Self::Gpt2), + "rwkv" => Ok(Self::Rwkv), + other => Err(other.to_string()), + } + } +} + +/// The type of a token. +#[allow(missing_docs)] +pub enum GgufEmbeddedTokenizerTokenType { + Normal, + Unknown, + Control, + UserDefined, + Unused, + Byte, +} +impl TryFrom for GgufEmbeddedTokenizerTokenType { + type Error = u32; + + fn try_from(value: u32) -> Result { + match value { + 1 => Ok(Self::Normal), + 2 => Ok(Self::Unknown), + 3 => Ok(Self::Control), + 4 => Ok(Self::UserDefined), + 5 => Ok(Self::Unused), + 6 => Ok(Self::Byte), + other => Err(other), + } + } +} diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index fae8a25c..8837e1ef 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -1,3 +1,5 @@ +//! Tokenizer-related functionality. + use std::{error::Error, fmt::Display, path::PathBuf, str::FromStr}; use ggml::format::gguf::Gguf; @@ -136,13 +138,14 @@ impl TokenizerSource { Self::Embedded => { let mut tokenizer = EmbeddedTokenizer::default(); - if let Some((tokens, scores)) = gguf.tokenizer_embedded() { - for (i, (token, score)) in tokens.iter().zip(scores.iter()).enumerate() { - tokenizer.push_token(i as u32, token.as_bytes().to_vec(), *score); - } - } else { - return Err(TokenizerLoadError::NoTokenizerFound); + let tok = GgufEmbeddedTokenizer::from_metadata(&gguf.metadata).map_err(|_| { + // TODO: consider passing the error along + TokenizerLoadError::NoTokenizerFound + })?; + for (i, (token, score)) in tok.tokens.iter().zip(tok.scores.iter()).enumerate() { + tokenizer.push_token(i as u32, token.as_bytes().to_vec(), *score); } + tokenizer.into() } }) diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index c58f36f2..02b2ef4c 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -79,12 +79,12 @@ pub use llm_base::{ conversation_inference_callback, feed_prompt_callback, ggml::accelerator::get_accelerator as ggml_get_accelerator, ggml::accelerator::Accelerator as GgmlAccelerator, ggml::format as ggml_format, - ggml::RoPEOverrides, quantize, samplers, ElementType, FileMagic, FileType, FileTypeFormat, - Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, - InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, - InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, Model, ModelKVMemoryType, - ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, RewindError, - SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, + ggml::RoPEOverrides, quantize, samplers, tokenizer, ElementType, FileMagic, FileType, + FileTypeFormat, Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, + InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, + InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, Model, + ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, + RewindError, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource, }; diff --git a/crates/llm/src/loader.rs b/crates/llm/src/loader.rs index 9bbe0d51..5397f7c1 100644 --- a/crates/llm/src/loader.rs +++ b/crates/llm/src/loader.rs @@ -7,9 +7,10 @@ use llm_base::{ Context, }, loader::{LoadKnownProgress, Source}, + loader::{MetadataError, TensorLoadError}, model::HyperparametersReadError, - ContainerType, FileMagic, KnownModel, LoadKnownError, LoraAdapter, MetadataError, MetadataExt, - Mmap, Model, ModelParameters, TensorLoadError, Tokenizer, TokenizerLoadError, TokenizerSource, + ContainerType, FileMagic, KnownModel, LoadKnownError, LoraAdapter, Mmap, Model, + ModelParameters, Tokenizer, TokenizerLoadError, TokenizerSource, }; use thiserror::Error; @@ -196,12 +197,12 @@ pub fn load( let architecture = gguf .metadata - .fallible_get_string("general.architecture")? + .get_string("general.architecture")? .parse::()?; let tokenizer = tokenizer_source.retrieve(&gguf)?; - let quantization_version = gguf.metadata.get("general.quantization_version"); + let quantization_version = gguf.metadata.get_optional("general.quantization_version"); log::trace!( "Determined quantization version of model as {:?}", quantization_version diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 46fe8152..192eec8c 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -3,10 +3,10 @@ use llm_base::{ ggml::{self, format::gguf::Metadata}, + loader::{ModelTensorLoader, TensorLoadError}, model::{common, HyperparametersReadError, HyperparametersWriteError}, - FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, MetadataExt, - ModelContext, ModelParameters, ModelTensorLoader, OutputRequest, Regex, TensorLoadError, - TokenId, Tokenizer, + FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, ModelContext, + ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, }; const META_TENSOR_DATA_LAYOUT: &str = "Meta AI original pth"; @@ -416,15 +416,15 @@ impl llm_base::Hyperparameters for Hyperparameters { Ok(Self { // TODO: handle models without an embedded vocabulary vocabulary_count: metadata - .fallible_typed_get("tokenizer.ggml.tokens", |v| v.as_array())? + .get_with_type("tokenizer.ggml.tokens", |v| v.as_array())? .len(), - embedding_length: metadata.fallible_get_countable("llama.embedding_length")?, - head_count: metadata.fallible_get_countable("llama.attention.head_count")?, - head_count_kv: metadata.fallible_get_countable("llama.attention.head_count_kv")?, - block_count: metadata.fallible_get_countable("llama.block_count")?, + embedding_length: metadata.get_countable("llama.embedding_length")?, + head_count: metadata.get_countable("llama.attention.head_count")?, + head_count_kv: metadata.get_countable("llama.attention.head_count_kv")?, + block_count: metadata.get_countable("llama.block_count")?, file_type: FileType::read_for_hyperparameters(metadata)?, tensor_data_layout: metadata - .fallible_get_string("llama.tensor_data_layout") + .get_string("llama.tensor_data_layout") .unwrap_or(META_TENSOR_DATA_LAYOUT.to_string()), }) } From 388fa87bbb379379de2c180328beadae7639e4a0 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 8 Oct 2023 15:31:43 -0700 Subject: [PATCH 17/33] feat(llm): first pass at tokenizer re-port --- crates/ggml/src/format/gguf/metadata.rs | 13 + crates/llm-base/src/tokenizer/embedded.rs | 478 +++++++++++++++++----- crates/llm-base/src/tokenizer/mod.rs | 18 +- crates/models/llama/src/lib.rs | 2 +- 4 files changed, 403 insertions(+), 108 deletions(-) diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index 58959e4a..35041fc5 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -33,6 +33,19 @@ impl Metadata { } pub fn get_with_type<'a, T: MetadataValueTypeFromRustType>( + &'a self, + key: &'a str, + getter: impl Fn(&MetadataValue) -> Option, + ) -> Result { + let metadata_value = self.get(key)?; + getter(metadata_value).ok_or_else(|| MetadataError::InvalidType { + key: key.to_string(), + expected_type: T::value_type(), + actual_type: metadata_value.value_type(), + }) + } + + pub fn get_with_ref_type<'a, T: MetadataValueTypeFromRustType>( &'a self, key: &'a str, getter: impl Fn(&MetadataValue) -> Option<&T>, diff --git a/crates/llm-base/src/tokenizer/embedded.rs b/crates/llm-base/src/tokenizer/embedded.rs index 2825b609..791505cf 100644 --- a/crates/llm-base/src/tokenizer/embedded.rs +++ b/crates/llm-base/src/tokenizer/embedded.rs @@ -1,8 +1,14 @@ -use std::{collections::HashMap, str::FromStr}; +use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashMap}, + str::FromStr, +}; use ggml::format::gguf::{Metadata, MetadataError}; use thiserror::Error; +use crate::TokenizerLoadError; + use super::{Token, TokenId, TokenScore, TokenizationError}; #[derive(Debug, Error)] @@ -14,43 +20,109 @@ pub enum EmbeddedTokenizerError { } /// The built-in GGML tokenizer. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct EmbeddedTokenizer { /// Maps every integer (index) token ID to its corresponding token. - id_to_token: Vec, - - /// Maps every integer (index) token ID to corresponding score. - id_to_token_score: Vec, + id_to_token: Vec, // todo: use a radix tree /// Maps a token to a token ID. token_to_id: HashMap, - /// The longest token in this tokenizer. - max_token_length: usize, + model: GgufEmbeddedTokenizerModel, + bos_id: TokenId, + eos_id: TokenId, + unknown_id: TokenId, + linefeed_id: TokenId, + separator_id: Option, + padding_id: Option, +} +#[derive(Debug, Clone, Default)] +struct TokenData { + text: Token, + score: TokenScore, + ty: TokenType, } - impl EmbeddedTokenizer { - /// Add a token to the internal vocabulary. - /// - /// The token added must have `id` directly after the last token in the vocabulary. - /// - /// # Panics - /// - This function can panic if `id` does not correspond to the next token in the vocabulary. - /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. - pub(crate) fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { - // These are loader invariants. If this is broken, then the loader is broken and this is a bug, - // not an issue with the model itself. - assert_eq!(self.id_to_token.len(), self.id_to_token_score.len()); - if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize { - let expected_id = self.id_to_token.len() as TokenId; - panic!("the id of token added should be {expected_id}; is {id}"); - } + pub(crate) fn from_metadata(metadata: &Metadata) -> Result { + let tok = GgufEmbeddedTokenizer::from_metadata(metadata)?; + + let model = if let Some(model) = tok.model { + model + .parse::() + .expect("TODO: handle invalid tokenizer model") + } else { + GgufEmbeddedTokenizerModel::Llama + }; + + match model { + GgufEmbeddedTokenizerModel::Llama => { + let bos_id = metadata + .get_with_type("tokenizer.ggml.bos_token_id", |v| v.as_uint32()) + .unwrap_or(1); + let eos_id = metadata + .get_with_type("tokenizer.ggml.eos_token_id", |v| v.as_uint32()) + .unwrap_or(2); + let unknown_id = metadata + .get_with_type("tokenizer.ggml.unknown_token_id", |v| v.as_uint32()) + .unwrap_or(0); + let separator_id = metadata + .get_with_type("tokenizer.ggml.separator_token_id", |v| v.as_uint32()) + .ok(); + let padding_id = metadata + .get_with_type("tokenizer.ggml.padding_token_id", |v| v.as_uint32()) + .ok(); + + let tokens = metadata.get_array_with_type("tokenizer.ggml.tokens", |v| { + v.as_array()?.as_string_array() + })?; + let scores = metadata + .get_array_with_type("tokenizer.ggml.scores", |v| { + v.as_array()?.as_float32_array() + }) + .unwrap_or_default(); + let types = metadata + .get_array_with_type("tokenizer.ggml.token_types", |v| { + v.as_array()?.as_uint32_array() + }) + .unwrap_or_default(); + + let mut token_to_id = HashMap::default(); + let mut id_to_token = vec![TokenData::default(); tokens.len()]; + + for (i, token) in tokens.iter().enumerate() { + let word = token.as_bytes().to_vec(); + token_to_id.insert(word.clone(), i as TokenId); + id_to_token[i] = TokenData { + text: word.clone(), + score: scores.get(i).copied().unwrap_or(0.0), + ty: match types.get(i) { + Some(tok) => { + TokenType::try_from(*tok).expect("TODO: handle invalid token type") + } + None => TokenType::Normal, + }, + }; + } - self.max_token_length = self.max_token_length.max(content.len()); - self.id_to_token.push(content.clone()); - self.id_to_token_score.push(score); - self.token_to_id.insert(content, id); + let mut tokenizer = EmbeddedTokenizer { + token_to_id, + id_to_token, + model: GgufEmbeddedTokenizerModel::Llama, + bos_id, + eos_id, + unknown_id, + linefeed_id: 0, + separator_id, + padding_id, + }; + + tokenizer.linefeed_id = tokenizer.byte_to_token(b'\n'); + + Ok(tokenizer) + } + _ => unimplemented!(), + } } pub(crate) fn id(&self, token: &[u8]) -> Option { @@ -58,8 +130,8 @@ impl EmbeddedTokenizer { } /// Converts a token index to the token it represents in this tokenizer. - pub(crate) fn token(&self, idx: usize) -> Vec { - self.id_to_token[idx].clone() + pub(crate) fn token(&self, idx: usize) -> Token { + self.id_to_token[idx].text.clone() } /// Returns the number of tokens in the tokenizer. @@ -72,7 +144,6 @@ impl EmbeddedTokenizer { self.id_to_token.is_empty() } - // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece /// Tokenize a `text` with this tokenizer. /// /// `bos` controls whether a beginning-of-string token should be inserted. @@ -80,80 +151,82 @@ impl EmbeddedTokenizer { &self, text: &str, bos: bool, - ) -> Result, TokenId)>, TokenizationError> { - let len = text.len(); - - let mut score = vec![0usize; len + 1]; - let mut prev = vec![TokenId::default(); len + 1]; - - for i in 0..len { - let max_len = (len - i).min(self.max_token_length); - for sub_len in 1..=max_len { - let sub = &text.as_bytes()[i..i + sub_len]; - let token = self.token_to_id.get(sub); - - if let Some(token) = token { - let token_score = sub.len() * sub.len(); - let local_score = score[i] + token_score; - let next = i + sub_len; - - if score[next] < local_score { - score[next] = local_score; - prev[next] = *token; - } - } - } - } - - // Backward pass - let mut res = vec![]; - let mut i = len; - while i > 0 { - let token_id = prev[i]; - if token_id == 0 { - return Err(TokenizationError::TokenizationFailed { - error: Box::new(EmbeddedTokenizerError::Arbitrary( - "the backward pass for the tokenizer encountered a non-set token" - .to_string(), - )), - }); - } - let token = self.id_to_token[token_id as usize].as_slice(); - res.push((token.to_vec(), token_id)); - i -= token.len(); - } + ) -> Result, TokenizationError> { + let mut output = vec![]; if bos { - // TODO: replace with vocab.bos - res.push((vec![], 1)); + output.push(( + self.id_to_token[self.bos_id as usize].text.clone(), + self.bos_id, + )); } - // Pieces are in reverse order so correct that - res.reverse(); + if text.is_empty() { + return Ok(output); + } - Ok(res) + match self.model { + GgufEmbeddedTokenizerModel::Llama => { + let text = escape_whitespace(format!(" {text}").as_bytes()); + + Ok(TokenizerSpm::new(self) + .tokenize(&text) + .into_iter() + .map(|id| { + // TODO: see if this can be made more efficient + (self.id_to_token[id as usize].text.clone(), id) + }) + .collect()) + } + _ => unimplemented!(), + } } /// Decode a list `tokens` with this tokenizer. - pub(crate) fn decode(&self, tokens: Vec, skip_special_tokens: bool) -> Vec { - let mut vec = vec![]; - - for token in tokens { - if skip_special_tokens && token == 1 { - continue; + pub(crate) fn decode(&self, tokens: Vec, _skip_special_tokens: bool) -> Vec { + let mut ret = vec![]; + + match self.model { + GgufEmbeddedTokenizerModel::Llama => { + for token_id in tokens { + let token = &self.id_to_token[token_id as usize]; + match token.ty { + TokenType::Normal => { + ret.append(&mut unescape_whitespace(&token.text)); + } + TokenType::Unknown => { + assert_eq!(token.text.len(), 3); + ret.extend_from_slice(&[0xE2, 0x96, 0x85]); + } + TokenType::Byte => { + ret.push(self.token_to_byte(token_id)); + } + TokenType::Control | TokenType::UserDefined | TokenType::Unused => {} + } + } } - - vec.append(&mut self.id_to_token[token as usize].to_vec()); + _ => unimplemented!(), } - vec + ret } +} +impl EmbeddedTokenizer { + fn byte_to_token(&self, ch: u8) -> TokenId { + let token = format!("<0x{ch:02X}>"); + self.token_to_id.get(token.as_bytes()).copied().unwrap() + } + + fn token_to_byte(&self, token_id: TokenId) -> u8 { + let data = &self.id_to_token[token_id as usize]; + assert_eq!(data.ty, TokenType::Byte); - pub(crate) fn iter(&self) -> impl Iterator + '_ { - self.id_to_token - .iter() - .zip(self.id_to_token_score.iter()) - .map(|(token, score)| (token.clone(), *score)) + match self.model { + GgufEmbeddedTokenizerModel::Llama => { + u8::from_str_radix(std::str::from_utf8(&data.text[3..5]).unwrap(), 16).unwrap() + } + _ => unimplemented!(), + } } } @@ -191,6 +264,7 @@ impl GgufEmbeddedTokenizer<'_> { } /// Typesafe tokenizer models. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum GgufEmbeddedTokenizerModel { /// Llama style SentencePiece (tokens and scores extracted from HF `tokenizer.model`) Llama, @@ -217,7 +291,9 @@ impl FromStr for GgufEmbeddedTokenizerModel { /// The type of a token. #[allow(missing_docs)] -pub enum GgufEmbeddedTokenizerTokenType { +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +pub enum TokenType { + #[default] Normal, Unknown, Control, @@ -225,7 +301,7 @@ pub enum GgufEmbeddedTokenizerTokenType { Unused, Byte, } -impl TryFrom for GgufEmbeddedTokenizerTokenType { +impl TryFrom for TokenType { type Error = u32; fn try_from(value: u32) -> Result { @@ -240,3 +316,217 @@ impl TryFrom for GgufEmbeddedTokenizerTokenType { } } } + +#[derive(Clone)] +struct Symbol { + prev: Option, + next: Option, + text: Token, + n: usize, +} + +struct LlmBigramSpm { + left: usize, + right: usize, + score: f32, + size: usize, +} +impl PartialOrd for LlmBigramSpm { + fn partial_cmp(&self, other: &Self) -> Option { + match self.score.partial_cmp(&other.score) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.left.partial_cmp(&other.left) + } +} +impl Ord for LlmBigramSpm { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} +impl PartialEq for LlmBigramSpm { + fn eq(&self, other: &Self) -> bool { + (self.score < other.score) || (self.score == other.score && self.left > other.left) + } +} +impl Eq for LlmBigramSpm {} + +impl LlmBigramSpm { + fn new(left: usize, right: usize, score: f32, size: usize) -> Self { + LlmBigramSpm { + left, + right, + score, + size, + } + } +} + +struct TokenizerSpm<'a> { + vocab: &'a EmbeddedTokenizer, + symbols: Vec, + work_queue: BinaryHeap, + rev_merge: HashMap, +} + +impl<'a> TokenizerSpm<'a> { + fn new(vocab: &'a EmbeddedTokenizer) -> Self { + TokenizerSpm { + vocab, + symbols: Vec::new(), + work_queue: BinaryHeap::new(), + rev_merge: HashMap::new(), + } + } + + fn tokenize(&mut self, text: &[u8]) -> Vec { + let mut output = vec![]; + + let mut index = 0; + let mut offs = 0; + while offs < text.len() { + let sym = Symbol { + prev: if index == 0 { None } else { Some(index - 1) }, + next: if offs == text.len() - 1 { + None + } else { + Some(index + 1) + }, + text: text[offs..].to_vec(), + n: std::cmp::min(text.len() - offs, utf8_len(text[offs])), + }; + offs += sym.n; + index += 1; + self.symbols.push(sym); + } + + for i in 1..self.symbols.len() { + self.try_add_bigram(Some(i - 1), Some(i)); + } + + while let Some(bigram) = self.work_queue.pop() { + let mut left_sym = self.symbols[bigram.left as usize].clone(); + let mut right_sym = self.symbols[bigram.right as usize].clone(); + + if left_sym.n == 0 || right_sym.n == 0 || left_sym.n + right_sym.n != bigram.size { + continue; + } + + left_sym.n += right_sym.n; + right_sym.n = 0; + + left_sym.next = right_sym.next; + if let Some(next) = right_sym.next { + self.symbols[next].prev = Some(bigram.left); + } + + let left_sym_prev = left_sym.prev; + let left_sym_next = left_sym.next; + self.symbols[bigram.left as usize] = left_sym; + self.symbols[bigram.right as usize] = right_sym; + + self.try_add_bigram(left_sym_prev, Some(bigram.left)); + self.try_add_bigram(Some(bigram.left), left_sym_next); + } + + let mut i = Some(0); + while let Some(idx) = i { + if idx >= self.symbols.len() { + break; + } + + let symbol = &self.symbols[idx as usize]; + self.resegment(symbol, &mut output); + i = symbol.next; + } + + output + } + + fn resegment(&self, symbol: &Symbol, output: &mut Vec) { + let text = &symbol.text; + let token = self.vocab.token_to_id.get(text); + + if let Some(&token_id) = token { + output.push(token_id); + return; + } + + if let Some(p) = self.rev_merge.get(text) { + self.resegment(&self.symbols[p.0 as usize], output); + self.resegment(&self.symbols[p.1 as usize], output); + } else { + for ch in text { + let token_id = self.vocab.byte_to_token(*ch); + output.push(token_id); + } + } + } + + fn try_add_bigram(&mut self, left: Option, right: Option) { + let Some((left, right)) = left.zip(right) else { + return; + }; + + let mut text = self.symbols[left].text.clone(); + text.extend_from_slice(&self.symbols[right].text); + + if let Some(&token_id) = self.vocab.token_to_id.get(&text) { + if (token_id as usize) < self.vocab.id_to_token.len() { + let tok_data = &self.vocab.id_to_token[token_id as usize]; + let bigram = LlmBigramSpm::new(left, right, tok_data.score, text.len()); + self.work_queue.push(bigram); + self.rev_merge.insert(text.clone(), (left, right)); + } + } + } +} + +fn utf8_len(src: u8) -> usize { + const LOOKUP: &[u8] = &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]; + let highbits: u8 = src >> 4; + LOOKUP[highbits as usize] as usize +} + +fn escape_whitespace(text: &[u8]) -> Vec { + let mut out = vec![]; + + for &b in text { + if b == b' ' { + out.extend_from_slice(&[0xE2, 0x96, 0x81]); + } else { + out.push(b); + } + } + + out +} + +fn unescape_whitespace(text: &[u8]) -> Vec { + let mut out = vec![]; + let mut buffer: Vec = vec![]; + + for &b in text { + if b == 0xE2 { + // If the current byte is 0xE2, start buffering and check for the sequence. + buffer.push(b); + } else if buffer.len() == 1 && b == 0x96 { + // If the previous byte was 0xE2 and the current byte is 0x96, continue buffering. + buffer.push(b); + } else if buffer.len() == 2 && b == 0x81 { + // If the previous bytes were 0xE2 and 0x96 and the current byte is 0x81, replace with space and reset buffer. + out.push(b' '); + buffer.clear(); + } else { + // If no match, flush the buffer and append the current byte. + out.append(&mut buffer); + out.push(b); + } + } + + // If there are any remaining bytes in the buffer, append them. + out.append(&mut buffer); + + out +} diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 8837e1ef..11ff1574 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -2,7 +2,7 @@ use std::{error::Error, fmt::Display, path::PathBuf, str::FromStr}; -use ggml::format::gguf::Gguf; +use ggml::format::gguf::{Gguf, MetadataError}; use thiserror::Error; mod embedded; @@ -45,6 +45,9 @@ pub enum TokenizerLoadError { #[error("no tokenizer was found, including in the model file")] /// No tokenizer was found, including in the model file. NoTokenizerFound, + #[error("{0}")] + /// An error occured with retrieving data from the metadata. + MetadataError(#[from] MetadataError), } /// Used to identify where the tokenizer that errored came from. @@ -136,18 +139,7 @@ impl TokenizerSource { .into() } - Self::Embedded => { - let mut tokenizer = EmbeddedTokenizer::default(); - let tok = GgufEmbeddedTokenizer::from_metadata(&gguf.metadata).map_err(|_| { - // TODO: consider passing the error along - TokenizerLoadError::NoTokenizerFound - })?; - for (i, (token, score)) in tok.tokens.iter().zip(tok.scores.iter()).enumerate() { - tokenizer.push_token(i as u32, token.as_bytes().to_vec(), *score); - } - - tokenizer.into() - } + Self::Embedded => EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into(), }) } } diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 192eec8c..26f2f252 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -416,7 +416,7 @@ impl llm_base::Hyperparameters for Hyperparameters { Ok(Self { // TODO: handle models without an embedded vocabulary vocabulary_count: metadata - .get_with_type("tokenizer.ggml.tokens", |v| v.as_array())? + .get_with_ref_type("tokenizer.ggml.tokens", |v| v.as_array())? .len(), embedding_length: metadata.get_countable("llama.embedding_length")?, head_count: metadata.get_countable("llama.attention.head_count")?, From f82751709e9de620d9c3ca062048ebe51b8cba54 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 23 Oct 2023 01:46:18 +0200 Subject: [PATCH 18/33] fix(llm): embedded tokenizer decode --- crates/llm-base/src/inference_session.rs | 27 ++++++++--------------- crates/llm-base/src/tokenizer/embedded.rs | 10 ++++----- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 67408b34..5d20aaac 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -317,14 +317,11 @@ impl InferenceSession { for &tk in batch { let should_call_callback = Some(tk) != model.bot_token_id(); - let mut token = match model.tokenizer() { - crate::Tokenizer::Embedded(_) => model.tokenizer().token(tk as usize).to_vec(), - crate::Tokenizer::HuggingFace(_) => { - let mut tokens = self.tokens.clone(); - tokens.push(tk); + let mut token = { + let mut tokens = self.tokens.clone(); + tokens.push(tk); - get_newly_decoded_portion_huggingface(model, tokens, &self.decoded_tokens) - } + get_newly_decoded_portion(model, tokens, &self.decoded_tokens) }; if should_call_callback { @@ -407,16 +404,7 @@ impl InferenceSession { if next_token as TokenId == model.eot_token_id() { Err(InferenceError::EndOfText) } else { - let res = match model.tokenizer() { - crate::Tokenizer::Embedded(_) => { - model.tokenizer().token(next_token as usize).to_vec() - } - crate::Tokenizer::HuggingFace(_) => get_newly_decoded_portion_huggingface( - model, - self.tokens.clone(), - &self.decoded_tokens, - ), - }; + let res = get_newly_decoded_portion(model, self.tokens.clone(), &self.decoded_tokens); self.decoded_tokens.append(&mut res.clone()); Ok(res) @@ -664,7 +652,10 @@ impl Drop for InferenceSession { } } -fn get_newly_decoded_portion_huggingface( +// TODO: Cache results and/or find a more intelligent way to do this. +// At present, this will decode *all* tokens generated by the model, which is +// not ideal when it gets run on each new token. Perhaps only consider the last three for decoding? +fn get_newly_decoded_portion( model: &dyn Model, tokens: Vec, decoded_tokens: &[u8], diff --git a/crates/llm-base/src/tokenizer/embedded.rs b/crates/llm-base/src/tokenizer/embedded.rs index 791505cf..bca26d84 100644 --- a/crates/llm-base/src/tokenizer/embedded.rs +++ b/crates/llm-base/src/tokenizer/embedded.rs @@ -82,8 +82,8 @@ impl EmbeddedTokenizer { }) .unwrap_or_default(); let types = metadata - .get_array_with_type("tokenizer.ggml.token_types", |v| { - v.as_array()?.as_uint32_array() + .get_array_with_type("tokenizer.ggml.token_type", |v| { + v.as_array()?.as_int32_array() }) .unwrap_or_default(); @@ -301,10 +301,10 @@ pub enum TokenType { Unused, Byte, } -impl TryFrom for TokenType { - type Error = u32; +impl TryFrom for TokenType { + type Error = i32; - fn try_from(value: u32) -> Result { + fn try_from(value: i32) -> Result { match value { 1 => Ok(Self::Normal), 2 => Ok(Self::Unknown), From 58cb8cce4efe268c9477d246d325660a8e22a5b2 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 23 Oct 2023 02:48:37 +0200 Subject: [PATCH 19/33] wip(llm): reconvert ggml tokenizer with GPT-4 --- crates/llm-base/src/tokenizer/embedded.rs | 121 ++++++++++------------ 1 file changed, 53 insertions(+), 68 deletions(-) diff --git a/crates/llm-base/src/tokenizer/embedded.rs b/crates/llm-base/src/tokenizer/embedded.rs index bca26d84..8a8119a1 100644 --- a/crates/llm-base/src/tokenizer/embedded.rs +++ b/crates/llm-base/src/tokenizer/embedded.rs @@ -319,60 +319,50 @@ impl TryFrom for TokenType { #[derive(Clone)] struct Symbol { - prev: Option, - next: Option, - text: Token, + prev: isize, + next: isize, + text: Vec, n: usize, } struct LlmBigramSpm { - left: usize, - right: usize, + left: isize, + right: isize, score: f32, size: usize, } impl PartialOrd for LlmBigramSpm { fn partial_cmp(&self, other: &Self) -> Option { - match self.score.partial_cmp(&other.score) { - Some(core::cmp::Ordering::Equal) => {} - ord => return ord, - } - self.left.partial_cmp(&other.left) + Some(self.cmp(other)) } } impl Ord for LlmBigramSpm { fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other).unwrap() + self.score + .partial_cmp(&other.score) + .unwrap_or(Ordering::Equal) + .then_with(|| other.left.cmp(&self.left)) } } + impl PartialEq for LlmBigramSpm { fn eq(&self, other: &Self) -> bool { - (self.score < other.score) || (self.score == other.score && self.left > other.left) + self.score == other.score && self.left == other.left } } -impl Eq for LlmBigramSpm {} -impl LlmBigramSpm { - fn new(left: usize, right: usize, score: f32, size: usize) -> Self { - LlmBigramSpm { - left, - right, - score, - size, - } - } -} +impl Eq for LlmBigramSpm {} struct TokenizerSpm<'a> { vocab: &'a EmbeddedTokenizer, symbols: Vec, work_queue: BinaryHeap, - rev_merge: HashMap, + rev_merge: HashMap, } impl<'a> TokenizerSpm<'a> { fn new(vocab: &'a EmbeddedTokenizer) -> Self { - TokenizerSpm { + Self { vocab, symbols: Vec::new(), work_queue: BinaryHeap::new(), @@ -382,19 +372,19 @@ impl<'a> TokenizerSpm<'a> { fn tokenize(&mut self, text: &[u8]) -> Vec { let mut output = vec![]; - let mut index = 0; let mut offs = 0; while offs < text.len() { + let len = text[offs..].len(); let sym = Symbol { - prev: if index == 0 { None } else { Some(index - 1) }, - next: if offs == text.len() - 1 { - None + text: text[offs..offs + len].to_vec(), + n: len.min(text.len() - offs), + prev: index - 1, + next: if offs + len == text.len() { + -1 } else { - Some(index + 1) + index + 1 }, - text: text[offs..].to_vec(), - n: std::cmp::min(text.len() - offs, utf8_len(text[offs])), }; offs += sym.n; index += 1; @@ -402,7 +392,7 @@ impl<'a> TokenizerSpm<'a> { } for i in 1..self.symbols.len() { - self.try_add_bigram(Some(i - 1), Some(i)); + self.try_add_bigram((i - 1) as isize, i as isize); } while let Some(bigram) = self.work_queue.pop() { @@ -417,78 +407,73 @@ impl<'a> TokenizerSpm<'a> { right_sym.n = 0; left_sym.next = right_sym.next; - if let Some(next) = right_sym.next { - self.symbols[next].prev = Some(bigram.left); + if right_sym.next >= 0 { + self.symbols[right_sym.next as usize].prev = bigram.left; } let left_sym_prev = left_sym.prev; let left_sym_next = left_sym.next; + self.symbols[bigram.left as usize] = left_sym; self.symbols[bigram.right as usize] = right_sym; - self.try_add_bigram(left_sym_prev, Some(bigram.left)); - self.try_add_bigram(Some(bigram.left), left_sym_next); + self.try_add_bigram(left_sym_prev, bigram.left); + self.try_add_bigram(bigram.left, left_sym_next); } - let mut i = Some(0); - while let Some(idx) = i { - if idx >= self.symbols.len() { - break; - } - - let symbol = &self.symbols[idx as usize]; + let mut i = 0; + while i != -1 { + let symbol = &self.symbols[i as usize]; self.resegment(symbol, &mut output); i = symbol.next; } - output } fn resegment(&self, symbol: &Symbol, output: &mut Vec) { - let text = &symbol.text; - let token = self.vocab.token_to_id.get(text); - - if let Some(&token_id) = token { + let text = symbol.text.clone(); + if let Some(&token_id) = self.vocab.token_to_id.get(&text) { output.push(token_id); return; } - if let Some(p) = self.rev_merge.get(text) { - self.resegment(&self.symbols[p.0 as usize], output); - self.resegment(&self.symbols[p.1 as usize], output); + if let Some(&(left, right)) = self.rev_merge.get(&text) { + self.resegment(&self.symbols[left as usize], output); + self.resegment(&self.symbols[right as usize], output); } else { - for ch in text { - let token_id = self.vocab.byte_to_token(*ch); + for &ch in &text { + let token_id = self.vocab.byte_to_token(ch); output.push(token_id); } } } - fn try_add_bigram(&mut self, left: Option, right: Option) { - let Some((left, right)) = left.zip(right) else { + fn try_add_bigram(&mut self, left: isize, right: isize) { + if left == -1 || right == -1 { return; - }; - - let mut text = self.symbols[left].text.clone(); - text.extend_from_slice(&self.symbols[right].text); + } + let text = [ + self.symbols[left as usize].text.clone(), + self.symbols[right as usize].text.clone(), + ] + .concat(); if let Some(&token_id) = self.vocab.token_to_id.get(&text) { if (token_id as usize) < self.vocab.id_to_token.len() { let tok_data = &self.vocab.id_to_token[token_id as usize]; - let bigram = LlmBigramSpm::new(left, right, tok_data.score, text.len()); + let bigram = LlmBigramSpm { + left, + right, + score: tok_data.score, + size: text.len(), + }; self.work_queue.push(bigram); - self.rev_merge.insert(text.clone(), (left, right)); + self.rev_merge.insert(text, (left, right)); } } } } -fn utf8_len(src: u8) -> usize { - const LOOKUP: &[u8] = &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]; - let highbits: u8 = src >> 4; - LOOKUP[highbits as usize] as usize -} - fn escape_whitespace(text: &[u8]) -> Vec { let mut out = vec![]; From e4db5b9db8c38e385ecfe4a61d2c06609896dc59 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 23 Oct 2023 02:58:20 +0200 Subject: [PATCH 20/33] fix(ggml/llmb): use IndexMap for GGUF --- Cargo.lock | 14 ++++++++------ Cargo.toml | 10 ++++++++-- crates/ggml/Cargo.toml | 1 + crates/ggml/src/format/gguf/metadata.rs | 5 +++-- crates/ggml/src/format/gguf/mod.rs | 12 +++++------- crates/llm-base/Cargo.toml | 11 +++++++---- crates/llm-base/src/lora.rs | 8 +++----- 7 files changed, 35 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 811ef654..af44f099 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1616,6 +1616,7 @@ version = "0.2.0-dev" dependencies = [ "anyhow", "ggml-sys", + "indexmap 2.0.2", "memmap2", "rand", "thiserror", @@ -1773,9 +1774,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" [[package]] name = "heck" @@ -1947,12 +1948,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.0.0" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.2", ] [[package]] @@ -2157,6 +2158,7 @@ dependencies = [ "bytemuck", "ggml", "half", + "indexmap 2.0.2", "llm-samplers", "memmap2", "partial_sort", @@ -3822,7 +3824,7 @@ version = "0.19.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" dependencies = [ - "indexmap 2.0.0", + "indexmap 2.0.2", "toml_datetime", "winnow", ] diff --git a/Cargo.toml b/Cargo.toml index ae5b22f7..9eb459ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ "crates/llm", "crates/llm-base", "crates/models/*", - "binaries/*" + "binaries/*", ] resolver = "2" default-members = ["binaries/llm-cli", "crates/llm"] @@ -33,6 +33,7 @@ memmap2 = "0.5.10" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = { version = "0.1", features = ["log"] } llm-samplers = "=0.0.6" +indexmap = "2.0.2" # Config for 'cargo dist' [workspace.metadata.dist] @@ -45,7 +46,12 @@ ci = ["github"] # The installers to generate for each app installers = ["shell", "powershell"] # Target platforms to build apps for (Rust target-triple syntax) -targets = ["x86_64-unknown-linux-gnu", "x86_64-apple-darwin", "x86_64-pc-windows-msvc", "aarch64-apple-darwin"] +targets = [ + "x86_64-unknown-linux-gnu", + "x86_64-apple-darwin", + "x86_64-pc-windows-msvc", + "aarch64-apple-darwin", +] # The profile that 'cargo dist' will build with [profile.dist] diff --git a/crates/ggml/Cargo.toml b/crates/ggml/Cargo.toml index 03aeb92a..108a98e2 100644 --- a/crates/ggml/Cargo.toml +++ b/crates/ggml/Cargo.toml @@ -11,6 +11,7 @@ ggml-sys = { path = "sys", version = "0.2.0-dev" } thiserror = { workspace = true } memmap2 = { workspace = true } +indexmap = { workspace = true } [dev-dependencies] rand = { workspace = true } diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index 35041fc5..244d9e70 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -1,5 +1,6 @@ -use std::{collections::HashMap, io::BufRead}; +use std::io::BufRead; +use indexmap::IndexMap; use thiserror::Error; use crate::util; @@ -7,7 +8,7 @@ use crate::util; use super::{GgufContext, GgufLoadError}; #[derive(Debug, Clone, PartialEq)] -pub struct Metadata(pub HashMap); +pub struct Metadata(pub IndexMap); impl Metadata { pub fn iter(&self) -> impl Iterator { self.0.iter() diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index 0b5da22b..ae878d8d 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -1,13 +1,11 @@ #![allow(missing_docs)] -use std::{ - collections::HashMap, - io::{BufRead, Seek}, -}; +use std::io::{BufRead, Seek}; use super::{data_size, header_size, ContainerType, ContainerTypeReadError}; use crate::{util, ElementType}; +use indexmap::IndexMap; use thiserror::Error; mod metadata; @@ -49,7 +47,7 @@ pub enum GgufSaveError { // TODO! } -pub type TensorInfos = HashMap; +pub type TensorInfos = IndexMap; #[derive(Debug, Clone, PartialEq)] pub struct Gguf { @@ -74,7 +72,7 @@ impl Gguf { let tensor_count = util::read_length(reader, ctx.use_64_bit_length)?; let metadata_kv_count = util::read_length(reader, ctx.use_64_bit_length)?; - let mut metadata = HashMap::with_capacity(metadata_kv_count); + let mut metadata = IndexMap::with_capacity(metadata_kv_count); for _ in 0..metadata_kv_count { let (key, value) = MetadataValue::read_key_value(&ctx, reader)?; metadata.insert(key, value); @@ -86,7 +84,7 @@ impl Gguf { .and_then(|v| v.as_uint32()) .unwrap_or(DEFAULT_ALIGNMENT) as u64; - let mut tensor_infos = HashMap::with_capacity(tensor_count); + let mut tensor_infos = IndexMap::with_capacity(tensor_count); for _ in 0..tensor_count { let (key, value) = TensorInfo::read_name_value(&ctx, reader)?; tensor_infos.insert(key, value); diff --git a/crates/llm-base/Cargo.toml b/crates/llm-base/Cargo.toml index badcbdc6..0474d1d6 100644 --- a/crates/llm-base/Cargo.toml +++ b/crates/llm-base/Cargo.toml @@ -17,16 +17,19 @@ bytemuck = { workspace = true } rand = { workspace = true } serde = { workspace = true } thiserror = { workspace = true } +indexmap = { workspace = true } +memmap2 = { workspace = true } +tracing = { workspace = true } +llm-samplers = { workspace = true } partial_sort = "0.2.0" serde_bytes = "0.11" -memmap2 = { workspace = true } half = "2" -tokenizers = {version="0.13.4", default-features=false, features=["onig"]} +tokenizers = { version = "0.13.4", default-features = false, features = [ + "onig", +] } regex = "1.8" -tracing = { workspace = true } -llm-samplers = { workspace = true } [features] tokenizers-remote = ["tokenizers/http"] diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index c2010767..ea1386c1 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -8,10 +8,8 @@ use ggml::{ format::gguf::{Gguf, Metadata, TensorInfo}, GraphExecutionPlan, }; -use std::{ - collections::{HashMap, HashSet}, - path::PathBuf, -}; +use indexmap::IndexMap; +use std::{collections::HashSet, path::PathBuf}; #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] /// Parameters for a [LoRA](https://arxiv.org/abs/2106.09685) adapter. @@ -50,7 +48,7 @@ pub struct LoraAdapter { /// Scaling to apply to the LoRA weights. pub scaling: f32, /// The tensors of the LoRA. - pub tensors: HashMap, + pub tensors: IndexMap, /// Names of the tensors that should be patched. pub tensors_to_patch: HashSet, /// Source containing the LoRA weights. From 89960611102c2954c7357f8c5e3c6c68901b5167 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 29 Oct 2023 19:55:01 +0100 Subject: [PATCH 21/33] fix(llmb): disable embedded tokenizer --- crates/ggml/src/format/gguf/metadata.rs | 11 ++-- crates/llm-base/src/tokenizer/embedded.rs | 20 +++--- crates/llm-base/src/tokenizer/mod.rs | 79 +++++++++++++++-------- crates/llm/src/loader.rs | 2 +- crates/models/llama/src/lib.rs | 5 +- 5 files changed, 75 insertions(+), 42 deletions(-) diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index 244d9e70..40a9be2f 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -26,6 +26,10 @@ impl Metadata { self.0.get(key) } + pub fn contains_key(&self, key: &str) -> bool { + self.0.contains_key(key) + } + pub fn get(&self, key: &str) -> Result<&MetadataValue, MetadataError> { self.get_optional(key) .ok_or_else(|| MetadataError::MissingKey { @@ -72,8 +76,8 @@ impl Metadata { }) } - // TODO: see if we can generalize this with `ToOwned` or something? - pub fn get_string(&self, key: &str) -> Result { + // TODO: consider + pub fn get_str(&self, key: &str) -> Result<&str, MetadataError> { let metadata_value = self.get(key)?; Ok(metadata_value .as_string() @@ -81,8 +85,7 @@ impl Metadata { key: key.to_string(), expected_type: MetadataValueType::String, actual_type: metadata_value.value_type(), - })? - .to_string()) + })?) } pub fn get_countable(&self, key: &str) -> Result { diff --git a/crates/llm-base/src/tokenizer/embedded.rs b/crates/llm-base/src/tokenizer/embedded.rs index 8a8119a1..02acb4c3 100644 --- a/crates/llm-base/src/tokenizer/embedded.rs +++ b/crates/llm-base/src/tokenizer/embedded.rs @@ -31,11 +31,11 @@ pub struct EmbeddedTokenizer { model: GgufEmbeddedTokenizerModel, bos_id: TokenId, - eos_id: TokenId, - unknown_id: TokenId, + _eos_id: TokenId, + _unknown_id: TokenId, linefeed_id: TokenId, - separator_id: Option, - padding_id: Option, + _separator_id: Option, + _padding_id: Option, } #[derive(Debug, Clone, Default)] struct TokenData { @@ -44,6 +44,10 @@ struct TokenData { ty: TokenType, } impl EmbeddedTokenizer { + pub(crate) fn is_present_in_metadata(metadata: &Metadata) -> bool { + metadata.contains_key("tokenizer.ggml.scores") + } + pub(crate) fn from_metadata(metadata: &Metadata) -> Result { let tok = GgufEmbeddedTokenizer::from_metadata(metadata)?; @@ -110,11 +114,11 @@ impl EmbeddedTokenizer { id_to_token, model: GgufEmbeddedTokenizerModel::Llama, bos_id, - eos_id, - unknown_id, + _eos_id: eos_id, + _unknown_id: unknown_id, linefeed_id: 0, - separator_id, - padding_id, + _separator_id: separator_id, + _padding_id: padding_id, }; tokenizer.linefeed_id = tokenizer.byte_to_token(b'\n'); diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 11ff1574..0f53aba8 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -42,9 +42,12 @@ pub enum TokenizerLoadError { /// The error that occurred during loading. error: Box, }, - #[error("no tokenizer was found, including in the model file")] - /// No tokenizer was found, including in the model file. - NoTokenizerFound, + #[error("no supported tokenizers were found, including in the model file: {unsupported_tokenizers:?}")] + /// No supported tokenizers were found, including in the model file. + NoSupportedTokenizersFound { + /// The list of tokenizers that were found, but not supported. + unsupported_tokenizers: Vec, + }, #[error("{0}")] /// An error occured with retrieving data from the metadata. MetadataError(#[from] MetadataError), @@ -75,6 +78,10 @@ impl Display for HuggingFaceTokenizerErrorSource { } } +/// At the time of writing, the embedded tokenizer is not enabled as it has +/// some bugs. We're just not enabling the option while it's broken. +const EMBEDDED_TOKENIZER_ENABLED: bool = false; + #[derive(Clone, Debug, PartialEq)] /// The source of a tokenizer. pub enum TokenizerSource { @@ -95,9 +102,6 @@ pub enum TokenizerSource { /// and may store files locally, so it is not recommended for production use. #[cfg(feature = "tokenizers-remote")] HuggingFaceRemote(String), - // - // TODO: Support embedded huggingface tokenizer from GGUF - // } impl TokenizerSource { /// Retrieve the tokenizer from the source. @@ -105,9 +109,9 @@ impl TokenizerSource { /// Note that this may make a blocking HTTP request to Hugging Face to retrieve the tokenizer. /// if `self` is `Self::HuggingFaceRemote`. pub fn retrieve(self, gguf: &Gguf) -> Result { - Ok(match self { + match self { #[cfg(feature = "tokenizers-remote")] - Self::HuggingFaceRemote(identifier) => HuggingFaceTokenizer::new( + Self::HuggingFaceRemote(identifier) => Ok(HuggingFaceTokenizer::new( tokenizers::Tokenizer::from_pretrained(&identifier, None).map_err(|error| { TokenizerLoadError::HuggingFaceTokenizerError { tokenizer_source: HuggingFaceTokenizerErrorSource::Remote( @@ -117,33 +121,54 @@ impl TokenizerSource { } })?, ) - .into(), + .into()), - Self::HuggingFaceTokenizerFile(path) => { - HuggingFaceTokenizer::new(tokenizers::Tokenizer::from_file(&path).map_err( - |error| TokenizerLoadError::HuggingFaceTokenizerError { - tokenizer_source: HuggingFaceTokenizerErrorSource::File(path.clone()), - error: error.into(), - }, - )?) - .into() - } - - Self::HuggingFaceTokenizerString(s) => { - HuggingFaceTokenizer::new(tokenizers::Tokenizer::from_str(&s).map_err(|error| { + Self::HuggingFaceTokenizerFile(path) => Ok(HuggingFaceTokenizer::new( + tokenizers::Tokenizer::from_file(&path).map_err(|error| { TokenizerLoadError::HuggingFaceTokenizerError { - tokenizer_source: HuggingFaceTokenizerErrorSource::String, + tokenizer_source: HuggingFaceTokenizerErrorSource::File(path.clone()), error: error.into(), } - })?) - .into() + })?, + ) + .into()), + + Self::HuggingFaceTokenizerString(s) => Ok(Self::load_huggingface_json(&s)?), + + Self::Embedded => { + if let Ok(hf) = gguf.metadata.get_str("tokenizer.huggingface.json") { + Ok(Self::load_huggingface_json(hf)?) + } else { + if EmbeddedTokenizer::is_present_in_metadata(&gguf.metadata) { + if EMBEDDED_TOKENIZER_ENABLED { + Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into()) + } else { + Err(TokenizerLoadError::NoSupportedTokenizersFound { + unsupported_tokenizers: vec!["embedded".to_owned()], + }) + } + } else { + Err(TokenizerLoadError::NoSupportedTokenizersFound { + unsupported_tokenizers: vec![], + }) + } + } } + } + } - Self::Embedded => EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into(), - }) + fn load_huggingface_json(tokenizer_json: &str) -> Result { + Ok( + HuggingFaceTokenizer::new(tokenizers::Tokenizer::from_str(tokenizer_json).map_err( + |error| TokenizerLoadError::HuggingFaceTokenizerError { + tokenizer_source: HuggingFaceTokenizerErrorSource::String, + error: error.into(), + }, + )?) + .into(), + ) } } - /// Encapsulates the tokenizer for a model, and provides methods to tokenize text. pub enum Tokenizer { /// The vocabulary built-in to the model. diff --git a/crates/llm/src/loader.rs b/crates/llm/src/loader.rs index 5397f7c1..6282a076 100644 --- a/crates/llm/src/loader.rs +++ b/crates/llm/src/loader.rs @@ -197,7 +197,7 @@ pub fn load( let architecture = gguf .metadata - .get_string("general.architecture")? + .get_str("general.architecture")? .parse::()?; let tokenizer = tokenizer_source.retrieve(&gguf)?; diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 26f2f252..9b9c9185 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -424,8 +424,9 @@ impl llm_base::Hyperparameters for Hyperparameters { block_count: metadata.get_countable("llama.block_count")?, file_type: FileType::read_for_hyperparameters(metadata)?, tensor_data_layout: metadata - .get_string("llama.tensor_data_layout") - .unwrap_or(META_TENSOR_DATA_LAYOUT.to_string()), + .get_str("llama.tensor_data_layout") + .unwrap_or(META_TENSOR_DATA_LAYOUT) + .to_string(), }) } From 34379acab14039cb8b9cb1c214f119d0bbf40db7 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 29 Oct 2023 21:15:17 +0100 Subject: [PATCH 22/33] refactor: move loading logic back into llmb, simplify --- binaries/llm-cli/src/main.rs | 10 +- binaries/llm-test/src/common.rs | 2 +- crates/llm-base/src/lib.rs | 6 +- crates/llm-base/src/loader.rs | 379 ++++++++++++++++++++++++++---- crates/llm-base/src/lora.rs | 25 +- crates/llm-base/src/model/mod.rs | 156 ++++--------- crates/llm-base/src/quantize.rs | 10 +- crates/llm/src/lib.rs | 12 +- crates/llm/src/loader.rs | 390 +++---------------------------- crates/models/bloom/src/lib.rs | 4 +- crates/models/falcon/src/lib.rs | 4 +- crates/models/gpt2/src/lib.rs | 4 +- crates/models/gptj/src/lib.rs | 4 +- crates/models/gptneox/src/lib.rs | 4 +- crates/models/llama/src/lib.rs | 96 +++----- crates/models/mpt/src/lib.rs | 4 +- 16 files changed, 471 insertions(+), 639 deletions(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index f8bbdac0..d19933af 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -1,12 +1,8 @@ -use std::{ - convert::Infallible, - fs::File, - io::{BufReader, BufWriter}, -}; +use std::{convert::Infallible, fs::File, io::BufReader}; use clap::Parser; use cli_args::Args; -use color_eyre::eyre::{self, Context, ContextCompat}; +use color_eyre::eyre; use is_terminal::IsTerminal; use llm::ggml_format::gguf; @@ -223,7 +219,7 @@ fn prompt_tokens(args: &cli_args::PromptTokens) -> eyre::Result<()> { // struct QuantizeVisitor<'a>(&'a cli_args::Quantize); // impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { -// fn visit(&mut self) -> eyre::Result<()> { +// fn visit(&mut self) -> eyre::Result<()> { // let args = self.0; // let mut source: BufReader = BufReader::new(std::fs::File::open(&args.source)?); diff --git a/binaries/llm-test/src/common.rs b/binaries/llm-test/src/common.rs index 63ea41d4..f910d095 100644 --- a/binaries/llm-test/src/common.rs +++ b/binaries/llm-test/src/common.rs @@ -12,7 +12,7 @@ pub(super) fn can_send(model: Box) -> anyhow::Result> model } -// pub(super) fn can_roundtrip_hyperparameters( +// pub(super) fn can_roundtrip_hyperparameters( // model: &M, // ) -> anyhow::Result<()> { // fn test_hyperparameters(hyperparameters: &M) -> anyhow::Result<()> { diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index 8796bbe1..9cdd60bc 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -29,12 +29,10 @@ pub use inference_session::{ ModelKVMemoryType, RewindError, SnapshotError, }; pub use llm_samplers::prelude::{Sampler, SamplerChain}; -pub use loader::{ - load_known_internal, ContainerType, FileMagic, FileType, FileTypeFormat, LoadKnownError, -}; +pub use loader::{ContainerType, FileMagic, FileType, FileTypeFormat}; pub use lora::{LoraAdapter, LoraParameters}; pub use memmap2::Mmap; -pub use model::{Hyperparameters, KnownModel, Model, ModelContext, ModelParameters, OutputRequest}; +pub use model::{Model, ModelContext, ModelParameters, OutputRequest}; pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use regex::Regex; pub use tokenizer::{ diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index a765d77c..e79999fc 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -2,21 +2,23 @@ use std::{ fmt::{Display, Formatter}, - io::{BufRead, Seek, SeekFrom}, + fs::File, + io::{BufRead, BufReader, Seek, SeekFrom}, path::Path, sync::Arc, }; use crate::{ - model::{Hyperparameters, HyperparametersReadError}, - KnownModel, LoraAdapter, ModelContext, ModelParameters, Tokenizer, + model::{HyperparametersReadError, ModelData, ModelLoadArgs, ModelLoadError}, + LoraAdapter, Model, ModelContext, ModelParameters, TokenizerLoadError, TokenizerSource, }; pub use ggml::{format::gguf::MetadataError, format::ContainerType, util::FileMagic}; use ggml::{ - format::gguf::{Gguf, Metadata, TensorInfo}, + format::gguf::{Gguf, GgufLoadError, Metadata, MetadataValue, MetadataValueType, TensorInfo}, sys::llama::llama_ftype, Context, MAX_NAME_LENGTH, }; +use memmap2::Mmap; use thiserror::Error; #[derive(Debug, PartialEq, Clone, Copy, Eq, Default)] @@ -194,63 +196,347 @@ impl Display for FileTypeFormat { pub trait Source: BufRead + Seek {} impl Source for S {} -/// Errors that can occur when loading a known model. +/// Each variant represents a step within the process of loading the model. +/// These can be used to report progress to the user. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum LoadProgress<'a> { + /// The hyperparameters have been loaded from the model. + HyperparametersLoaded, + /// The context has been created. + ContextSize { + /// The size of the context. + bytes: usize, + }, + /// A tensor was patched with a LoRA. + LoraApplied { + /// The name of the patched tensor. + name: &'a str, + /// LoRA file the patch was applied from. + source: &'a Path, + }, + /// A tensor from the current part has been loaded. + TensorLoaded { + /// The current tensor (0-indexed). + current_tensor: usize, + /// The number of total tensors. + tensor_count: usize, + }, + /// A model part has finished fully loading. + Loaded { + /// The number of bytes in the part. + file_size: u64, + /// The number of tensors in the part. + tensor_count: usize, + }, +} + #[derive(Error, Debug)] -pub enum LoadKnownError { - /// Failed to read the hyperparameters +/// Errors encountered during the loading process. +pub enum LoadError { + #[error("the file does not exist")] + /// The file does not exist. + FileDoesNotExist, + #[error("could not open file")] + /// A file failed to open. + OpenFileFailed { + /// The original error. + source: std::io::Error, + }, + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("invalid magic value {magic}")] + /// An invalid magic value was encountered during the loading process. + InvalidMagic { + /// The magic value that was encountered. + magic: FileMagic, + }, + #[error("invalid file format {container_type:?}")] + /// The version of the format is not supported by this version of `llm`. + InvalidFormatVersion { + /// The format that was encountered. + container_type: ContainerType, + }, + /// The tensor `tensor_name` had an unsupported element type. + #[error("invalid element type {element_type} for tensor `{tensor_name}`")] + UnsupportedElementType { + /// The name of the tensor. + tensor_name: String, + /// The element type that was encountered. + element_type: u32, + }, + /// The tokenizer could not be loaded. + #[error("could not load tokenizer: {0}")] + TokenizerLoadFail(#[from] TokenizerLoadError), + /// The quantization version was missing, despite this model containing quantized tensors. + #[error("quantization version was missing, despite model containing quantized tensors")] + MissingQuantizationVersion, + /// The quantization version is not supported by this version of `llm`. + #[error("quantization version {quantization_version:?} is not supported")] + UnsupportedQuantizationVersion { + /// The quantization version that was encountered. + quantization_version: MetadataValue, + }, + /// The model expected a metadata key-value pair, but the key was missing. + #[error("missing metadata key {key:?}")] + MissingMetadataKey { + /// The key that was missing. + key: String, + }, + /// The metadata key-value pair was not of the expected type. + #[error("metadata key {key:?} was not of the expected type")] + InvalidMetadataType { + /// The key with the invalid type. + key: String, + /// The expected type. + expected_type: MetadataValueType, + /// The actual type. + actual_type: MetadataValueType, + }, + /// The file type within the model was not supported by this version of `llm`. + #[error("file type {file_type} is not supported")] + UnsupportedFileType { + /// The file type (ignoring the quantization version) that was encountered. + file_type: llama_ftype, + }, + /// The architecture in the file is not known to the loader. + #[error("unknown architecture {architecture}")] + UnknownArchitecture { + /// The architecture that was encountered. + architecture: String, + }, + /// An error occurred while reading the hyperparameters. #[error("{0}")] HyperparametersReadError(#[from] HyperparametersReadError), - /// Failed to load the tensors + /// An error occurred while loading the concrete model. #[error("{0}")] - TensorLoadError(#[from] TensorLoadError), + ModelLoadError(#[from] ModelLoadError), +} +impl From for LoadError { + fn from(value: GgufLoadError) -> Self { + match value { + GgufLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { magic }, + GgufLoadError::InvalidFormatVersion(container_type) => { + LoadError::InvalidFormatVersion { container_type } + } + GgufLoadError::Io(err) => LoadError::Io(err), + GgufLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err), + GgufLoadError::InvalidIntegerConversion(err) => { + LoadError::InvalidIntegerConversion(err) + } + GgufLoadError::UnsupportedElementType { tensor_name, ftype } => { + LoadError::UnsupportedElementType { + tensor_name, + element_type: ftype, + } + } + } + } +} +impl From for LoadError { + fn from(value: MetadataError) -> Self { + Self::HyperparametersReadError(HyperparametersReadError::MetadataError(value)) + } } -/// Each variant represents a step within loading a known model. -#[derive(Debug, Copy, Clone)] -#[doc(hidden)] -pub enum LoadKnownProgress<'a> { - /// A LoRA has been applied. - LoraApplied { name: &'a str, source: &'a Path }, - /// A tensor has been loaded. - TensorLoaded { current_tensor: usize }, +/// When given args, attempt to instantiate a model. +pub type ModelLoadCallback = fn(ModelLoadArgs) -> Result, ModelLoadError>; + +/// A factory that can retrieve the constructor for a given model architecture. +pub trait ModelFactory { + /// For a given architecture name, return a function that will load the model, + /// or `None` if the architecture is not supported. + fn load(&self, architecture: &str) -> Option; } -/// Internal function that takes all of the state that can be derived without -/// knowing a concrete type and loads a concrete model. A *lot* of precondition -/// logic is done in `llm`. -// TODO: think about this design. Do we want to let people to be able to load -// known models directly? -#[doc(hidden)] -#[allow(clippy::too_many_arguments)] -pub fn load_known_internal( - source: &mut dyn Source, - gguf: &Gguf, - tokenizer: Tokenizer, - context: Context, - lora_adapters: Option>, - progress_callback: &mut dyn FnMut(LoadKnownProgress), +/// Loads the specified GGUF model from disk, determining its architecture from the metadata. +/// +/// This method returns a [`Box`], which means that the model will have single ownership. +/// If you'd like to share ownership (i.e. to use the model in multiple threads), we +/// suggest using [`Arc::from(Box)`](https://doc.rust-lang.org/std/sync/struct.Arc.html#impl-From%3CBox%3CT,+Global%3E%3E-for-Arc%3CT%3E) +/// to convert the [`Box`] into an [`Arc`](std::sync::Arc) after loading. +pub fn load( + path: &Path, + tokenizer_source: TokenizerSource, params: ModelParameters, -) -> Result { - let hyperparameters = ::read_gguf(&gguf.metadata)?; - let tl = ModelTensorLoader { - tensor_loader: TensorLoader { - source, - gguf: &gguf, - context, - }, - lora_adapters, - progress_callback, - loaded_tensor_count: 0, + model_factory: impl ModelFactory, + mut load_progress_callback: impl FnMut(LoadProgress), +) -> Result, LoadError> { + if !path.exists() { + return Err(LoadError::FileDoesNotExist); + } + + let file = File::open(path).map_err(|e| LoadError::OpenFileFailed { source: e })?; + let mut reader = BufReader::new(&file); + tracing::trace!("Read model file from {:?}", path); + + let gguf = Gguf::load(&mut reader)?; + tracing::trace!("Loaded GGML model from reader"); + + let architecture = gguf.metadata.get_str("general.architecture")?; + let tokenizer = tokenizer_source.retrieve(&gguf)?; + + let quantization_version = gguf.metadata.get_optional("general.quantization_version"); + tracing::trace!( + "Determined quantization version of model as {:?}", + quantization_version + ); + + // TODO: this is temporary while we figure out how to handle this + let any_quantized = gguf + .tensor_infos + .values() + .any(|t| t.element_type.is_quantized()); + if any_quantized { + match quantization_version { + Some(MetadataValue::UInt32(2)) => { + // Currently supported version + } + Some(quantization_version) => { + return Err(LoadError::UnsupportedQuantizationVersion { + quantization_version: quantization_version.clone(), + }) + } + None => return Err(LoadError::MissingQuantizationVersion), + } + } + + let use_mmap = params.prefer_mmap && params.lora_adapters.is_none(); + + let ctx_size = gguf + .tensor_infos + .values() + .map(|ti| ti.calc_absolute_size(use_mmap)) + .sum::(); + tracing::trace!("Context size: {:?}", ctx_size); + + let mut lora_adapters: Option> = None; + if let Some(lora_paths) = ¶ms.lora_adapters { + let adapters: Result, _> = lora_paths + .iter() + .map(|lora_path| { + // Read the LoRA file + let lora_file = File::open(lora_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + })?; + let mut lora_reader = BufReader::new(&lora_file); + let gguf = Gguf::load(&mut lora_reader)?; + + // Collect the names of the tensors that should be patched + let tensors_to_patch = gguf + .tensor_infos + .keys() + .filter_map(|k| Some(k.rsplit_once('.')?.0.to_owned())) + .collect(); + + tracing::trace!("Loaded LoRA weights"); + // Return the LoRA patches + #[allow(unreachable_code)] + Ok::<_, LoadError>(LoraAdapter { + tensors: gguf.tensor_infos.clone(), + tensors_to_patch, + source: Box::new(lora_reader), + path: lora_path.to_owned(), + gguf, + scaling: todo!("Calculate scaling from LoRA file metadata (GGUF does not have standardised metadata yet)"), + }) + }) + .collect(); + lora_adapters = Some(adapters?); + } + + (load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size }); + let (context, file_size) = if use_mmap { + unsafe { + let mmap = Mmap::map(&file)?; + let file_size = mmap.len() as u64; + (Context::new_with_mmap(mmap), file_size) + } + } else { + (Context::new_with_allocate(ctx_size), file.metadata()?.len()) }; - Ok(KnownModel::new(hyperparameters, params, tokenizer, tl)?) + let model_constructor = + model_factory + .load(architecture) + .ok_or_else(|| LoadError::UnknownArchitecture { + architecture: architecture.to_string(), + })?; + let model = (model_constructor)(ModelLoadArgs { + gguf: &gguf, + data: ModelData { params, tokenizer }, + tensor_loader: ModelTensorLoader { + tensor_loader: TensorLoader { + source: &mut reader, + gguf: &gguf, + context, + }, + lora_adapters, + progress_callback: &mut load_progress_callback, + loaded_tensor_count: 0, + }, + })?; + + (load_progress_callback)(LoadProgress::Loaded { + file_size, + tensor_count: gguf.tensor_infos.len(), + }); + + tracing::trace!("Loaded model"); + + Ok(model) +} + +/// A implementation for `load_progress_callback` that outputs to `stdout`. +pub fn load_progress_callback_stdout(progress: LoadProgress) { + match progress { + LoadProgress::HyperparametersLoaded => println!("Loaded hyperparameters"), + LoadProgress::ContextSize { bytes } => println!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::TensorLoaded { + current_tensor, + tensor_count, + .. + } => { + let current_tensor = current_tensor + 1; + if current_tensor % 8 == 0 { + println!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::Loaded { + file_size: byte_size, + tensor_count, + } => { + println!("Loading of model complete"); + println!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + LoadProgress::LoraApplied { name, source } => { + println!( + "Patched tensor {} via LoRA from '{}'", + name, + source.file_name().unwrap().to_str().unwrap() + ); + } + }; } /// A helper struct for loading tensors from a model. pub struct ModelTensorLoader<'a> { pub(crate) tensor_loader: TensorLoader<'a>, pub(crate) lora_adapters: Option>, - pub(crate) progress_callback: &'a mut dyn FnMut(LoadKnownProgress), + pub(crate) progress_callback: &'a mut dyn FnMut(LoadProgress), pub(crate) loaded_tensor_count: usize, } impl ModelTensorLoader<'_> { @@ -261,7 +547,7 @@ impl ModelTensorLoader<'_> { if let Some(lora_adapters) = &mut self.lora_adapters { for lora_adapter in lora_adapters { lora_adapter.patch(name, info, &mut tensor)?; - (self.progress_callback)(LoadKnownProgress::LoraApplied { + (self.progress_callback)(LoadProgress::LoraApplied { name, source: &lora_adapter.path, }); @@ -269,8 +555,9 @@ impl ModelTensorLoader<'_> { } self.loaded_tensor_count += 1; - (self.progress_callback)(LoadKnownProgress::TensorLoaded { + (self.progress_callback)(LoadProgress::TensorLoaded { current_tensor: self.loaded_tensor_count, + tensor_count: self.tensor_loader.gguf.tensor_infos.len(), }); Ok(tensor) diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index ea1386c1..07bf97c5 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -1,11 +1,7 @@ -use crate::{ - loader::{Source, TensorLoadError, TensorLoader}, - model::{HyperparametersReadError, HyperparametersWriteError}, - FileType, Hyperparameters, -}; +use crate::loader::{Source, TensorLoadError, TensorLoader}; use ggml::{ - format::gguf::{Gguf, Metadata, TensorInfo}, + format::gguf::{Gguf, TensorInfo}, GraphExecutionPlan, }; use indexmap::IndexMap; @@ -25,23 +21,6 @@ impl LoraParameters { (self.alpha as f32) / (self.r as f32) } } -impl Hyperparameters for LoraParameters { - fn read_gguf(metadata: &Metadata) -> Result { - todo!() - } - - fn write_gguf(&self, metadata: &mut Metadata) -> Result<(), HyperparametersWriteError> { - todo!() - } - - fn file_type(&self) -> Option { - None - } - - fn file_type_mut(&mut self) -> Option<&mut FileType> { - None - } -} /// [LoRA](https://arxiv.org/abs/2106.09685) adapter for a model. pub struct LoraAdapter { diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index a2640516..916f62e1 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -4,7 +4,7 @@ use std::{fmt::Debug, path::PathBuf, sync::Arc}; use ggml::{ accelerator::Backend, - format::gguf::{Metadata, MetadataError}, + format::gguf::{Gguf, MetadataError}, sys::llama::llama_ftype, }; use regex::Regex; @@ -13,26 +13,47 @@ use thiserror::Error; use crate::{ loader::{ModelTensorLoader, TensorLoadError}, tokenizer::TokenId, - FileType, InferenceSession, InferenceSessionConfig, Tokenizer, + InferenceSession, InferenceSessionConfig, Tokenizer, }; /// Common functions for model evaluation pub mod common; +/// All of the arguments required to load a model. +pub struct ModelLoadArgs<'a> { + /// The GGUF metadata for the model. + pub gguf: &'a Gguf, + /// Model metadata. + pub data: ModelData, + /// The tensor loader to use for the model. + pub tensor_loader: ModelTensorLoader<'a>, +} + +/// Model data that is required for all models. +pub struct ModelData { + /// Any parameters that control the behaviour of the model. + pub params: ModelParameters, + /// The tokenizer to use for the model. + pub tokenizer: Tokenizer, +} + +/// An error encountered while loading a concrete model. +#[derive(Error, Debug)] +pub enum ModelLoadError { + /// An error occurred while loading the model's tensors. + #[error("{0}")] + TensorLoadError(#[from] TensorLoadError), + /// An error occurred while reading the model's hyperparameters. + #[error("{0}")] + HyperparametersReadError(#[from] HyperparametersReadError), +} + /// Interfaces for creating and interacting with a large language model with a known type /// of [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)). -pub trait KnownModel: Send + Sync { - /// Hyperparameters for the model. - type Hyperparameters: Hyperparameters; - +pub trait Model: Send + Sync { /// Creates a new model from the provided [ModelParameters] hyperparameters. /// This function is called by the [load](crate::loader::load) function. - fn new( - hyperparameters: Self::Hyperparameters, - params: ModelParameters, - tokenizer: Tokenizer, - tensor_loader: ModelTensorLoader, - ) -> Result + fn new(args: ModelLoadArgs) -> Result where Self: Sized; @@ -50,15 +71,19 @@ pub trait KnownModel: Send + Sync { output_request: &mut OutputRequest, ); - /// Get the hyperparameters for this model. - fn hyperparameters(&self) -> &Self::Hyperparameters; + /// Get the data for this model. + fn data(&self) -> &ModelData; /// Get the tokenizer for this model. - fn tokenizer(&self) -> &Tokenizer; + fn tokenizer(&self) -> &Tokenizer { + &self.data().tokenizer + } /// Get the context size (configured with [ModelParameters::context_size]) used by /// this model. - fn context_size(&self) -> usize; + fn context_size(&self) -> usize { + self.data().params.context_size + } /// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers. fn bot_token_id(&self) -> Option; @@ -67,10 +92,10 @@ pub trait KnownModel: Send + Sync { fn eot_token_id(&self) -> TokenId; /// Get the list of regexes to use to determine if a tensor in this model should be quantized. - fn quantize_tensors() -> Vec; + fn quantize_tensors(&self) -> Vec; /// Get the list of regexes to use to determine if a tensor in this model should not be quantized. - fn skip_quantize_tensors() -> Vec; + fn skip_quantize_tensors(&self) -> Vec; /// Returns whether the model supports deleting tokens. fn supports_rewind(&self) -> bool { @@ -79,91 +104,8 @@ pub trait KnownModel: Send + Sync { } } -/// A type-erased model to allow for interacting with a model without knowing -/// its hyperparameters. -pub trait Model: Send + Sync { - /// Starts a new `InferenceSession` for this model. - fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession; - - /// This function is called by the provided [InferenceSession]; it will use this model - /// to generate output by evaluating the `input_tokens`. - /// The [OutputRequest] is used to specify additional data to fetch from the - /// model. - fn evaluate( - &self, - session: &mut InferenceSession, - input_tokens: &[TokenId], - output_request: &mut OutputRequest, - ); - - /// Get the tokenizer for this model. - fn tokenizer(&self) -> &Tokenizer; - - /// Get the context size (configured with [ModelParameters::context_size]) used by - /// this model. - fn context_size(&self) -> usize; - - /// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers. - fn bot_token_id(&self) -> Option; - - /// Get the end of text/end of string token ID. This value is defined by model implementers. - fn eot_token_id(&self) -> TokenId; - - /// Returns whether the model supports deleting tokens. - fn supports_rewind(&self) -> bool; -} -impl> Model for M { - fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { - KnownModel::start_session(self, config) - } - - fn evaluate( - &self, - session: &mut InferenceSession, - input_tokens: &[TokenId], - output_request: &mut OutputRequest, - ) { - KnownModel::evaluate(self, session, input_tokens, output_request) - } - - fn tokenizer(&self) -> &Tokenizer { - KnownModel::tokenizer(self) - } - - fn context_size(&self) -> usize { - KnownModel::context_size(self) - } - - fn bot_token_id(&self) -> Option { - KnownModel::bot_token_id(self) - } - - fn eot_token_id(&self) -> TokenId { - KnownModel::eot_token_id(self) - } - - fn supports_rewind(&self) -> bool { - KnownModel::supports_rewind(self) - } -} - -/// Implemented by model hyperparameters for interacting with hyperparameters -/// without knowing what they are, as well as writing/reading them as required. -pub trait Hyperparameters: Sized + Default + Debug + PartialEq + Eq { - /// Read the parameters from GGUF metadata. - fn read_gguf(metadata: &Metadata) -> Result; - - /// Write the parameters to GGUF metadata. - fn write_gguf(&self, metadata: &mut Metadata) -> Result<(), HyperparametersWriteError>; - - /// Get the filetype of the model. - fn file_type(&self) -> Option; - - /// Get mutable access to filetype of the model. - fn file_type_mut(&mut self) -> Option<&mut FileType>; -} #[derive(Error, Debug)] -/// Reported from functions that write +/// Reported from functions that read hyperparameters pub enum HyperparametersReadError { #[error("{0}")] /// A metadata error. @@ -175,16 +117,6 @@ pub enum HyperparametersReadError { file_type: llama_ftype, }, } -#[derive(Error, Debug)] -/// Reported from functions that write -pub enum HyperparametersWriteError { - #[error("non-specific I/O error")] - /// A non-specific IO error. - Io(#[from] std::io::Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), -} /// Parameters for model-wide behaviour. #[derive(Debug, Clone)] diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index 95d30e50..faf1ee2a 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -3,10 +3,7 @@ //! Implements quantization of weights. -use crate::{ - loader::FileTypeFormat, model::HyperparametersWriteError, Hyperparameters, KnownModel, - Tokenizer, -}; +use crate::{loader::FileTypeFormat, Model, Tokenizer}; use ggml::format::gguf::GgufSaveError; use half::f16; use regex::Regex; @@ -114,9 +111,6 @@ pub enum QuantizeError { /// The element type. element_type: ggml::Type, }, - /// An error was encountered while writing the hyperparameters. - #[error("an error was encountered while writing the hyperparameters")] - HyperparametersWriteError(#[source] HyperparametersWriteError), /// An attempt was made to save a model with a container type that does not /// support vocabulary scoring, despite the model having a scored vocabulary. #[error("container type does not support vocabulary scoring")] @@ -140,7 +134,7 @@ impl QuantizeError { } /// Quantizes a model. -pub fn quantize( +pub fn quantize( reader: &mut R, writer: &mut W, tokenizer: Tokenizer, diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 02b2ef4c..f7bf0d03 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -80,11 +80,11 @@ pub use llm_base::{ ggml::accelerator::get_accelerator as ggml_get_accelerator, ggml::accelerator::Accelerator as GgmlAccelerator, ggml::format as ggml_format, ggml::RoPEOverrides, quantize, samplers, tokenizer, ElementType, FileMagic, FileType, - FileTypeFormat, Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, - InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, - InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, Model, - ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, - RewindError, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, + FileTypeFormat, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, + InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, + InferenceSnapshotRef, InferenceStats, InvalidTokenBias, Model, ModelKVMemoryType, + ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, RewindError, + SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource, }; @@ -185,7 +185,7 @@ define_models!( /// Used to dispatch some code based on the model architecture. pub trait ModelArchitectureVisitor { /// Visit a model architecture. - fn visit(self) -> R; + fn visit(self) -> R; } /// An unsupported model architecture was specified. diff --git a/crates/llm/src/loader.rs b/crates/llm/src/loader.rs index 6282a076..bc4c871c 100644 --- a/crates/llm/src/loader.rs +++ b/crates/llm/src/loader.rs @@ -1,178 +1,17 @@ -use std::{fs::File, io::BufReader, path::Path}; +use std::path::Path; +pub use llm_base::loader::{load_progress_callback_stdout, LoadError, LoadProgress}; use llm_base::{ - ggml::{ - format::gguf::{Gguf, GgufLoadError, MetadataValue, MetadataValueType}, - sys::llama::llama_ftype, - Context, - }, - loader::{LoadKnownProgress, Source}, - loader::{MetadataError, TensorLoadError}, - model::HyperparametersReadError, - ContainerType, FileMagic, KnownModel, LoadKnownError, LoraAdapter, Mmap, Model, - ModelParameters, Tokenizer, TokenizerLoadError, TokenizerSource, + loader::{ModelFactory, ModelLoadCallback}, + model::{ModelLoadArgs, ModelLoadError}, + Model, ModelParameters, TokenizerSource, }; -use thiserror::Error; -use tracing::log; +use crate::{ModelArchitecture, ModelArchitectureVisitor}; -use crate::{ModelArchitecture, ModelArchitectureVisitor, UnsupportedModelArchitecture}; - -/// Each variant represents a step within the process of loading the model. -/// These can be used to report progress to the user. -#[derive(Clone, PartialEq, Eq, Debug)] -pub enum LoadProgress<'a> { - /// The hyperparameters have been loaded from the model. - HyperparametersLoaded, - /// The context has been created. - ContextSize { - /// The size of the context. - bytes: usize, - }, - /// A tensor was patched with a LoRA. - LoraApplied { - /// The name of the patched tensor. - name: &'a str, - /// LoRA file the patch was applied from. - source: &'a Path, - }, - /// A tensor from the current part has been loaded. - TensorLoaded { - /// The current tensor (0-indexed). - current_tensor: usize, - /// The number of total tensors. - tensor_count: usize, - }, - /// A model part has finished fully loading. - Loaded { - /// The number of bytes in the part. - file_size: u64, - /// The number of tensors in the part. - tensor_count: usize, - }, -} - -#[derive(Error, Debug)] -/// Errors encountered during the loading process. -pub enum LoadError { - #[error("the file does not exist")] - /// The file does not exist. - FileDoesNotExist, - #[error("could not open file")] - /// A file failed to open. - OpenFileFailed { - /// The original error. - source: std::io::Error, - }, - #[error("non-specific I/O error")] - /// A non-specific IO error. - Io(#[from] std::io::Error), - #[error("could not convert bytes to a UTF-8 string")] - /// One of the strings encountered was not valid UTF-8. - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("invalid magic value {magic}")] - /// An invalid magic value was encountered during the loading process. - InvalidMagic { - /// The magic value that was encountered. - magic: FileMagic, - }, - #[error("invalid file format {container_type:?}")] - /// The version of the format is not supported by this version of `llm`. - InvalidFormatVersion { - /// The format that was encountered. - container_type: ContainerType, - }, - /// The tensor `tensor_name` had an unsupported element type. - #[error("invalid element type {element_type} for tensor `{tensor_name}`")] - UnsupportedElementType { - /// The name of the tensor. - tensor_name: String, - /// The element type that was encountered. - element_type: u32, - }, - /// The tokenizer could not be loaded. - #[error("could not load tokenizer: {0}")] - TokenizerLoadFail(#[from] TokenizerLoadError), - /// The quantization version was missing, despite this model containing quantized tensors. - #[error("quantization version was missing, despite model containing quantized tensors")] - MissingQuantizationVersion, - /// The quantization version is not supported by this version of `llm`. - #[error("quantization version {quantization_version:?} is not supported")] - UnsupportedQuantizationVersion { - /// The quantization version that was encountered. - quantization_version: MetadataValue, - }, - /// The model expected a metadata key-value pair, but the key was missing. - #[error("missing metadata key {key:?}")] - MissingMetadataKey { - /// The key that was missing. - key: String, - }, - /// The metadata key-value pair was not of the expected type. - #[error("metadata key {key:?} was not of the expected type")] - InvalidMetadataType { - /// The key with the invalid type. - key: String, - /// The expected type. - expected_type: MetadataValueType, - /// The actual type. - actual_type: MetadataValueType, - }, - /// The file type within the model was not supported by this version of `llm`. - #[error("file type {file_type} is not supported")] - UnsupportedFileType { - /// The file type (ignoring the quantization version) that was encountered. - file_type: llama_ftype, - }, - /// The architecture specified in this model is not supported by `llm`. - #[error("architecture is not supported: {0}")] - UnsupportedArchitecture(#[from] UnsupportedModelArchitecture), - /// An error occurred while reading the hyperparameters. - #[error("{0}")] - HyperparametersReadError(HyperparametersReadError), - /// An error occurred while reading the tensors. - #[error("{0}")] - TensorLoadError(TensorLoadError), -} -impl From for LoadError { - fn from(value: GgufLoadError) -> Self { - match value { - GgufLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { magic }, - GgufLoadError::InvalidFormatVersion(container_type) => { - LoadError::InvalidFormatVersion { container_type } - } - GgufLoadError::Io(err) => LoadError::Io(err), - GgufLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err), - GgufLoadError::InvalidIntegerConversion(err) => { - LoadError::InvalidIntegerConversion(err) - } - GgufLoadError::UnsupportedElementType { tensor_name, ftype } => { - LoadError::UnsupportedElementType { - tensor_name, - element_type: ftype, - } - } - } - } -} -impl From for LoadError { - fn from(value: LoadKnownError) -> Self { - match value { - LoadKnownError::HyperparametersReadError(e) => Self::HyperparametersReadError(e), - LoadKnownError::TensorLoadError(e) => Self::TensorLoadError(e), - } - } -} -impl From for LoadError { - fn from(value: MetadataError) -> Self { - Self::HyperparametersReadError(HyperparametersReadError::MetadataError(value)) - } -} - -/// Loads the specified GGUF model from disk, determining its architecture from the metadata. +/// Loads the specified GGUF model from disk, determining its architecture from the metadata, +/// and loading it with one of the supported modules. If you want to load a custom model, +/// consider using [llm_base::loader::load] directly. /// /// This method returns a [`Box`], which means that the model will have single ownership. /// If you'd like to share ownership (i.e. to use the model in multiple threads), we @@ -182,200 +21,35 @@ pub fn load( path: &Path, tokenizer_source: TokenizerSource, params: ModelParameters, - mut load_progress_callback: impl FnMut(LoadProgress), + load_progress_callback: impl FnMut(LoadProgress), ) -> Result, LoadError> { - if !path.exists() { - return Err(LoadError::FileDoesNotExist); - } - - let file = File::open(path).map_err(|e| LoadError::OpenFileFailed { source: e })?; - let mut reader = BufReader::new(&file); - log::trace!("Read model file from {:?}", path); - - let gguf = Gguf::load(&mut reader)?; - log::trace!("Loaded GGML model from reader"); - - let architecture = gguf - .metadata - .get_str("general.architecture")? - .parse::()?; - - let tokenizer = tokenizer_source.retrieve(&gguf)?; - - let quantization_version = gguf.metadata.get_optional("general.quantization_version"); - log::trace!( - "Determined quantization version of model as {:?}", - quantization_version - ); - - // TODO: this is temporary while we figure out how to handle this - let any_quantized = gguf - .tensor_infos - .values() - .any(|t| t.element_type.is_quantized()); - if any_quantized { - match quantization_version { - Some(MetadataValue::UInt32(2)) => { - // Currently supported version - } - Some(quantization_version) => { - return Err(LoadError::UnsupportedQuantizationVersion { - quantization_version: quantization_version.clone(), - }) - } - None => return Err(LoadError::MissingQuantizationVersion), - } - } - - let use_mmap = params.prefer_mmap && params.lora_adapters.is_none(); - - let ctx_size = gguf - .tensor_infos - .values() - .map(|ti| ti.calc_absolute_size(use_mmap)) - .sum::(); - log::trace!("Context size: {:?}", ctx_size); - - let mut lora_adapters: Option> = None; - if let Some(lora_paths) = ¶ms.lora_adapters { - let adapters: Result, _> = lora_paths - .iter() - .map(|lora_path| { - // Read the LoRA file - let lora_file = File::open(lora_path).map_err(|e| LoadError::OpenFileFailed { - source: e, - })?; - let mut lora_reader = BufReader::new(&lora_file); - let gguf = Gguf::load(&mut lora_reader)?; - - // Collect the names of the tensors that should be patched - let tensors_to_patch = gguf - .tensor_infos - .keys() - .filter_map(|k| Some(k.rsplit_once('.')?.0.to_owned())) - .collect(); - - log::trace!("Loaded LoRA weights"); - // Return the LoRA patches - #[allow(unreachable_code)] - Ok::<_, LoadError>(LoraAdapter { - tensors: gguf.tensor_infos.clone(), - tensors_to_patch, - source: Box::new(lora_reader), - path: lora_path.to_owned(), - gguf, - scaling: todo!("Calculate scaling from LoRA file metadata (GGUF does not have standardised metadata yet)"), - }) - }) - .collect(); - lora_adapters = Some(adapters?); - } - - (load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size }); - let (context, file_size) = if use_mmap { - unsafe { - let mmap = Mmap::map(&file)?; - let file_size = mmap.len() as u64; - (Context::new_with_mmap(mmap), file_size) - } - } else { - (Context::new_with_allocate(ctx_size), file.metadata()?.len()) - }; - - let model = architecture.visit(LoadVisitor { - source: &mut reader, - gguf: &gguf, - tokenizer, - context, - lora_adapters, - load_progress_callback: &mut load_progress_callback, + Ok(llm_base::loader::load( + path, + tokenizer_source, params, - })?; - - (load_progress_callback)(LoadProgress::Loaded { - file_size, - tensor_count: gguf.tensor_infos.len(), - }); - - log::trace!("Loaded model"); - - Ok(model) + VisitorModelFactory, + load_progress_callback, + )?) } -struct LoadVisitor<'a, F: FnMut(LoadProgress)> { - source: &'a mut dyn Source, - gguf: &'a Gguf, - tokenizer: Tokenizer, - context: Context, - lora_adapters: Option>, - load_progress_callback: F, - params: ModelParameters, +struct VisitorModelFactory; +impl ModelFactory for VisitorModelFactory { + fn load(&self, architecture: &str) -> Option { + let architecture = architecture.parse::().ok()?; + Some(architecture.visit(VisitorModelFactoryVisitor)) + } } -impl<'a, F: FnMut(LoadProgress)> ModelArchitectureVisitor, LoadError>> - for LoadVisitor<'a, F> -{ - fn visit(mut self) -> Result, LoadError> { - let model = Box::new(llm_base::load_known_internal::( - self.source, - self.gguf, - self.tokenizer, - self.context, - self.lora_adapters, - &mut |step| { - (self.load_progress_callback)(match step { - LoadKnownProgress::LoraApplied { name, source } => { - LoadProgress::LoraApplied { name, source } - } - LoadKnownProgress::TensorLoaded { current_tensor } => { - LoadProgress::TensorLoaded { - current_tensor, - tensor_count: self.gguf.tensor_infos.len(), - } - } - }) - }, - self.params, - )?); - Ok(model) +struct VisitorModelFactoryVisitor; +impl ModelArchitectureVisitor for VisitorModelFactoryVisitor { + fn visit(self) -> ModelLoadCallback { + Self::new_for_model:: } } - -/// A implementation for `load_progress_callback` that outputs to `stdout`. -pub fn load_progress_callback_stdout(progress: LoadProgress) { - match progress { - LoadProgress::HyperparametersLoaded => println!("Loaded hyperparameters"), - LoadProgress::ContextSize { bytes } => println!( - "ggml ctx size = {:.2} MB\n", - bytes as f64 / (1024.0 * 1024.0) - ), - LoadProgress::TensorLoaded { - current_tensor, - tensor_count, - .. - } => { - let current_tensor = current_tensor + 1; - if current_tensor % 8 == 0 { - println!("Loaded tensor {current_tensor}/{tensor_count}"); - } - } - LoadProgress::Loaded { - file_size: byte_size, - tensor_count, - } => { - println!("Loading of model complete"); - println!( - "Model size = {:.2} MB / num tensors = {}", - byte_size as f64 / 1024.0 / 1024.0, - tensor_count - ); - } - LoadProgress::LoraApplied { name, source } => { - println!( - "Patched tensor {} via LoRA from '{}'", - name, - source.file_name().unwrap().to_str().unwrap() - ); - } - }; +impl VisitorModelFactoryVisitor { + fn new_for_model( + args: ModelLoadArgs, + ) -> Result, ModelLoadError> { + Ok(M::new(args).map(Box::new)?) + } } diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index 3880f8c4..d8c6f822 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -5,7 +5,7 @@ // use llm_base::{ // ggml, // model::{common, HyperparametersWriteError}, -// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, // ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, // }; @@ -41,7 +41,7 @@ // unsafe impl Send for Bloom {} // unsafe impl Sync for Bloom {} -// impl KnownModel for Bloom { +// impl Model for Bloom { // type Hyperparameters = Hyperparameters; // fn new( diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs index db26a6d1..79d118ee 100644 --- a/crates/models/falcon/src/lib.rs +++ b/crates/models/falcon/src/lib.rs @@ -11,7 +11,7 @@ // use llm_base::{ // ggml, // model::{common, HyperparametersWriteError}, -// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, LoadError, // ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, // }; @@ -43,7 +43,7 @@ // unsafe impl Send for Falcon {} // unsafe impl Sync for Falcon {} -// impl KnownModel for Falcon { +// impl Model for Falcon { // type Hyperparameters = Hyperparameters; // fn new( diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index 59933b67..b370a7ab 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -5,7 +5,7 @@ // use llm_base::{ // ggml, // model::{common, HyperparametersWriteError}, -// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, LoadError, // ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, // }; @@ -42,7 +42,7 @@ // unsafe impl Send for Gpt2 {} // unsafe impl Sync for Gpt2 {} -// impl KnownModel for Gpt2 { +// impl Model for Gpt2 { // type Hyperparameters = Hyperparameters; // fn new( diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index dd70728f..03da600c 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -7,7 +7,7 @@ // use llm_base::{ // ggml, // model::{common, HyperparametersWriteError}, -// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, LoadError, // ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, // }; @@ -41,7 +41,7 @@ // unsafe impl Send for GptJ {} // unsafe impl Sync for GptJ {} -// impl KnownModel for GptJ { +// impl Model for GptJ { // type Hyperparameters = Hyperparameters; // fn new( diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 93ecf6cf..436939ed 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -8,7 +8,7 @@ // use llm_base::{ // ggml, // model::{common, HyperparametersWriteError}, -// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, LoadError, // ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, // }; @@ -41,7 +41,7 @@ // unsafe impl Send for GptNeoX {} // unsafe impl Sync for GptNeoX {} -// impl KnownModel for GptNeoX { +// impl Model for GptNeoX { // type Hyperparameters = Hyperparameters; // fn new( diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 9b9c9185..46076113 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -3,10 +3,9 @@ use llm_base::{ ggml::{self, format::gguf::Metadata}, - loader::{ModelTensorLoader, TensorLoadError}, - model::{common, HyperparametersReadError, HyperparametersWriteError}, - FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, ModelContext, - ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, + model::{common, HyperparametersReadError, ModelData, ModelLoadArgs, ModelLoadError}, + FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, ModelContext, + OutputRequest, Regex, TokenId, }; const META_TENSOR_DATA_LAYOUT: &str = "Meta AI original pth"; @@ -16,9 +15,8 @@ const META_TENSOR_DATA_LAYOUT: &str = "Meta AI original pth"; /// # Safety /// This implements [Send] and [Sync] as it is immutable after construction. pub struct Llama { - params: ModelParameters, + data: ModelData, hyperparameters: Hyperparameters, - tokenizer: Tokenizer, // model-global weights // weighted token embeddings wte: ggml::Tensor, @@ -37,23 +35,19 @@ pub struct Llama { unsafe impl Send for Llama {} unsafe impl Sync for Llama {} -impl KnownModel for Llama { - type Hyperparameters = Hyperparameters; - - fn new( - hyperparameters: Self::Hyperparameters, - params: ModelParameters, - tokenizer: Tokenizer, - tensor_loader: ModelTensorLoader, - ) -> Result { +impl Model for Llama { + fn new(args: ModelLoadArgs) -> Result { + let hyperparameters = Hyperparameters::read(&args.gguf.metadata)?; assert_eq!(hyperparameters.tensor_data_layout, META_TENSOR_DATA_LAYOUT); - let mut tl = tensor_loader; + let mut tl = args.tensor_loader; // model-global weights let wte = tl.load("token_embd.weight")?; - let backend = params.backend(0); + let data = args.data; + + let backend = data.params.backend(0); let norm = tl.load("output_norm.weight")?.transfer_to(backend); let output = tl.load("output.weight")?.transfer_to(backend); @@ -61,7 +55,7 @@ impl KnownModel for Llama { let mut blocks = Vec::new(); for i in 0..hyperparameters.block_count { - let backend = params.backend(i); + let backend = data.params.backend(i); let block = Block { attn_n: tl @@ -97,9 +91,8 @@ impl KnownModel for Llama { let context = tl.finish(); Ok(Self { + data, hyperparameters, - params, - tokenizer, wte, norm, output, @@ -112,7 +105,7 @@ impl KnownModel for Llama { fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { InferenceSession::new( config, - &self.params, + &self.data.params, self.hyperparameters.block_count, self.hyperparameters.embedding_length, self.hyperparameters.vocabulary_count, @@ -128,7 +121,8 @@ impl KnownModel for Llama { ) { let input_len = input_tokens.len(); let session_len = session.n_past; - let ctx_size = self.params.context_size; + let params = &self.data.params; + let ctx_size = params.context_size; let Hyperparameters { vocabulary_count, @@ -151,7 +145,7 @@ impl KnownModel for Llama { let mut gf = ctx0.create_compute_graph(); for il in 0..block_count { - ctx0.set_offloading(self.params.should_offload(il)); + ctx0.set_offloading(params.should_offload(il)); let input_self_attention = input_layer.share(); let mut current: ggml::Tensor; @@ -166,7 +160,7 @@ impl KnownModel for Llama { // self-attention // compute Q and K and RoPE them - let overrides = self.params.rope_overrides.as_ref(); + let overrides = params.rope_overrides.as_ref(); let n_embd_head = embedding_length / head_count; let q_current = ctx0 .op_rope_inplace( @@ -360,16 +354,8 @@ impl KnownModel for Llama { ); } - fn hyperparameters(&self) -> &Self::Hyperparameters { - &self.hyperparameters - } - - fn tokenizer(&self) -> &Tokenizer { - &self.tokenizer - } - - fn context_size(&self) -> usize { - self.params.context_size + fn data(&self) -> &ModelData { + &self.data } fn bot_token_id(&self) -> Option { @@ -377,14 +363,14 @@ impl KnownModel for Llama { } fn eot_token_id(&self) -> TokenId { - self.tokenizer.id("".as_bytes()).unwrap_or(2) + self.tokenizer().id("".as_bytes()).unwrap_or(2) } - fn quantize_tensors() -> Vec { + fn quantize_tensors(&self) -> Vec { vec![Regex::new(".*weight").unwrap()] } - fn skip_quantize_tensors() -> Vec { + fn skip_quantize_tensors(&self) -> Vec { vec![] } @@ -393,26 +379,25 @@ impl KnownModel for Llama { } } -/// LLaMA [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) #[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct Hyperparameters { +struct Hyperparameters { /// Size of the model's vocabulary - pub vocabulary_count: usize, + vocabulary_count: usize, /// Size of the model's embedding layer - pub embedding_length: usize, + embedding_length: usize, /// The number of attention heads - pub head_count: usize, + head_count: usize, /// The number of grouped-query attention heads - pub head_count_kv: usize, + head_count_kv: usize, /// Number of layers in the model - pub block_count: usize, + block_count: usize, /// file_type - pub file_type: Option, + file_type: Option, /// The tensor data layout that this model was encoded with - pub tensor_data_layout: String, + tensor_data_layout: String, } -impl llm_base::Hyperparameters for Hyperparameters { - fn read_gguf(metadata: &Metadata) -> Result { +impl Hyperparameters { + pub fn read(metadata: &Metadata) -> Result { Ok(Self { // TODO: handle models without an embedded vocabulary vocabulary_count: metadata @@ -430,21 +415,8 @@ impl llm_base::Hyperparameters for Hyperparameters { }) } - fn write_gguf(&self, metadata: &mut Metadata) -> Result<(), HyperparametersWriteError> { - todo!() - } - - fn file_type(&self) -> Option { - self.file_type - } - - fn file_type_mut(&mut self) -> Option<&mut FileType> { - self.file_type.as_mut() - } -} -impl Hyperparameters { /// Returns the number of grouped-query attention heads. - pub fn grouped_query_attention(&self) -> usize { + fn grouped_query_attention(&self) -> usize { self.head_count / self.head_count_kv } } diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 2b7db5b8..cfeeefb6 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -5,7 +5,7 @@ // use llm_base::{ // ggml::{self}, // model::{common, HyperparametersWriteError}, -// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, LoadError, // ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer, // }; @@ -35,7 +35,7 @@ // unsafe impl Send for Mpt {} // unsafe impl Sync for Mpt {} -// impl KnownModel for Mpt { +// impl Model for Mpt { // type Hyperparameters = Hyperparameters; // fn new( From a4bbdbf4829f7e2eff21278fe5acd7010ffda99c Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 30 Oct 2023 00:02:01 +0100 Subject: [PATCH 23/33] feat: implement GGUF write / llm gguf rebuild --- binaries/llm-cli/src/cli_args.rs | 19 +- binaries/llm-cli/src/main.rs | 43 ++- crates/ggml/src/format/gguf/metadata.rs | 335 +++++++++++++----------- crates/ggml/src/format/gguf/mod.rs | 114 +++++++- crates/ggml/src/util.rs | 134 ++++++++-- crates/llm-base/src/quantize.rs | 1 - 6 files changed, 449 insertions(+), 197 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 97fa8a85..290d0630 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -24,8 +24,11 @@ pub enum Args { Perplexity(Box), #[command()] - /// Get information about a GGML model. - Info(Box), + /// Interact with a GGUF model. + Gguf { + #[command(subcommand)] + gguf: Gguf, + }, #[command()] /// Dumps the prompt to console and exits, first as a comma-separated list of token IDs @@ -111,6 +114,12 @@ pub struct Perplexity { pub prompt: Prompt, } +#[derive(Parser, Debug)] +pub enum Gguf { + Info(Info), + Rebuild(Rebuild), +} + #[derive(Parser, Debug)] pub struct Info { #[command(flatten)] @@ -125,6 +134,12 @@ pub struct Info { pub tokenizer: bool, } +#[derive(Parser, Debug)] +pub struct Rebuild { + pub input: PathBuf, + pub output: PathBuf, +} + #[derive(Parser, Debug)] pub struct PromptTokens { #[command(flatten)] diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index d19933af..5c7468e1 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -1,4 +1,8 @@ -use std::{convert::Infallible, fs::File, io::BufReader}; +use std::{ + convert::Infallible, + fs::File, + io::{BufReader, BufWriter, Read, Seek}, +}; use clap::Parser; use cli_args::Args; @@ -28,7 +32,7 @@ fn main() -> eyre::Result<()> { match args { Args::Infer(args) => infer(&args), Args::Perplexity(args) => perplexity(&args), - Args::Info(args) => info(&args), + Args::Gguf { gguf: args } => gguf(&args), Args::PromptTokens(args) => prompt_tokens(&args), Args::Repl(args) => interactive::repl(&args), Args::Chat(args) => interactive::chat(&args), @@ -128,6 +132,13 @@ fn perplexity(args: &cli_args::Perplexity) -> eyre::Result<()> { Ok(()) } +fn gguf(args: &cli_args::Gguf) -> eyre::Result<()> { + match args { + cli_args::Gguf::Info(args) => info(args), + cli_args::Gguf::Rebuild(args) => rebuild(args), + } +} + fn info(args: &cli_args::Info) -> eyre::Result<()> { let model_path = &args.model_and_tokenizer.model_path; @@ -174,10 +185,11 @@ fn info(args: &cli_args::Info) -> eyre::Result<()> { log::info!("Tensors:"); for (name, tensor) in &gguf.tensor_infos { log::info!( - "- {} ({:?} {:?})", + "- {} ({:?} {:?}) @ 0x{:X}", name, tensor.element_type, - tensor.dimensions + tensor.dimensions, + tensor.offset ); } } @@ -185,6 +197,29 @@ fn info(args: &cli_args::Info) -> eyre::Result<()> { Ok(()) } +fn rebuild(args: &cli_args::Rebuild) -> eyre::Result<()> { + let input = File::open(&args.input)?; + let mut reader = BufReader::new(&input); + let gguf = gguf::Gguf::load(&mut reader)?; + + let mut output = File::create(&args.output)?; + let mut writer = BufWriter::new(&mut output); + gguf.save(&mut writer, |writer, name, _info| { + let reader = &mut reader; + let original_info = gguf.tensor_infos.get(name).unwrap(); + + reader.seek(std::io::SeekFrom::Start( + gguf.tensor_data_position + original_info.offset, + ))?; + + std::io::copy(&mut reader.take(original_info.calc_size() as u64), writer)?; + + Ok(()) + })?; + + Ok(()) +} + fn prompt_tokens(args: &cli_args::PromptTokens) -> eyre::Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; let model = args.model_load.load(false)?; diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index 40a9be2f..5849b85e 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -1,4 +1,4 @@ -use std::io::BufRead; +use std::io::{self, BufRead, Write}; use indexmap::IndexMap; use thiserror::Error; @@ -37,7 +37,7 @@ impl Metadata { }) } - pub fn get_with_type<'a, T: MetadataValueTypeFromRustType>( + pub fn get_with_type<'a, T: ToMetadataValue>( &'a self, key: &'a str, getter: impl Fn(&MetadataValue) -> Option, @@ -50,7 +50,7 @@ impl Metadata { }) } - pub fn get_with_ref_type<'a, T: MetadataValueTypeFromRustType>( + pub fn get_with_ref_type<'a, T: ToMetadataValue>( &'a self, key: &'a str, getter: impl Fn(&MetadataValue) -> Option<&T>, @@ -63,7 +63,7 @@ impl Metadata { }) } - pub fn get_array_with_type<'a, T: MetadataValueTypeFromRustType>( + pub fn get_array_with_type<'a, T: ToMetadataValue>( &'a self, key: &'a str, getter: impl Fn(&MetadataValue) -> Option<&[T]>, @@ -139,19 +139,32 @@ pub enum MetadataValueType { /// Implemented in GGUFv2. Float64 = 12, } -pub trait MetadataValueTypeFromRustType { +pub trait ToMetadataValue { fn value_type() -> MetadataValueType; + fn to_value(self) -> MetadataValue; +} +pub trait ToMetadataArrayValue { + fn to_array_value(self) -> MetadataArrayValue; } macro_rules! impl_value_boilerplate { ($($value_type:ident($rust_type:ty)),*) => { $( - impl MetadataValueTypeFromRustType for $rust_type { + impl ToMetadataValue for $rust_type { fn value_type() -> MetadataValueType { MetadataValueType::$value_type } + + fn to_value(self) -> MetadataValue { + MetadataValue::$value_type(self) + } } - )* + impl ToMetadataArrayValue for Vec<$rust_type> { + fn to_array_value(self) -> MetadataArrayValue { + MetadataArrayValue::$value_type(self) + } + } + )* impl TryFrom for MetadataValueType { type Error = (); @@ -167,7 +180,19 @@ macro_rules! impl_value_boilerplate { Err(()) } } - + impl MetadataValueType { + fn read_value( + self, + ctx: &GgufContext, + reader: &mut dyn BufRead, + ) -> Result { + use MetadataValueType as MVT; + + Ok(match self { + $(MVT::$value_type => <$rust_type>::read(ctx, reader)?.to_value(),)* + }) + } + } #[derive(Debug, Clone, PartialEq)] pub enum MetadataValue { @@ -175,14 +200,31 @@ macro_rules! impl_value_boilerplate { $value_type($rust_type), )* } - - // Public impl MetadataValue { pub fn value_type(&self) -> MetadataValueType { match self { $(MetadataValue::$value_type(_) => MetadataValueType::$value_type),* } } + + fn write(&self, ctx: &GgufContext, writer: &mut dyn Write) -> io::Result<()> { + match self { + $(MetadataValue::$value_type(v) => v.write(ctx, writer)),* + } + } + } + + #[derive(Debug, Clone, PartialEq)] + pub enum MetadataArrayValue { + $($value_type(Vec<$rust_type>),)* + } + impl MetadataArrayValue { + /// Returns the length of the array. + pub fn len(&self) -> usize { + match self { + $(Self::$value_type(v) => v.len(),)* + } + } } }; } @@ -295,7 +337,6 @@ impl MetadataValue { } } } -// Private impl MetadataValue { pub(super) fn read_key_value( ctx: &GgufContext, @@ -304,105 +345,25 @@ impl MetadataValue { let key = util::read_string(reader, ctx.use_64_bit_length)?; let value_type = MetadataValueType::try_from(util::read_u32(reader)?) .expect("TODO: handle invalid value types"); - let value = Self::read_value(ctx, reader, value_type)?; + let value = value_type.read_value(ctx, reader)?; Ok((key, value)) } - fn read_value( - ctx: &GgufContext, - reader: &mut dyn BufRead, - value_type: MetadataValueType, - ) -> Result { - match value_type { - MetadataValueType::UInt8 => Self::read_u8(ctx, reader).map(MetadataValue::UInt8), - MetadataValueType::Int8 => Self::read_i8(ctx, reader).map(MetadataValue::Int8), - MetadataValueType::UInt16 => Self::read_u16(ctx, reader).map(MetadataValue::UInt16), - MetadataValueType::Int16 => Self::read_i16(ctx, reader).map(MetadataValue::Int16), - MetadataValueType::UInt32 => Self::read_u32(ctx, reader).map(MetadataValue::UInt32), - MetadataValueType::Int32 => Self::read_i32(ctx, reader).map(MetadataValue::Int32), - MetadataValueType::Float32 => Self::read_f32(ctx, reader).map(MetadataValue::Float32), - MetadataValueType::Bool => Self::read_bool(ctx, reader).map(MetadataValue::Bool), - MetadataValueType::String => Self::read_string(ctx, reader).map(MetadataValue::String), - MetadataValueType::Array => Self::read_array(ctx, reader).map(MetadataValue::Array), - MetadataValueType::UInt64 => Self::read_u64(ctx, reader).map(MetadataValue::UInt64), - MetadataValueType::Int64 => Self::read_i64(ctx, reader).map(MetadataValue::Int64), - MetadataValueType::Float64 => Self::read_f64(ctx, reader).map(MetadataValue::Float64), - } - } - - fn read_u8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_bytes::<1>(reader)?[0]) - } - - fn read_i8(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_bytes::<1>(reader)?[0] as i8) - } - - fn read_u16(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(u16::from_le_bytes(util::read_bytes::<2>(reader)?)) - } - - fn read_i16(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(i16::from_le_bytes(util::read_bytes::<2>(reader)?)) - } - - fn read_u32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_u32(reader)?) - } - - fn read_i32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_i32(reader)?) - } - - fn read_f32(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_f32(reader)?) - } - - fn read_bool(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_bool(reader)?) - } - - fn read_string(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_string(reader, ctx.use_64_bit_length)?) - } - - fn read_array( + pub(super) fn write_key_value( + &self, ctx: &GgufContext, - reader: &mut dyn BufRead, - ) -> Result { - MetadataArrayValue::read_value(ctx, reader) - } - - fn read_u64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_u64(reader)?) - } - - fn read_i64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_i64(reader)?) - } + writer: &mut dyn Write, + key: &str, + ) -> io::Result<()> { + util::write_string(writer, ctx.use_64_bit_length, key)?; + util::write_u32(writer, self.value_type() as u32)?; + self.write(ctx, writer)?; - fn read_f64(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { - Ok(util::read_f64(reader)?) + Ok(()) } } -#[derive(Debug, Clone, PartialEq)] -pub enum MetadataArrayValue { - UInt8(Vec), - Int8(Vec), - UInt16(Vec), - Int16(Vec), - UInt32(Vec), - Int32(Vec), - Float32(Vec), - Bool(Vec), - String(Vec), - Array(Vec), - UInt64(Vec), - Int64(Vec), - Float64(Vec), -} // Public impl MetadataArrayValue { pub fn as_uint8_array(&self) -> Option<&[u8]> { @@ -496,70 +457,126 @@ impl MetadataArrayValue { } } } -impl MetadataArrayValue { - fn read_value(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result { + +// Shared +trait ValueIO { + fn read(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result + where + Self: Sized; + fn write(&self, ctx: &GgufContext, writer: &mut dyn Write) -> io::Result<()>; +} +macro_rules! impl_value_io_boilerplate { + ($($value_type:ident($rust_type:ty, $read_method:ident, $write_method:ident)),*) => { + $( + impl ValueIO for $rust_type { + fn read(_ctx: &GgufContext, reader: &mut dyn BufRead) -> Result + where + Self: Sized, + { + Ok(util::$read_method(reader)?) + } + + fn write(&self, _ctx: &GgufContext, writer: &mut dyn Write) -> io::Result<()> { + util::$write_method(writer, *self) + } + } + )* + }; +} +impl_value_io_boilerplate! { + UInt8(u8, read_u8, write_u8), + Int8(i8, read_i8, write_i8), + UInt16(u16, read_u16, write_u16), + Int16(i16, read_i16, write_i16), + UInt32(u32, read_u32, write_u32), + Int32(i32, read_i32, write_i32), + Float32(f32, read_f32, write_f32), + Bool(bool, read_bool, write_bool), + UInt64(u64, read_u64, write_u64), + Int64(i64, read_i64, write_i64), + Float64(f64, read_f64, write_f64) +} +impl ValueIO for String { + fn read(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result + where + Self: Sized, + { + Ok(util::read_string(reader, ctx.use_64_bit_length)?) + } + + fn write(&self, ctx: &GgufContext, writer: &mut dyn Write) -> io::Result<()> { + util::write_string(writer, ctx.use_64_bit_length, self) + } +} +impl ValueIO for MetadataArrayValue { + fn read(ctx: &GgufContext, reader: &mut dyn BufRead) -> Result + where + Self: Sized, + { let value_type = MetadataValueType::try_from(util::read_u32(reader)?) .expect("TODO: handle invalid value types"); let length = util::read_length(reader, ctx.use_64_bit_length)?; - struct ArrayReader<'a> { - ctx: &'a GgufContext, - reader: &'a mut dyn BufRead, - length: usize, - } - impl ArrayReader<'_> { - fn read( - &mut self, - value_reader: impl Fn(&GgufContext, &mut dyn BufRead) -> Result, - value_constructor: impl Fn(Vec) -> MetadataArrayValue, - ) -> Result { - (0..self.length) - .map(|_| value_reader(self.ctx, self.reader)) - .collect::, _>>() - .map(value_constructor) - } - } + use MetadataValueType as MVT; + return match value_type { + MVT::UInt8 => read_array::(ctx, reader, length), + MVT::Int8 => read_array::(ctx, reader, length), + MVT::UInt16 => read_array::(ctx, reader, length), + MVT::Int16 => read_array::(ctx, reader, length), + MVT::UInt32 => read_array::(ctx, reader, length), + MVT::Int32 => read_array::(ctx, reader, length), + MVT::Float32 => read_array::(ctx, reader, length), + MVT::Bool => read_array::(ctx, reader, length), + MVT::String => read_array::(ctx, reader, length), + MVT::Array => read_array::(ctx, reader, length), + MVT::UInt64 => read_array::(ctx, reader, length), + MVT::Int64 => read_array::(ctx, reader, length), + MVT::Float64 => read_array::(ctx, reader, length), + }; - let mut reader = ArrayReader { - ctx, - reader, - length, + fn read_array( + ctx: &GgufContext, + reader: &mut dyn BufRead, + length: usize, + ) -> Result + where + Vec: ToMetadataArrayValue, + { + (0..length) + .map(|_| T::read(ctx, reader)) + .collect::, _>>() + .map(|v| v.to_array_value()) + } + } + + fn write(&self, ctx: &GgufContext, writer: &mut dyn Write) -> io::Result<()> { + return match self { + MetadataArrayValue::UInt8(v) => write_array(ctx, writer, v), + MetadataArrayValue::Int8(v) => write_array(ctx, writer, v), + MetadataArrayValue::UInt16(v) => write_array(ctx, writer, v), + MetadataArrayValue::Int16(v) => write_array(ctx, writer, v), + MetadataArrayValue::UInt32(v) => write_array(ctx, writer, v), + MetadataArrayValue::Int32(v) => write_array(ctx, writer, v), + MetadataArrayValue::Float32(v) => write_array(ctx, writer, v), + MetadataArrayValue::Bool(v) => write_array(ctx, writer, v), + MetadataArrayValue::String(v) => write_array(ctx, writer, v), + MetadataArrayValue::Array(v) => write_array(ctx, writer, v), + MetadataArrayValue::UInt64(v) => write_array(ctx, writer, v), + MetadataArrayValue::Int64(v) => write_array(ctx, writer, v), + MetadataArrayValue::Float64(v) => write_array(ctx, writer, v), }; - use MetadataValue as MV; - use MetadataValueType as MVT; - match value_type { - MVT::UInt8 => reader.read(MV::read_u8, Self::UInt8), - MVT::Int8 => reader.read(MV::read_i8, Self::Int8), - MVT::UInt16 => reader.read(MV::read_u16, Self::UInt16), - MVT::Int16 => reader.read(MV::read_i16, Self::Int16), - MVT::UInt32 => reader.read(MV::read_u32, Self::UInt32), - MVT::Int32 => reader.read(MV::read_i32, Self::Int32), - MVT::Float32 => reader.read(MV::read_f32, Self::Float32), - MVT::Bool => reader.read(MV::read_bool, Self::Bool), - MVT::String => reader.read(MV::read_string, Self::String), - MVT::Array => reader.read(MV::read_array, Self::Array), - MVT::UInt64 => reader.read(MV::read_u64, Self::UInt64), - MVT::Int64 => reader.read(MV::read_i64, Self::Int64), - MVT::Float64 => reader.read(MV::read_f64, Self::Float64), - } - } - - /// Returns the length of the array. - pub fn len(&self) -> usize { - match self { - Self::UInt8(v) => v.len(), - Self::Int8(v) => v.len(), - Self::UInt16(v) => v.len(), - Self::Int16(v) => v.len(), - Self::UInt32(v) => v.len(), - Self::Int32(v) => v.len(), - Self::Float32(v) => v.len(), - Self::Bool(v) => v.len(), - Self::String(v) => v.len(), - Self::Array(v) => v.len(), - Self::UInt64(v) => v.len(), - Self::Int64(v) => v.len(), - Self::Float64(v) => v.len(), + + fn write_array( + ctx: &GgufContext, + writer: &mut dyn Write, + array: &[T], + ) -> io::Result<()> { + util::write_u32(writer, T::value_type() as u32)?; + util::write_length(writer, ctx.use_64_bit_length, array.len())?; + for value in array { + value.write(ctx, writer)?; + } + Ok(()) } } } diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index ae878d8d..ebb2812e 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -1,10 +1,11 @@ #![allow(missing_docs)] -use std::io::{BufRead, Seek}; +use std::io::{BufRead, BufWriter, Seek, Write}; use super::{data_size, header_size, ContainerType, ContainerTypeReadError}; use crate::{util, ElementType}; +use ggml_sys::ggml_type; use indexmap::IndexMap; use thiserror::Error; @@ -90,10 +91,7 @@ impl Gguf { tensor_infos.insert(key, value); } - let tensor_data_position = { - let stream_position = reader.stream_position()?; - stream_position + (alignment - (stream_position % alignment)) % alignment - }; + let tensor_data_position = align_offset(reader.stream_position()?, alignment); Ok(Gguf { metadata, @@ -101,6 +99,93 @@ impl Gguf { tensor_data_position, }) } + + /// Saves the GGUF file to the given writer. + /// + /// `get_tensor_size` is a function that returns the size of a tensor's data in bytes. + /// `write_tensor_data` is a function that writes the tensor's data to the writer; the data + /// must be the same length as the value returned by `get_tensor_size`. + /// + /// The `offset` in `TensorInfo` will be ignored and the correct offset will be calculated + /// automatically. + pub fn save( + &self, + writer: &mut BufWriter, + mut write_tensor_data: impl FnMut(&mut BufWriter, &str, &TensorInfo) -> std::io::Result<()>, + ) -> std::io::Result<()> { + // Write header + let container = ContainerType::Gguf(2); + container.write(writer)?; + + let ctx = GgufContext { + use_64_bit_length: true, + }; + + util::write_length(writer, ctx.use_64_bit_length, self.tensor_infos.len())?; + util::write_length(writer, ctx.use_64_bit_length, self.metadata.0.len())?; + + // Write metadata + for (key, value) in &self.metadata.0 { + value.write_key_value(&ctx, writer, key)?; + } + + // Write tensor infos + let alignment = self + .metadata + .get_optional("general.alignment") + .and_then(|v| v.as_uint32()) + .unwrap_or(DEFAULT_ALIGNMENT) as u64; + + // Pre-plan the write before writing the tensor data. + #[derive(Debug)] + struct TensorWrite { + name: String, + info: TensorInfo, + size: usize, + } + let mut tensors = vec![]; + let mut next_offset = 0; + for (name, tensor_info) in &self.tensor_infos { + let size = tensor_info.calc_size(); + tensors.push(TensorWrite { + name: name.clone(), + info: TensorInfo { + offset: next_offset, + ..tensor_info.clone() + }, + size, + }); + + next_offset = align_offset(next_offset + size as u64, alignment); + } + + for write in &tensors { + write.info.write_name_value(&ctx, writer, &write.name)?; + } + + // Write tensors + let stream_position = writer.stream_position()?; + let tensor_data_position = align_offset(stream_position, alignment); + assert!(tensor_data_position > stream_position); + util::write_zero_bytes(writer, (tensor_data_position - stream_position) as usize)?; + + for write in &tensors { + write_tensor_data(writer, &write.name, &write.info)?; + + let stream_position = writer.stream_position()?; + assert!( + stream_position == tensor_data_position + write.info.offset + write.size as u64 + ); + let next_position = align_offset(stream_position, alignment); + util::write_zero_bytes(writer, (next_position - stream_position) as usize)?; + } + + Ok(()) + } +} + +fn align_offset(offset: u64, alignment: u64) -> u64 { + offset + (alignment - (offset % alignment)) % alignment } struct GgufContext { @@ -147,6 +232,25 @@ impl TensorInfo { )) } + fn write_name_value( + &self, + ctx: &GgufContext, + writer: &mut dyn Write, + name: &str, + ) -> std::io::Result<()> { + util::write_string(writer, ctx.use_64_bit_length, name)?; + + util::write_u32(writer, self.dimensions.len().try_into().unwrap())?; + for dimension in &self.dimensions { + util::write_length(writer, ctx.use_64_bit_length, *dimension)?; + } + + util::write_u32(writer, ggml_type::from(self.element_type) as u32)?; + util::write_u64(writer, self.offset)?; + + Ok(()) + } + /// Calculate the size of the tensor's values in bytes. pub fn calc_size(&self) -> usize { data_size(self.element_type, self.dimensions.iter().product()) diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index 7d65dfb2..ff73b7b7 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -2,7 +2,7 @@ use std::{ fmt, - io::{BufRead, Write}, + io::{self, BufRead, Write}, }; /// Helper struct that wraps the magic number of a file format, @@ -19,48 +19,69 @@ impl fmt::Debug for FileMagic { } } +/// +/// READERS +/// + /// Read a fixed-size array of bytes from a reader. -pub fn read_bytes(reader: &mut dyn BufRead) -> Result<[u8; N], std::io::Error> { +pub fn read_bytes(reader: &mut dyn BufRead) -> io::Result<[u8; N]> { let mut bytes = [0u8; N]; reader.read_exact(&mut bytes)?; Ok(bytes) } +/// Read a `i8` from a reader. +pub fn read_i8(reader: &mut dyn BufRead) -> io::Result { + Ok(i8::from_le_bytes(read_bytes::<1>(reader)?)) +} + +/// Read a `u8` from a reader. +pub fn read_u8(reader: &mut dyn BufRead) -> io::Result { + Ok(u8::from_le_bytes(read_bytes::<1>(reader)?)) +} + +/// Read a `i16` from a reader. +pub fn read_i16(reader: &mut dyn BufRead) -> io::Result { + Ok(i16::from_le_bytes(read_bytes::<2>(reader)?)) +} + +/// Read a `u16` from a reader. +pub fn read_u16(reader: &mut dyn BufRead) -> io::Result { + Ok(u16::from_le_bytes(read_bytes::<2>(reader)?)) +} + /// Read a `i32` from a reader. -pub fn read_i32(reader: &mut dyn BufRead) -> Result { +pub fn read_i32(reader: &mut dyn BufRead) -> io::Result { Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) } /// Read a `u32` from a reader. -pub fn read_u32(reader: &mut dyn BufRead) -> Result { +pub fn read_u32(reader: &mut dyn BufRead) -> io::Result { Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) } /// Read a `i64` from a reader. -pub fn read_i64(reader: &mut dyn BufRead) -> Result { +pub fn read_i64(reader: &mut dyn BufRead) -> io::Result { Ok(i64::from_le_bytes(read_bytes::<8>(reader)?)) } /// Read a `u64` from a reader. -pub fn read_u64(reader: &mut dyn BufRead) -> Result { +pub fn read_u64(reader: &mut dyn BufRead) -> io::Result { Ok(u64::from_le_bytes(read_bytes::<8>(reader)?)) } /// Read a `f32` from a reader. -pub fn read_f32(reader: &mut dyn BufRead) -> Result { +pub fn read_f32(reader: &mut dyn BufRead) -> io::Result { Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) } /// Read a `f64` from a reader. -pub fn read_f64(reader: &mut dyn BufRead) -> Result { +pub fn read_f64(reader: &mut dyn BufRead) -> io::Result { Ok(f64::from_le_bytes(read_bytes::<8>(reader)?)) } /// Read an integer (32-bit or 64-bit) from a reader, and convert it to a usize. -pub fn read_length( - reader: &mut dyn BufRead, - use_64_bit_length: bool, -) -> Result { +pub fn read_length(reader: &mut dyn BufRead, use_64_bit_length: bool) -> io::Result { let len: usize = if use_64_bit_length { read_u64(reader)?.try_into() } else { @@ -71,7 +92,7 @@ pub fn read_length( } /// Read a `bool` represented as an `i32` from a reader. -pub fn read_bool(reader: &mut dyn BufRead) -> Result { +pub fn read_bool(reader: &mut dyn BufRead) -> io::Result { let val = i32::from_le_bytes(read_bytes::<4>(reader)?); match val { 0 => Ok(false), @@ -84,20 +105,14 @@ pub fn read_bool(reader: &mut dyn BufRead) -> Result { } /// Read a variable-length array of bytes from a reader. -pub fn read_bytes_with_len( - reader: &mut dyn BufRead, - len: usize, -) -> Result, std::io::Error> { +pub fn read_bytes_with_len(reader: &mut dyn BufRead, len: usize) -> io::Result> { let mut bytes = vec![0u8; len]; reader.read_exact(&mut bytes)?; Ok(bytes) } /// Read a string from a reader. -pub fn read_string( - reader: &mut dyn BufRead, - use_64_bit_length: bool, -) -> Result { +pub fn read_string(reader: &mut dyn BufRead, use_64_bit_length: bool) -> io::Result { let len = read_length(reader, use_64_bit_length)?; let mut bytes = read_bytes_with_len(reader, len)?; // The GGUF C writer prior to `llama.cpp@103cfafc774f6feb3172b5d4d39681c965b17eba` @@ -111,29 +126,96 @@ pub fn read_string( .expect("string was not valid utf-8 (TODO: make this a library error)")) } +/// +/// WRITERS +/// + +/// Write a `i8` from a writer. +pub fn write_i8(writer: &mut dyn Write, value: i8) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Write a `u8` from a writer. +pub fn write_u8(writer: &mut dyn Write, value: u8) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Write a `i16` from a writer. +pub fn write_i16(writer: &mut dyn Write, value: i16) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Write a `u16` from a writer. +pub fn write_u16(writer: &mut dyn Write, value: u16) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + /// Write a `i32` from a writer. -pub fn write_i32(writer: &mut dyn Write, value: i32) -> Result<(), std::io::Error> { +pub fn write_i32(writer: &mut dyn Write, value: i32) -> io::Result<()> { writer.write_all(&value.to_le_bytes()) } /// Write a `u32` from a writer. -pub fn write_u32(writer: &mut dyn Write, value: u32) -> Result<(), std::io::Error> { +pub fn write_u32(writer: &mut dyn Write, value: u32) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Write a `i64` from a writer. +pub fn write_i64(writer: &mut dyn Write, value: i64) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Write a `u64` from a writer. +pub fn write_u64(writer: &mut dyn Write, value: u64) -> io::Result<()> { writer.write_all(&value.to_le_bytes()) } /// Write a `f32` from a writer. -pub fn write_f32(writer: &mut dyn Write, value: f32) -> Result<(), std::io::Error> { +pub fn write_f32(writer: &mut dyn Write, value: f32) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Write a `f64` from a writer. +pub fn write_f64(writer: &mut dyn Write, value: f64) -> io::Result<()> { writer.write_all(&value.to_le_bytes()) } /// Write a `bool` represented as an `i32` to a writer. -pub fn write_bool(writer: &mut dyn Write, value: bool) -> Result<(), std::io::Error> { +pub fn write_bool(writer: &mut dyn Write, value: bool) -> io::Result<()> { let int_value: i32 = if value { 1 } else { 0 }; writer.write_all(&int_value.to_le_bytes()) } +/// Write an integer (32-bit or 64-bit) to a writer, and convert it from a usize. +pub fn write_length(writer: &mut dyn Write, use_64_bit_length: bool, len: usize) -> io::Result<()> { + if use_64_bit_length { + write_u64(writer, len as u64) + } else { + write_u32(writer, len as u32) + } +} + +/// Read a string from a reader. +pub fn write_string( + writer: &mut dyn Write, + use_64_bit_length: bool, + value: &str, +) -> io::Result<()> { + write_length(writer, use_64_bit_length, value.len())?; + writer.write_all(value.as_bytes()) +} + +/// Write N zero bytes to a writer. +// TODO: is there a more efficient way to do this? +pub fn write_zero_bytes(writer: &mut dyn Write, n: usize) -> io::Result<()> { + for _ in 0..n { + writer.write_all(&[0u8])?; + } + Ok(()) +} + // NOTE: Implementation from #![feature(buf_read_has_data_left)] /// Check if there is any data left in the reader. -pub fn has_data_left(reader: &mut impl BufRead) -> Result { +pub fn has_data_left(reader: &mut impl BufRead) -> io::Result { reader.fill_buf().map(|b| !b.is_empty()) } diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index faf1ee2a..0aa8cc61 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -16,7 +16,6 @@ use std::{ use thiserror::Error; #[derive(Clone, Debug)] - /// Progress of quantization. pub enum QuantizeProgress<'a> { /// Hyperparameters have been loaded. From df1aa0e4e35b9226591232be18dce35b3bc48ec2 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 30 Oct 2023 00:21:06 +0100 Subject: [PATCH 24/33] feat(llm): implement gguf add-hf-tokenizer --- binaries/llm-cli/src/cli_args.rs | 8 ++++++ binaries/llm-cli/src/main.rs | 34 ++++++++++++++++++++++--- crates/ggml/src/format/gguf/metadata.rs | 6 ++++- crates/llm-base/src/tokenizer/mod.rs | 1 + 4 files changed, 45 insertions(+), 4 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 290d0630..6c72ff9e 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -118,6 +118,7 @@ pub struct Perplexity { pub enum Gguf { Info(Info), Rebuild(Rebuild), + AddHfTokenizer(AddHfTokenizer), } #[derive(Parser, Debug)] @@ -140,6 +141,13 @@ pub struct Rebuild { pub output: PathBuf, } +#[derive(Parser, Debug)] +pub struct AddHfTokenizer { + pub input: PathBuf, + pub output: PathBuf, + pub tokenizer: String, +} + #[derive(Parser, Debug)] pub struct PromptTokens { #[command(flatten)] diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 5c7468e1..f3a9adf3 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -2,6 +2,7 @@ use std::{ convert::Infallible, fs::File, io::{BufReader, BufWriter, Read, Seek}, + path::Path, }; use clap::Parser; @@ -136,6 +137,7 @@ fn gguf(args: &cli_args::Gguf) -> eyre::Result<()> { match args { cli_args::Gguf::Info(args) => info(args), cli_args::Gguf::Rebuild(args) => rebuild(args), + cli_args::Gguf::AddHfTokenizer(args) => add_hf_tokenizer(args), } } @@ -198,12 +200,38 @@ fn info(args: &cli_args::Info) -> eyre::Result<()> { } fn rebuild(args: &cli_args::Rebuild) -> eyre::Result<()> { - let input = File::open(&args.input)?; + rebuild_with_mutation(&args.input, &args.output, |_| Ok(())) +} + +fn add_hf_tokenizer(args: &cli_args::AddHfTokenizer) -> eyre::Result<()> { + let tokenizer = + llm::tokenizer::huggingface_tokenizers::Tokenizer::from_pretrained(&args.tokenizer, None) + .unwrap(); + + rebuild_with_mutation(&args.input, &args.output, move |gguf| { + let tokenizer = tokenizer.to_string(false).unwrap(); + gguf.metadata + .insert("tokenizer.huggingface.json", tokenizer); + + Ok(()) + }) +} + +fn rebuild_with_mutation( + input: &Path, + output: &Path, + mut mutator: impl FnMut(&mut gguf::Gguf) -> eyre::Result<()>, +) -> eyre::Result<()> { + eyre::ensure!(input != output, "input and output must be different files"); + + let input = File::open(input)?; let mut reader = BufReader::new(&input); - let gguf = gguf::Gguf::load(&mut reader)?; + let mut gguf = gguf::Gguf::load(&mut reader)?; - let mut output = File::create(&args.output)?; + let mut output = File::create(output)?; let mut writer = BufWriter::new(&mut output); + + mutator(&mut gguf)?; gguf.save(&mut writer, |writer, name, _info| { let reader = &mut reader; let original_info = gguf.tensor_infos.get(name).unwrap(); diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index 5849b85e..70e20347 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -76,7 +76,7 @@ impl Metadata { }) } - // TODO: consider + // TODO: consider finding a way to automate getting with traits pub fn get_str(&self, key: &str) -> Result<&str, MetadataError> { let metadata_value = self.get(key)?; Ok(metadata_value @@ -100,6 +100,10 @@ impl Metadata { }), } } + + pub fn insert(&mut self, key: &str, value: T) { + self.0.insert(key.to_owned(), value.to_value()); + } } #[repr(u32)] diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 0f53aba8..afa8c9d6 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -9,6 +9,7 @@ mod embedded; pub use embedded::*; mod huggingface; pub use huggingface::*; +pub use tokenizers as huggingface_tokenizers; /// The identifier of a token in a tokenizer. pub type TokenId = u32; From d5e7b61da123e47185155e73d2cf9f5627e6a6da Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 30 Oct 2023 18:36:44 +0100 Subject: [PATCH 25/33] fix(gguf): add support for ggufv3 --- crates/ggml/src/format/gguf/mod.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index ebb2812e..6543acf5 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -62,12 +62,19 @@ impl Gguf { ContainerTypeReadError::InvalidMagic(magic) => GgufLoadError::InvalidMagic(magic), ContainerTypeReadError::Io(io) => GgufLoadError::Io(io), })?; - if ![ContainerType::Gguf(1), ContainerType::Gguf(2)].contains(&container) { + if ![ + ContainerType::Gguf(1), + ContainerType::Gguf(2), + ContainerType::Gguf(3), + ] + .contains(&container) + { return Err(GgufLoadError::InvalidFormatVersion(container)); } let ctx = GgufContext { - use_64_bit_length: container == ContainerType::Gguf(2), + use_64_bit_length: container == ContainerType::Gguf(2) + || container == ContainerType::Gguf(3), }; let tensor_count = util::read_length(reader, ctx.use_64_bit_length)?; From 5457414d729e8840b382b84b4412e43866c15405 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 30 Oct 2023 18:37:09 +0100 Subject: [PATCH 26/33] fix(gguf): load bools correctly --- crates/ggml/src/util.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index ff73b7b7..bdddbeca 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -91,15 +91,16 @@ pub fn read_length(reader: &mut dyn BufRead, use_64_bit_length: bool) -> io::Res Ok(len) } -/// Read a `bool` represented as an `i32` from a reader. +/// Read a `bool` represented as a single byte from a reader. pub fn read_bool(reader: &mut dyn BufRead) -> io::Result { - let val = i32::from_le_bytes(read_bytes::<4>(reader)?); + let val = read_bytes::<1>(reader)?[0]; + match val { 0 => Ok(false), 1 => Ok(true), _ => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, - format!("Invalid i32 value for bool: '{}'", val), + format!("Invalid value for bool: '{}'", val), )), } } From 61140763e505acec5a838ee696a7ee905aaba639 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 30 Oct 2023 19:12:04 +0100 Subject: [PATCH 27/33] feat(llm): get GPT-NeoX loading again --- crates/ggml/src/format/gguf/mod.rs | 1 + crates/llm/Cargo.toml | 2 +- crates/models/gptneox/src/lib.rs | 970 ++++++++++++++--------------- crates/models/llama/src/lib.rs | 20 +- 4 files changed, 465 insertions(+), 528 deletions(-) diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index 6543acf5..c58e7276 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -13,6 +13,7 @@ mod metadata; pub use metadata::*; pub const DEFAULT_ALIGNMENT: u32 = 32; +pub const META_TENSOR_DATA_LAYOUT: &str = "Meta AI original pth"; #[derive(Debug, Error)] /// Errors that can occur while loading a model. diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 43f3eba7..d21b6bcd 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -35,7 +35,7 @@ default = ["models", "tokenizers-remote"] tokenizers-remote = ["llm-base/tokenizers-remote"] -models = ["llama"] #, "gpt2", "gptj", "bloom", "gptneox", "mpt"] +models = ["llama", "gptneox"] #, "gpt2", "gptj", "bloom", "mpt"] llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] gptj = ["dep:llm-gptj"] diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 436939ed..acc8f8cf 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -1,515 +1,455 @@ -// //! An implementation of [GPT-NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox) for the `llm` ecosystem. -// //! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model. -// #![deny(missing_docs)] - -// use std::error::Error; - -// use ggml::Tensor; -// use llm_base::{ -// ggml, -// model::{common, HyperparametersWriteError}, -// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, LoadError, -// ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, -// }; - -// /// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox) -// /// -// /// # Safety -// /// This implements [Send] and [Sync] as it is immutable after construction. -// pub struct GptNeoX { -// params: ModelParameters, - -// hyperparameters: Hyperparameters, -// tokenizer: Tokenizer, - -// // model-global weights -// // normalization gain & bias -// ln_f_g: Tensor, -// ln_f_b: Tensor, -// // weight token embeddings -// wte: Tensor, -// // language model head gain -// lmh_g: Tensor, - -// // weights for the model -// layers: Vec, - -// // must be kept alive for the model -// context: ModelContext, -// } - -// unsafe impl Send for GptNeoX {} -// unsafe impl Sync for GptNeoX {} - -// impl Model for GptNeoX { -// type Hyperparameters = Hyperparameters; - -// fn new( -// hyperparameters: Hyperparameters, -// params: ModelParameters, -// tokenizer: Tokenizer, -// tensor_loader: impl TensorLoader, -// ) -> Result -// where -// Self: Sized, -// { -// let mut tl = tensor_loader; - -// // model-global weights -// let wte = tl.load("gpt_neox.embed_in.weight")?; - -// let backend = params.backend(0); - -// let ln_f_g = tl -// .load("gpt_neox.final_layer_norm.weight")? -// .transfer_to(backend); -// let ln_f_b = tl -// .load("gpt_neox.final_layer_norm.bias")? -// .transfer_to(backend); -// let lmh_g = tl.load("embed_out.weight")?.transfer_to(backend); - -// let mut layers = Vec::new(); -// for i in 0..hyperparameters.n_layer { -// let backend = params.backend(i); -// let layer = Layer { -// ln_1_g: tl -// .load(&format!("gpt_neox.layers.{i}.input_layernorm.weight"))? -// .transfer_to(backend), -// ln_1_b: tl -// .load(&format!("gpt_neox.layers.{i}.input_layernorm.bias"))? -// .transfer_to(backend), - -// c_attn_attn_w: tl -// .load(&format!( -// "gpt_neox.layers.{i}.attention.query_key_value.weight" -// ))? -// .transfer_to(backend), -// c_attn_attn_b: tl -// .load(&format!( -// "gpt_neox.layers.{i}.attention.query_key_value.bias" -// ))? -// .transfer_to(backend), - -// c_attn_proj_w: tl -// .load(&format!("gpt_neox.layers.{i}.attention.dense.weight"))? -// .transfer_to(backend), -// c_attn_proj_b: tl -// .load(&format!("gpt_neox.layers.{i}.attention.dense.bias"))? -// .transfer_to(backend), - -// ln_2_g: tl -// .load(&format!( -// "gpt_neox.layers.{i}.post_attention_layernorm.weight" -// ))? -// .transfer_to(backend), -// ln_2_b: tl -// .load(&format!( -// "gpt_neox.layers.{i}.post_attention_layernorm.bias" -// ))? -// .transfer_to(backend), - -// c_mlp_fc_w: tl -// .load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight"))? -// .transfer_to(backend), -// c_mlp_fc_b: tl -// .load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias"))? -// .transfer_to(backend), - -// c_mlp_proj_w: tl -// .load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight"))? -// .transfer_to(backend), -// c_mlp_proj_b: tl -// .load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias"))? -// .transfer_to(backend), -// }; - -// layers.push(layer); -// } - -// let context = tl.finish(); - -// Ok(GptNeoX { -// hyperparameters, -// params, -// tokenizer, -// ln_f_g, -// ln_f_b, -// wte, -// lmh_g, -// layers, -// context, -// }) -// } - -// fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { -// InferenceSession::new( -// config, -// &self.params, -// self.hyperparameters.n_layer, -// self.hyperparameters.n_embd, -// self.hyperparameters.n_vocab, -// ) -// } - -// // allow snake case here as its a one-to-one mapping of the original names -// #[allow(non_snake_case)] -// fn evaluate( -// &self, -// session: &mut InferenceSession, -// input_tokens: &[TokenId], -// output_request: &mut OutputRequest, -// ) { -// let n = input_tokens.len(); -// let n_past = session.n_past; -// let n_ctx = self.params.context_size; - -// let Hyperparameters { -// n_embd, -// n_head, -// n_vocab, -// n_layer, -// n_rot, -// use_parallel_residual, -// .. -// } = self.hyperparameters; - -// let outputs = session.compute(self.context.clone(), input_tokens, |builder| { -// let mut ctx0 = builder.ctx0.borrow_mut(); -// let embd = builder.embd; -// let mut input_layer = ctx0.op_get_rows(&self.wte, embd); -// let (memory_k_size, memory_v_size) = ( -// builder.memory_k.element_size(), -// builder.memory_v.element_size(), -// ); - -// let mut gf = ctx0.create_compute_graph(); - -// for il in 0..n_layer { -// ctx0.set_offloading(self.params.should_offload(il)); -// // attention uses first scratch buffer -// ctx0.use_scratch(builder.get_scratch(0)); - -// // self-attention -// let mut current = ctx0.op_norm(&input_layer); -// current = ctx0.op_add( -// &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), -// &self.layers[il].ln_1_b, -// ); - -// // self-attention compute QKV -// current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t); -// current = ctx0.op_add(¤t, &self.layers[il].c_attn_attn_b); - -// let nb = current.get_nb()[1]; -// let f32_size = std::mem::size_of::(); - -// let mut qcur = ctx0.op_cont(&ctx0.op_view_3d( -// ¤t, -// (n_embd / n_head, n_head, n), -// (nb / n_head, nb), -// 0, -// )); -// let mut kcur = ctx0.op_cont(&ctx0.op_view_3d( -// ¤t, -// (n_embd / n_head, n_head, n), -// (nb / n_head, nb), -// f32_size * n_embd / n_head, -// )); -// let mut vcur = ctx0.op_cont(&ctx0.op_view_3d( -// ¤t, -// (n_embd / n_head, n_head, n), -// (nb / n_head, nb), -// 2 * f32_size * n_embd / n_head, -// )); - -// // self-attention using mode = 2 for GPT-NeoX mode -// let overrides = self.params.rope_overrides.as_ref(); -// qcur = ctx0.op_rope_inplace(&qcur, n_past, n_rot, 2, overrides); -// kcur = ctx0.op_rope_inplace(&kcur, n_past, n_rot, 2, overrides); - -// // store key and value to memory -// vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, n_embd, n)); - -// let k = ctx0.op_view_1d( -// builder.memory_k, -// n * n_embd, -// (memory_k_size * n_embd) * (il * n_ctx + n_past), -// ); - -// let v = ctx0.op_view_2d( -// builder.memory_v, -// (n, n_embd), -// n_ctx * memory_v_size, -// (il * n_ctx) * memory_v_size * n_embd + n_past * memory_v_size, -// ); - -// gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); -// gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); - -// // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) -// let Q = ctx0.op_permute(&qcur, (0, 2, 1, 3)); -// // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) -// let K = ctx0.op_permute( -// &ctx0.op_reshape_3d( -// &ctx0.op_view_1d( -// builder.memory_k, -// (n_past + n) * n_embd, -// il * n_ctx * memory_k_size * n_embd, -// ), -// n_embd / n_head, -// n_head, -// n_past + n, -// ), -// (0, 2, 1, 3), -// ); - -// // K * Q -// let KQ = ctx0.op_mul_mat(&K, &Q); - -// // KQ_scaled = KQ / sqrt(n_embd/n_head) -// let KQ_scaled = ctx0.op_scale_inplace( -// &KQ, -// &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)), -// ); - -// // KQ_masked = mask_past(KQ_scaled) -// let KQ_masked = ctx0.op_diag_mask_inf_inplace(&KQ_scaled, n_past); - -// // KQ = soft_max(KQ_masked) -// let KQ_softmax = ctx0.op_soft_max_inplace(&KQ_masked); - -// // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() -// let V = ctx0.op_view_3d( -// builder.memory_v, -// (n_past + n, n_embd / n_head, n_head), -// ( -// n_ctx * memory_v_size, -// n_ctx * memory_v_size * n_embd / n_head, -// ), -// il * n_ctx * memory_v_size * n_embd, -// ); - -// // KQV = transpose(V) * KQ_soft_max -// let KQV = ctx0.op_mul_mat(&V, &KQ_softmax); -// // KQV_merged = KQV.permute(0, 2, 1, 3) -// let KQV_merged = ctx0.op_permute(&KQV, (0, 2, 1, 3)); - -// // cur = KQV_merged.contiguous().view(n_embd, N) -// current = ctx0.op_cpy(&KQV_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); - -// // self-attention projection -// current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); -// current = ctx0.op_add(¤t, &self.layers[il].c_attn_proj_b); - -// // use the second scratch for the feed forward -// ctx0.use_scratch(builder.get_scratch(1)); - -// let feedforward_input: Tensor; -// if !use_parallel_residual { -// feedforward_input = ctx0.op_add(¤t, &input_layer); -// current = feed_forward_network(&ctx0, &self.layers[il], &feedforward_input); -// // input for next layer -// input_layer = ctx0.op_add(¤t, &feedforward_input); -// } else { -// // calculate with parallel residual -// feedforward_input = current.share(); - -// // this is independent of the self-attention result, so it could be done in parallel to the self-attention -// // note here we pass inpL instead of cur -// current = feed_forward_network(&ctx0, &self.layers[il], &input_layer); - -// // layer input + FF -// current = ctx0.op_add(¤t, &feedforward_input); - -// // input for next layer -// input_layer = ctx0.op_add(¤t, &input_layer); -// } -// } - -// // use the first scratch for the norm -// ctx0.use_scratch(builder.get_scratch(0)); - -// // normalize the output -// input_layer = ctx0.op_norm(&input_layer); -// // inpL = ln_f_g*inpL + ln_f_b -// input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); - -// let embeddings_tensor: ggml::Tensor = input_layer.share(); - -// // Disable the scratchbuffer -// ctx0.use_scratch(None); -// ctx0.set_offloading(false); -// // apply language model head -// input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); - -// ( -// gf, -// GraphOutputs { -// result: input_layer, -// embedding_result: embeddings_tensor, -// }, -// ) -// }); - -// // finish evaluation -// common::read_last_token(session, &outputs.result, n_vocab, n); -// common::extract_logits(output_request, &outputs.result, n_vocab, n); -// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); -// } - -// fn hyperparameters(&self) -> &Self::Hyperparameters { -// &self.hyperparameters -// } - -// fn tokenizer(&self) -> &Tokenizer { -// &self.tokenizer -// } - -// fn context_size(&self) -> usize { -// self.params.context_size -// } - -// fn bot_token_id(&self) -> Option { -// None -// } - -// fn eot_token_id(&self) -> TokenId { -// self.tokenizer.id("<|endoftext|>".as_bytes()).unwrap() -// } - -// fn quantize_tensors() -> Vec { -// vec![Regex::new(".*weight").unwrap()] -// } - -// fn skip_quantize_tensors() -> Vec { -// vec![] -// } - -// fn supports_rewind(&self) -> bool { -// true -// } -// } - -// /// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -// #[derive(Debug, PartialEq, Eq, Clone, Copy)] -// pub struct Hyperparameters { -// /// Size of the model's vocabulary -// pub n_vocab: usize, -// /// Size of the model's context -// pub n_ctx: usize, -// /// Size of the model's embedding layer -// pub n_embd: usize, -// /// n_head -// pub n_head: usize, -// /// Number of layers in the model -// pub n_layer: usize, -// /// n_rot -// pub n_rot: usize, -// /// Whether to use a "parallel" formulation in each Transformer layer. -// /// This is on for most models, but is off for some e.g. RedPajama. -// pub use_parallel_residual: bool, -// /// file_type -// pub file_type: FileType, -// } - -// impl Default for Hyperparameters { -// fn default() -> Self { -// Self { -// n_vocab: Default::default(), -// n_ctx: Default::default(), -// n_embd: Default::default(), -// n_head: Default::default(), -// n_layer: Default::default(), -// n_rot: Default::default(), -// file_type: Default::default(), -// use_parallel_residual: true, -// } -// } -// } - -// impl llm_base::Hyperparameters for Hyperparameters { -// fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { -// Ok(Hyperparameters { -// n_vocab: util::read_i32(reader)?.try_into()?, -// n_ctx: util::read_i32(reader)?.try_into()?, -// n_embd: util::read_i32(reader)?.try_into()?, -// n_head: util::read_i32(reader)?.try_into()?, -// n_layer: util::read_i32(reader)?.try_into()?, -// n_rot: util::read_i32(reader)?.try_into()?, -// use_parallel_residual: util::read_bool(reader)?, -// file_type: util::read_filetype(reader)?, -// }) -// } - -// fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { -// util::write_i32(writer, self.n_vocab.try_into()?)?; -// util::write_i32(writer, self.n_ctx.try_into()?)?; -// util::write_i32(writer, self.n_embd.try_into()?)?; -// util::write_i32(writer, self.n_head.try_into()?)?; -// util::write_i32(writer, self.n_layer.try_into()?)?; -// util::write_i32(writer, self.n_rot.try_into()?)?; -// util::write_bool(writer, self.use_parallel_residual)?; -// util::write_i32(writer, self.file_type.into())?; -// Ok(()) -// } - -// fn n_vocabulary(&self) -> usize { -// self.n_vocab -// } - -// fn file_type(&self) -> Option { -// Some(self.file_type) -// } - -// fn file_type_mut(&mut self) -> Option<&mut FileType> { -// Some(&mut self.file_type) -// } -// } - -// struct Layer { -// // pre-normalization -// ln_1_g: Tensor, -// ln_1_b: Tensor, - -// // attention -// c_attn_attn_w: Tensor, -// c_attn_attn_b: Tensor, - -// c_attn_proj_w: Tensor, -// c_attn_proj_b: Tensor, - -// // post normalization -// ln_2_g: Tensor, -// ln_2_b: Tensor, - -// // feed-forward -// c_mlp_fc_w: Tensor, -// c_mlp_fc_b: Tensor, - -// c_mlp_proj_w: Tensor, -// c_mlp_proj_b: Tensor, -// } - -// fn feed_forward_network(context: &ggml::Context, layer: &Layer, input: &Tensor) -> Tensor { -// let mut current = context.op_norm(input); - -// //gain and bias -// current = context.op_add(&context.op_mul(¤t, &layer.ln_2_g), &layer.ln_2_b); - -// // apply weights -// current = context.op_mul_mat(&layer.c_mlp_fc_w, ¤t); - -// // apply bias -// current = context.op_add(¤t, &layer.c_mlp_fc_b); - -// // GELU activation -// current = context.op_gelu(¤t); - -// // projection -// // cur = proj_w*cur + proj_b -// current = context.op_mul_mat(&layer.c_mlp_proj_w, ¤t); - -// current = context.op_add(¤t, &layer.c_mlp_proj_b); - -// current -// } +//! An implementation of [GPT-NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox) for the `llm` ecosystem. +//! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model. +#![deny(missing_docs)] + +use ggml::Tensor; +use llm_base::{ + ggml::{ + self, + format::gguf::{Metadata, MetadataValue, META_TENSOR_DATA_LAYOUT}, + }, + model::{common, HyperparametersReadError, ModelData, ModelLoadArgs, ModelLoadError}, + FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, ModelContext, + OutputRequest, Regex, TokenId, +}; + +/// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox) +/// +/// # Safety +/// This implements [Send] and [Sync] as it is immutable after construction. +pub struct GptNeoX { + data: ModelData, + hyperparameters: Hyperparameters, + + // model-global weights + // normalization gain & bias + ln_f_g: Tensor, + ln_f_b: Tensor, + // weight token embeddings + wte: Tensor, + // language model head gain + lmh_g: Tensor, + + // weights for the model + layers: Vec, + + // must be kept alive for the model + context: ModelContext, +} + +unsafe impl Send for GptNeoX {} +unsafe impl Sync for GptNeoX {} + +impl Model for GptNeoX { + fn new(args: ModelLoadArgs) -> Result { + let hyperparameters = Hyperparameters::read(&args.gguf.metadata)?; + + let mut tl = args.tensor_loader; + + // model-global weights + let wte = tl.load("token_embd.weight")?; + + let data = args.data; + let backend = data.params.backend(0); + + let ln_f_g = tl.load("output_norm.weight")?.transfer_to(backend); + let ln_f_b = tl.load("output_norm.bias")?.transfer_to(backend); + let lmh_g = tl.load("output.weight")?.transfer_to(backend); + + let mut layers = Vec::new(); + for i in 0..hyperparameters.block_count { + let backend = data.params.backend(i); + let block = Block { + ln_1_g: tl + .load(&format!("blk.{i}.attn_norm.weight"))? + .transfer_to(backend), + ln_1_b: tl + .load(&format!("blk.{i}.attn_norm.bias"))? + .transfer_to(backend), + + c_attn_attn_w: tl + .load(&format!("blk.{i}.attn_qkv.weight"))? + .transfer_to(backend), + c_attn_attn_b: tl + .load(&format!("blk.{i}.attn_qkv.bias"))? + .transfer_to(backend), + + c_attn_proj_w: tl + .load(&format!("blk.{i}.attn_output.weight"))? + .transfer_to(backend), + c_attn_proj_b: tl + .load(&format!("blk.{i}.attn_output.bias"))? + .transfer_to(backend), + + ln_2_g: tl + .load(&format!("blk.{i}.ffn_norm.weight"))? + .transfer_to(backend), + ln_2_b: tl + .load(&format!("blk.{i}.ffn_norm.bias"))? + .transfer_to(backend), + + c_mlp_fc_w: tl + .load(&format!("blk.{i}.ffn_up.weight"))? + .transfer_to(backend), + c_mlp_fc_b: tl + .load(&format!("blk.{i}.ffn_up.bias"))? + .transfer_to(backend), + + c_mlp_proj_w: tl + .load(&format!("blk.{i}.ffn_down.weight"))? + .transfer_to(backend), + c_mlp_proj_b: tl + .load(&format!("blk.{i}.ffn_down.bias"))? + .transfer_to(backend), + }; + + layers.push(block); + } + + let context = tl.finish(); + + Ok(GptNeoX { + data, + hyperparameters, + ln_f_g, + ln_f_b, + wte, + lmh_g, + layers, + context, + }) + } + + fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { + InferenceSession::new( + config, + &self.data.params, + self.hyperparameters.block_count, + self.hyperparameters.embedding_length, + self.tokenizer().len(), + ) + } + + // allow snake case here as its a one-to-one mapping of the original names + #[allow(non_snake_case)] + fn evaluate( + &self, + session: &mut InferenceSession, + input_tokens: &[TokenId], + output_request: &mut OutputRequest, + ) { + let n = input_tokens.len(); + let n_past = session.n_past; + let params = &self.data.params; + let ctx_size = params.context_size; + + let vocabulary_count = self.tokenizer().len(); + + let Hyperparameters { + embedding_length, + head_count, + block_count, + use_parallel_residual, + .. + } = self.hyperparameters; + + let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let mut ctx0 = builder.ctx0.borrow_mut(); + let embd = builder.embd; + let mut input_layer = ctx0.op_get_rows(&self.wte, embd); + let (memory_k_size, memory_v_size) = ( + builder.memory_k.element_size(), + builder.memory_v.element_size(), + ); + + let mut gf = ctx0.create_compute_graph(); + + for il in 0..block_count { + ctx0.set_offloading(params.should_offload(il)); + // attention uses first scratch buffer + ctx0.use_scratch(builder.get_scratch(0)); + + // self-attention + let mut current = ctx0.op_norm(&input_layer); + current = ctx0.op_add( + &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), + &self.layers[il].ln_1_b, + ); + + // self-attention compute QKV + current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t); + current = ctx0.op_add(¤t, &self.layers[il].c_attn_attn_b); + + let nb = current.get_nb()[1]; + let f32_size = std::mem::size_of::(); + + let mut qcur = ctx0.op_cont(&ctx0.op_view_3d( + ¤t, + (embedding_length / head_count, head_count, n), + (nb / head_count, nb), + 0, + )); + let mut kcur = ctx0.op_cont(&ctx0.op_view_3d( + ¤t, + (embedding_length / head_count, head_count, n), + (nb / head_count, nb), + f32_size * embedding_length / head_count, + )); + let mut vcur = ctx0.op_cont(&ctx0.op_view_3d( + ¤t, + (embedding_length / head_count, head_count, n), + (nb / head_count, nb), + 2 * f32_size * embedding_length / head_count, + )); + + // self-attention using mode = 2 for GPT-NeoX mode + let overrides = params.rope_overrides.as_ref(); + let n_embd_head = embedding_length / head_count; + qcur = ctx0.op_rope_inplace(&qcur, n_past, n_embd_head, 2, overrides); + kcur = ctx0.op_rope_inplace(&kcur, n_past, n_embd_head, 2, overrides); + + // store key and value to memory + vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, embedding_length, n)); + + let k = ctx0.op_view_1d( + builder.memory_k, + n * embedding_length, + (memory_k_size * embedding_length) * (il * ctx_size + n_past), + ); + + let v = ctx0.op_view_2d( + builder.memory_v, + (n, embedding_length), + ctx_size * memory_v_size, + (il * ctx_size) * memory_v_size * embedding_length + n_past * memory_v_size, + ); + + gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); + gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v)); + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + let Q = ctx0.op_permute(&qcur, (0, 2, 1, 3)); + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + let K = ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + builder.memory_k, + (n_past + n) * embedding_length, + il * ctx_size * memory_k_size * embedding_length, + ), + embedding_length / head_count, + head_count, + n_past + n, + ), + (0, 2, 1, 3), + ); + + // K * Q + let KQ = ctx0.op_mul_mat(&K, &Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + let KQ_scaled = ctx0.op_scale_inplace( + &KQ, + &ctx0.new_f32(1f32 / f32::sqrt(embedding_length as f32 / head_count as f32)), + ); + + // KQ_masked = mask_past(KQ_scaled) + let KQ_masked = ctx0.op_diag_mask_inf_inplace(&KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + let KQ_softmax = ctx0.op_soft_max_inplace(&KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + let V = ctx0.op_view_3d( + builder.memory_v, + (n_past + n, embedding_length / head_count, head_count), + ( + ctx_size * memory_v_size, + ctx_size * memory_v_size * embedding_length / head_count, + ), + il * ctx_size * memory_v_size * embedding_length, + ); + + // KQV = transpose(V) * KQ_soft_max + let KQV = ctx0.op_mul_mat(&V, &KQ_softmax); + // KQV_merged = KQV.permute(0, 2, 1, 3) + let KQV_merged = ctx0.op_permute(&KQV, (0, 2, 1, 3)); + + // cur = KQV_merged.contiguous().view(n_embd, N) + current = ctx0.op_cpy( + &KQV_merged, + &ctx0.new_tensor_2d(ggml::Type::F32, embedding_length, n), + ); + + // self-attention projection + current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); + current = ctx0.op_add(¤t, &self.layers[il].c_attn_proj_b); + + // use the second scratch for the feed forward + ctx0.use_scratch(builder.get_scratch(1)); + + let feedforward_input: Tensor; + if !use_parallel_residual { + feedforward_input = ctx0.op_add(¤t, &input_layer); + current = feed_forward_network(&ctx0, &self.layers[il], &feedforward_input); + // input for next layer + input_layer = ctx0.op_add(¤t, &feedforward_input); + } else { + // calculate with parallel residual + feedforward_input = current.share(); + + // this is independent of the self-attention result, so it could be done in parallel to the self-attention + // note here we pass inpL instead of cur + current = feed_forward_network(&ctx0, &self.layers[il], &input_layer); + + // layer input + FF + current = ctx0.op_add(¤t, &feedforward_input); + + // input for next layer + input_layer = ctx0.op_add(¤t, &input_layer); + } + } + + // use the first scratch for the norm + ctx0.use_scratch(builder.get_scratch(0)); + + // normalize the output + input_layer = ctx0.op_norm(&input_layer); + // inpL = ln_f_g*inpL + ln_f_b + input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); + + let embeddings_tensor: ggml::Tensor = input_layer.share(); + + // Disable the scratchbuffer + ctx0.use_scratch(None); + ctx0.set_offloading(false); + // apply language model head + input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); + + ( + gf, + GraphOutputs { + result: input_layer, + embedding_result: embeddings_tensor, + }, + ) + }); + + // finish evaluation + common::read_last_token(session, &outputs.result, vocabulary_count, n); + common::extract_logits(output_request, &outputs.result, vocabulary_count, n); + common::extract_embeddings( + output_request, + &outputs.embedding_result, + embedding_length, + n, + ); + } + + fn data(&self) -> &ModelData { + &self.data + } + + fn bot_token_id(&self) -> Option { + None + } + + fn eot_token_id(&self) -> TokenId { + self.tokenizer().id("<|endoftext|>".as_bytes()).unwrap() + } + + fn quantize_tensors(&self) -> Vec { + vec![Regex::new(".*weight").unwrap()] + } + + fn skip_quantize_tensors(&self) -> Vec { + vec![] + } + + fn supports_rewind(&self) -> bool { + true + } +} + +/// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Hyperparameters { + /// Size of the model's embedding layer + embedding_length: usize, + /// n_head + head_count: usize, + /// Number of blocks in the model + block_count: usize, + /// Whether to use a "parallel" formulation in each Transformer layer. + /// This is on for most models, but is off for some e.g. RedPajama. + use_parallel_residual: bool, + /// file_type + file_type: Option, + /// The tensor data layout that this model was encoded with + tensor_data_layout: String, +} + +impl Hyperparameters { + fn read(metadata: &Metadata) -> Result { + Ok(Self { + embedding_length: metadata.get_countable("gptneox.embedding_length")?, + head_count: metadata.get_countable("gptneox.attention.head_count")?, + block_count: metadata.get_countable("gptneox.block_count")?, + use_parallel_residual: metadata + .get_with_type("gptneox.use_parallel_residual", MetadataValue::as_bool)?, + file_type: FileType::read_for_hyperparameters(metadata)?, + tensor_data_layout: metadata + .get_str("llama.tensor_data_layout") + .unwrap_or(META_TENSOR_DATA_LAYOUT) + .to_string(), + }) + } +} + +struct Block { + // pre-normalization + ln_1_g: Tensor, + ln_1_b: Tensor, + + // attention + c_attn_attn_w: Tensor, + c_attn_attn_b: Tensor, + + c_attn_proj_w: Tensor, + c_attn_proj_b: Tensor, + + // post normalization + ln_2_g: Tensor, + ln_2_b: Tensor, + + // feed-forward + c_mlp_fc_w: Tensor, + c_mlp_fc_b: Tensor, + + c_mlp_proj_w: Tensor, + c_mlp_proj_b: Tensor, +} + +fn feed_forward_network(context: &ggml::Context, layer: &Block, input: &Tensor) -> Tensor { + let mut current = context.op_norm(input); + + //gain and bias + current = context.op_add(&context.op_mul(¤t, &layer.ln_2_g), &layer.ln_2_b); + + // apply weights + current = context.op_mul_mat(&layer.c_mlp_fc_w, ¤t); + + // apply bias + current = context.op_add(¤t, &layer.c_mlp_fc_b); + + // GELU activation + current = context.op_gelu(¤t); + + // projection + // cur = proj_w*cur + proj_b + current = context.op_mul_mat(&layer.c_mlp_proj_w, ¤t); + + current = context.op_add(¤t, &layer.c_mlp_proj_b); + + current +} diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 46076113..ab502aa2 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -2,14 +2,15 @@ #![deny(missing_docs)] use llm_base::{ - ggml::{self, format::gguf::Metadata}, + ggml::{ + self, + format::gguf::{Metadata, META_TENSOR_DATA_LAYOUT}, + }, model::{common, HyperparametersReadError, ModelData, ModelLoadArgs, ModelLoadError}, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, Model, ModelContext, OutputRequest, Regex, TokenId, }; -const META_TENSOR_DATA_LAYOUT: &str = "Meta AI original pth"; - /// The LLaMA model. Ref: [Introducing LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) /// /// # Safety @@ -108,7 +109,7 @@ impl Model for Llama { &self.data.params, self.hyperparameters.block_count, self.hyperparameters.embedding_length, - self.hyperparameters.vocabulary_count, + self.tokenizer().len(), ) } @@ -124,8 +125,9 @@ impl Model for Llama { let params = &self.data.params; let ctx_size = params.context_size; + let vocabulary_count = self.tokenizer().len(); + let Hyperparameters { - vocabulary_count, embedding_length, head_count, head_count_kv, @@ -381,15 +383,13 @@ impl Model for Llama { #[derive(Debug, Default, PartialEq, Eq, Clone)] struct Hyperparameters { - /// Size of the model's vocabulary - vocabulary_count: usize, /// Size of the model's embedding layer embedding_length: usize, /// The number of attention heads head_count: usize, /// The number of grouped-query attention heads head_count_kv: usize, - /// Number of layers in the model + /// Number of blocks in the model block_count: usize, /// file_type file_type: Option, @@ -399,10 +399,6 @@ struct Hyperparameters { impl Hyperparameters { pub fn read(metadata: &Metadata) -> Result { Ok(Self { - // TODO: handle models without an embedded vocabulary - vocabulary_count: metadata - .get_with_ref_type("tokenizer.ggml.tokens", |v| v.as_array())? - .len(), embedding_length: metadata.get_countable("llama.embedding_length")?, head_count: metadata.get_countable("llama.attention.head_count")?, head_count_kv: metadata.get_countable("llama.attention.head_count_kv")?, From 8961ff78bfad453dfe78a93c540506c97192101e Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 30 Oct 2023 20:09:27 +0100 Subject: [PATCH 28/33] feat(llm): get GPT-NeoX closer to working --- crates/models/gptneox/src/lib.rs | 72 +++++++++++++++++++------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index acc8f8cf..1000d3a4 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -31,7 +31,7 @@ pub struct GptNeoX { lmh_g: Tensor, // weights for the model - layers: Vec, + blocks: Vec, // must be kept alive for the model context: ModelContext, @@ -56,7 +56,7 @@ impl Model for GptNeoX { let ln_f_b = tl.load("output_norm.bias")?.transfer_to(backend); let lmh_g = tl.load("output.weight")?.transfer_to(backend); - let mut layers = Vec::new(); + let mut blocks = Vec::new(); for i in 0..hyperparameters.block_count { let backend = data.params.backend(i); let block = Block { @@ -103,7 +103,7 @@ impl Model for GptNeoX { .transfer_to(backend), }; - layers.push(block); + blocks.push(block); } let context = tl.finish(); @@ -115,7 +115,7 @@ impl Model for GptNeoX { ln_f_b, wte, lmh_g, - layers, + blocks, context, }) } @@ -150,6 +150,7 @@ impl Model for GptNeoX { head_count, block_count, use_parallel_residual, + rope_dimension_count, .. } = self.hyperparameters; @@ -172,41 +173,44 @@ impl Model for GptNeoX { // self-attention let mut current = ctx0.op_norm(&input_layer); current = ctx0.op_add( - &ctx0.op_mul(¤t, &self.layers[il].ln_1_g), - &self.layers[il].ln_1_b, + &ctx0.op_mul(¤t, &ctx0.op_repeat(&self.blocks[il].ln_1_g, ¤t)), + &ctx0.op_repeat(&self.blocks[il].ln_1_b, ¤t), ); // self-attention compute QKV - current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_attn_attn_b); + current = ctx0.op_mul_mat(&self.blocks[il].c_attn_attn_w, ¤t); + current = ctx0.op_add( + ¤t, + &ctx0.op_repeat(&self.blocks[il].c_attn_attn_b, ¤t), + ); let nb = current.get_nb()[1]; let f32_size = std::mem::size_of::(); + let n_embd_head = embedding_length / head_count; let mut qcur = ctx0.op_cont(&ctx0.op_view_3d( ¤t, - (embedding_length / head_count, head_count, n), + (n_embd_head, head_count, n), (nb / head_count, nb), 0, )); let mut kcur = ctx0.op_cont(&ctx0.op_view_3d( ¤t, - (embedding_length / head_count, head_count, n), + (n_embd_head, head_count, n), (nb / head_count, nb), - f32_size * embedding_length / head_count, + f32_size * n_embd_head, )); let mut vcur = ctx0.op_cont(&ctx0.op_view_3d( ¤t, - (embedding_length / head_count, head_count, n), + (n_embd_head, head_count, n), (nb / head_count, nb), - 2 * f32_size * embedding_length / head_count, + 2 * f32_size * n_embd_head, )); // self-attention using mode = 2 for GPT-NeoX mode let overrides = params.rope_overrides.as_ref(); - let n_embd_head = embedding_length / head_count; - qcur = ctx0.op_rope_inplace(&qcur, n_past, n_embd_head, 2, overrides); - kcur = ctx0.op_rope_inplace(&kcur, n_past, n_embd_head, 2, overrides); + qcur = ctx0.op_rope_inplace(&qcur, n_past, rope_dimension_count, 2, overrides); + kcur = ctx0.op_rope_inplace(&kcur, n_past, rope_dimension_count, 2, overrides); // store key and value to memory vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, embedding_length, n)); @@ -282,8 +286,11 @@ impl Model for GptNeoX { ); // self-attention projection - current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t); - current = ctx0.op_add(¤t, &self.layers[il].c_attn_proj_b); + current = ctx0.op_mul_mat(&self.blocks[il].c_attn_proj_w, ¤t); + current = ctx0.op_add( + &ctx0.op_repeat(&self.blocks[il].c_attn_proj_b, ¤t), + ¤t, + ); // use the second scratch for the feed forward ctx0.use_scratch(builder.get_scratch(1)); @@ -291,7 +298,7 @@ impl Model for GptNeoX { let feedforward_input: Tensor; if !use_parallel_residual { feedforward_input = ctx0.op_add(¤t, &input_layer); - current = feed_forward_network(&ctx0, &self.layers[il], &feedforward_input); + current = feed_forward_network(&ctx0, &self.blocks[il], &feedforward_input); // input for next layer input_layer = ctx0.op_add(¤t, &feedforward_input); } else { @@ -300,7 +307,7 @@ impl Model for GptNeoX { // this is independent of the self-attention result, so it could be done in parallel to the self-attention // note here we pass inpL instead of cur - current = feed_forward_network(&ctx0, &self.layers[il], &input_layer); + current = feed_forward_network(&ctx0, &self.blocks[il], &input_layer); // layer input + FF current = ctx0.op_add(¤t, &feedforward_input); @@ -316,7 +323,10 @@ impl Model for GptNeoX { // normalize the output input_layer = ctx0.op_norm(&input_layer); // inpL = ln_f_g*inpL + ln_f_b - input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); + input_layer = ctx0.op_add( + &ctx0.op_mul(&ctx0.op_repeat(&input_layer, &input_layer), &self.ln_f_g), + &ctx0.op_repeat(&self.ln_f_b, &input_layer), + ); let embeddings_tensor: ggml::Tensor = input_layer.share(); @@ -383,6 +393,8 @@ pub struct Hyperparameters { /// Whether to use a "parallel" formulation in each Transformer layer. /// This is on for most models, but is off for some e.g. RedPajama. use_parallel_residual: bool, + // RoPE dimension count + rope_dimension_count: usize, /// file_type file_type: Option, /// The tensor data layout that this model was encoded with @@ -397,6 +409,7 @@ impl Hyperparameters { block_count: metadata.get_countable("gptneox.block_count")?, use_parallel_residual: metadata .get_with_type("gptneox.use_parallel_residual", MetadataValue::as_bool)?, + rope_dimension_count: metadata.get_countable("gptneox.rope.dimension_count")?, file_type: FileType::read_for_hyperparameters(metadata)?, tensor_data_layout: metadata .get_str("llama.tensor_data_layout") @@ -430,26 +443,29 @@ struct Block { c_mlp_proj_b: Tensor, } -fn feed_forward_network(context: &ggml::Context, layer: &Block, input: &Tensor) -> Tensor { +fn feed_forward_network(context: &ggml::Context, block: &Block, input: &Tensor) -> Tensor { let mut current = context.op_norm(input); - //gain and bias - current = context.op_add(&context.op_mul(¤t, &layer.ln_2_g), &layer.ln_2_b); + // gain and bias + current = context.op_add( + &context.op_mul(&context.op_repeat(&block.ln_2_g, ¤t), ¤t), + &context.op_repeat(&block.ln_2_b, ¤t), + ); // apply weights - current = context.op_mul_mat(&layer.c_mlp_fc_w, ¤t); + current = context.op_mul_mat(&block.c_mlp_fc_w, ¤t); // apply bias - current = context.op_add(¤t, &layer.c_mlp_fc_b); + current = context.op_add(&context.op_repeat(&block.c_mlp_fc_b, ¤t), ¤t); // GELU activation current = context.op_gelu(¤t); // projection // cur = proj_w*cur + proj_b - current = context.op_mul_mat(&layer.c_mlp_proj_w, ¤t); + current = context.op_mul_mat(&block.c_mlp_proj_w, ¤t); - current = context.op_add(¤t, &layer.c_mlp_proj_b); + current = context.op_add(&context.op_repeat(&block.c_mlp_proj_b, ¤t), ¤t); current } From be709edd20cec26afa55b2be8a0eb99d9c162827 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 31 Oct 2023 00:09:21 +0100 Subject: [PATCH 29/33] feat(llm): more attempted GPT-NeoX fixes --- crates/models/gptneox/src/lib.rs | 53 +++++++++++++++++--------------- crates/models/llama/src/lib.rs | 1 + 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 1000d3a4..8786e512 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -138,8 +138,8 @@ impl Model for GptNeoX { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let n = input_tokens.len(); - let n_past = session.n_past; + let input_len = input_tokens.len(); + let session_len = session.n_past; let params = &self.data.params; let ctx_size = params.context_size; @@ -157,7 +157,9 @@ impl Model for GptNeoX { let outputs = session.compute(self.context.clone(), input_tokens, |builder| { let mut ctx0 = builder.ctx0.borrow_mut(); let embd = builder.embd; + let mut input_layer = ctx0.op_get_rows(&self.wte, embd); + let (memory_k_size, memory_v_size) = ( builder.memory_k.element_size(), builder.memory_v.element_size(), @@ -173,15 +175,15 @@ impl Model for GptNeoX { // self-attention let mut current = ctx0.op_norm(&input_layer); current = ctx0.op_add( - &ctx0.op_mul(¤t, &ctx0.op_repeat(&self.blocks[il].ln_1_g, ¤t)), + &ctx0.op_mul(&ctx0.op_repeat(&self.blocks[il].ln_1_g, ¤t), ¤t), &ctx0.op_repeat(&self.blocks[il].ln_1_b, ¤t), ); // self-attention compute QKV current = ctx0.op_mul_mat(&self.blocks[il].c_attn_attn_w, ¤t); current = ctx0.op_add( - ¤t, &ctx0.op_repeat(&self.blocks[il].c_attn_attn_b, ¤t), + ¤t, ); let nb = current.get_nb()[1]; @@ -190,42 +192,43 @@ impl Model for GptNeoX { let n_embd_head = embedding_length / head_count; let mut qcur = ctx0.op_cont(&ctx0.op_view_3d( ¤t, - (n_embd_head, head_count, n), + (n_embd_head, head_count, input_len), (nb / head_count, nb), 0, )); let mut kcur = ctx0.op_cont(&ctx0.op_view_3d( ¤t, - (n_embd_head, head_count, n), + (n_embd_head, head_count, input_len), (nb / head_count, nb), f32_size * n_embd_head, )); let mut vcur = ctx0.op_cont(&ctx0.op_view_3d( ¤t, - (n_embd_head, head_count, n), + (n_embd_head, head_count, input_len), (nb / head_count, nb), 2 * f32_size * n_embd_head, )); // self-attention using mode = 2 for GPT-NeoX mode let overrides = params.rope_overrides.as_ref(); - qcur = ctx0.op_rope_inplace(&qcur, n_past, rope_dimension_count, 2, overrides); - kcur = ctx0.op_rope_inplace(&kcur, n_past, rope_dimension_count, 2, overrides); + qcur = ctx0.op_rope_inplace(&qcur, session_len, rope_dimension_count, 2, overrides); + kcur = ctx0.op_rope_inplace(&kcur, session_len, rope_dimension_count, 2, overrides); // store key and value to memory - vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, embedding_length, n)); + vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, embedding_length, input_len)); let k = ctx0.op_view_1d( builder.memory_k, - n * embedding_length, - (memory_k_size * embedding_length) * (il * ctx_size + n_past), + input_len * embedding_length, + (memory_k_size * embedding_length) * (il * ctx_size + session_len), ); let v = ctx0.op_view_2d( builder.memory_v, - (n, embedding_length), + (input_len, embedding_length), ctx_size * memory_v_size, - (il * ctx_size) * memory_v_size * embedding_length + n_past * memory_v_size, + (il * ctx_size) * memory_v_size * embedding_length + + session_len * memory_v_size, ); gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); @@ -238,12 +241,12 @@ impl Model for GptNeoX { &ctx0.op_reshape_3d( &ctx0.op_view_1d( builder.memory_k, - (n_past + n) * embedding_length, + (session_len + input_len) * embedding_length, il * ctx_size * memory_k_size * embedding_length, ), - embedding_length / head_count, + n_embd_head, head_count, - n_past + n, + session_len + input_len, ), (0, 2, 1, 3), ); @@ -258,7 +261,7 @@ impl Model for GptNeoX { ); // KQ_masked = mask_past(KQ_scaled) - let KQ_masked = ctx0.op_diag_mask_inf_inplace(&KQ_scaled, n_past); + let KQ_masked = ctx0.op_diag_mask_inf_inplace(&KQ_scaled, session_len); // KQ = soft_max(KQ_masked) let KQ_softmax = ctx0.op_soft_max_inplace(&KQ_masked); @@ -266,10 +269,10 @@ impl Model for GptNeoX { // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() let V = ctx0.op_view_3d( builder.memory_v, - (n_past + n, embedding_length / head_count, head_count), + (session_len + input_len, n_embd_head, head_count), ( ctx_size * memory_v_size, - ctx_size * memory_v_size * embedding_length / head_count, + ctx_size * memory_v_size * n_embd_head, ), il * ctx_size * memory_v_size * embedding_length, ); @@ -282,7 +285,7 @@ impl Model for GptNeoX { // cur = KQV_merged.contiguous().view(n_embd, N) current = ctx0.op_cpy( &KQV_merged, - &ctx0.new_tensor_2d(ggml::Type::F32, embedding_length, n), + &ctx0.new_tensor_2d(ggml::Type::F32, embedding_length, input_len), ); // self-attention projection @@ -324,7 +327,7 @@ impl Model for GptNeoX { input_layer = ctx0.op_norm(&input_layer); // inpL = ln_f_g*inpL + ln_f_b input_layer = ctx0.op_add( - &ctx0.op_mul(&ctx0.op_repeat(&input_layer, &input_layer), &self.ln_f_g), + &ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer), &ctx0.op_repeat(&self.ln_f_b, &input_layer), ); @@ -346,13 +349,13 @@ impl Model for GptNeoX { }); // finish evaluation - common::read_last_token(session, &outputs.result, vocabulary_count, n); - common::extract_logits(output_request, &outputs.result, vocabulary_count, n); + common::read_last_token(session, &outputs.result, vocabulary_count, input_len); + common::extract_logits(output_request, &outputs.result, vocabulary_count, input_len); common::extract_embeddings( output_request, &outputs.embedding_result, embedding_length, - n, + input_len, ); } diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index ab502aa2..bbc843d0 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -135,6 +135,7 @@ impl Model for Llama { file_type: _, tensor_data_layout: _, } = self.hyperparameters; + let embedding_length_gqa = embedding_length / self.hyperparameters.grouped_query_attention(); From a728852615ffae037b3b6e4fd485e4678158e6df Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 31 Oct 2023 00:42:46 +0100 Subject: [PATCH 30/33] fix(cli): in info, elide known large items --- binaries/llm-cli/src/main.rs | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index f3a9adf3..08657207 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -1,5 +1,6 @@ use std::{ convert::Infallible, + fmt, fs::File, io::{BufReader, BufWriter, Read, Seek}, path::Path, @@ -9,7 +10,7 @@ use clap::Parser; use cli_args::Args; use color_eyre::eyre; use is_terminal::IsTerminal; -use llm::ggml_format::gguf; +use llm::ggml_format::gguf::{self, MetadataValue}; mod cli_args; mod interactive; @@ -150,11 +151,29 @@ fn info(args: &cli_args::Info) -> eyre::Result<()> { log::info!("Non-array parameters:"); for (metadata_key, metadata_value) in gguf.metadata.iter() { - if metadata_value.as_array().is_some() { - continue; + struct ValueDisplay<'a>(Option<&'a MetadataValue>); + impl fmt::Debug for ValueDisplay<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(value) = self.0 { + write!(f, "{:?}", value) + } else { + write!(f, "[elided due to size]") + } + } } - log::info!("- {}: {:?}", metadata_key, metadata_value); + let elide_due_to_size = + metadata_value.as_array().is_some() || metadata_key == "tokenizer.huggingface.json"; + + log::info!( + "- {}: {:?}", + metadata_key, + ValueDisplay(if elide_due_to_size { + None + } else { + Some(metadata_value) + }) + ); } if let Ok(tokenizer) = llm::tokenizer::GgufEmbeddedTokenizer::from_metadata(&gguf.metadata) { From 5ed38be5c6d7b820bbf3b77a482c9418c2e4440f Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 1 Nov 2023 00:39:25 +0100 Subject: [PATCH 31/33] fix(llmb): usercallback show error --- crates/llm-base/src/inference_session.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 5d20aaac..e45f0106 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -685,7 +685,7 @@ pub enum InferenceError { /// /// Note that this error *can* be ignored and inference can continue, but the results are not guaranteed to be sensical. EndOfText, - #[error("the user-specified callback returned an error")] + #[error("the user-specified callback returned an error: {0}")] /// The user-specified callback returned an error. UserCallback(Box), /// Sampling returned an error. From ab956c930dcd8610989d48b5102b0695007107de Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 12 Nov 2023 23:00:17 +0100 Subject: [PATCH 32/33] Merge in develop --- Cargo.lock | 106 +-- Cargo.toml | 4 +- binaries/generate-ggml-bindings/src/main.rs | 5 + binaries/llm-cli/src/cli_args.rs | 15 +- binaries/llm-cli/src/interactive.rs | 2 +- binaries/llm-test/configs/mpt.json | 2 +- binaries/llm-test/src/inference.rs | 10 +- crates/ggml/src/accelerator/metal.rs | 13 +- crates/ggml/src/accelerator/mod.rs | 1 + crates/ggml/src/context.rs | 89 +- crates/ggml/src/lib.rs | 154 +++- crates/ggml/src/tensor.rs | 19 +- crates/ggml/sys/build.rs | 13 +- crates/ggml/sys/llama-cpp | 2 +- crates/ggml/sys/src/cuda.rs | 43 +- crates/ggml/sys/src/lib.rs | 932 ++++++++++++++++---- crates/ggml/sys/src/llama.rs | 13 +- crates/ggml/sys/src/metal.rs | 36 +- crates/llm-base/src/inference_session.rs | 395 +++++++-- crates/llm-base/src/lib.rs | 2 +- crates/llm-base/src/loader.rs | 12 +- crates/llm-base/src/lora.rs | 3 +- crates/llm-base/src/samplers.rs | 52 +- crates/llm/Cargo.toml | 4 +- crates/llm/src/lib.rs | 1 + crates/models/bert/Cargo.toml | 14 + crates/models/bert/src/lib.rs | 464 ++++++++++ crates/models/bloom/src/lib.rs | 21 +- crates/models/falcon/src/lib.rs | 31 +- crates/models/gpt2/src/lib.rs | 28 +- crates/models/gptj/src/lib.rs | 22 +- crates/models/gptneox/src/lib.rs | 32 +- crates/models/llama/src/lib.rs | 28 +- crates/models/mpt/src/lib.rs | 32 +- 34 files changed, 2069 insertions(+), 531 deletions(-) create mode 100644 crates/models/bert/Cargo.toml create mode 100644 crates/models/bert/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index af44f099..17611878 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,9 +89,9 @@ dependencies = [ [[package]] name = "addr2line" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] @@ -312,7 +312,7 @@ dependencies = [ "log", "parking", "polling", - "rustix 0.37.21", + "rustix 0.37.27", "slab", "socket2", "waker-fn", @@ -340,7 +340,7 @@ dependencies = [ "cfg-if", "event-listener", "futures-lite", - "rustix 0.37.21", + "rustix 0.37.27", "signal-hook", "windows-sys 0.48.0", ] @@ -412,17 +412,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -431,9 +420,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" dependencies = [ "addr2line", "cc", @@ -502,9 +491,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "block" @@ -841,13 +830,13 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "colored" -version = "2.0.0" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3616f750b84d8f0de8a58bda93e08e2a81ad3f523089b05f1dffecab48c6cbd" +checksum = "2674ec482fbc38012cf31e6c42ba0177b431a0cb6f15fe40efa5aab1bda516f6" dependencies = [ - "atty", + "is-terminal", "lazy_static", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -1418,7 +1407,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", - "rustix 0.38.1", + "rustix 0.38.19", "windows-sys 0.48.0", ] @@ -1641,9 +1630,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.27.3" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" [[package]] name = "gl_generator" @@ -1784,15 +1773,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - [[package]] name = "hermit-abi" version = "0.3.1" @@ -1995,7 +1975,7 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", "windows-sys 0.48.0", ] @@ -2008,12 +1988,12 @@ checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" [[package]] name = "is-terminal" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24fddda5af7e54bf7da53067d6e802dbcc381d0a8eef629df528e3ebf68755cb" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.3.1", - "rustix 0.38.1", + "hermit-abi", + "rustix 0.38.19", "windows-sys 0.48.0", ] @@ -2101,9 +2081,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libloading" @@ -2123,9 +2103,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.4.3" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" +checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" [[package]] name = "llm" @@ -2134,6 +2114,7 @@ dependencies = [ "bytesize", "clap", "llm-base", + "llm-bert", "llm-bloom", "llm-falcon", "llm-gpt2", @@ -2171,6 +2152,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "llm-bert" +version = "0.2.0-dev" +dependencies = [ + "bytemuck", + "llm-base", + "tracing", +] + [[package]] name = "llm-bloom" version = "0.2.0-dev" @@ -2248,9 +2238,9 @@ dependencies = [ [[package]] name = "llm-samplers" -version = "0.0.6" +version = "0.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7553f60d113c9cdc6a5402456a31cd9a273bef79f6f16d8a4f7b4bedf5f754b2" +checksum = "7e85df656cd89e7702cb56171d75aa77c7bec828af7d2054d9987c34411cf896" dependencies = [ "anyhow", "num-traits", @@ -2586,7 +2576,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", ] @@ -2695,9 +2685,9 @@ dependencies = [ [[package]] name = "object" -version = "0.31.1" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" dependencies = [ "memchr", ] @@ -3176,9 +3166,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.37.21" +version = "0.37.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25693a73057a1b4cb56179dd3c7ea21a7c6c5ee7d85781f5749b46f34b79c" +checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" dependencies = [ "bitflags 1.3.2", "errno", @@ -3190,14 +3180,14 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.1" +version = "0.38.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc6396159432b5c8490d4e301d8c705f61860b8b6c863bf79942ce5401968f3" +checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "errno", "libc", - "linux-raw-sys 0.4.3", + "linux-raw-sys 0.4.11", "windows-sys 0.48.0", ] @@ -3510,9 +3500,9 @@ dependencies = [ [[package]] name = "spinoff" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fee259f96b31e7a18657d11741fe30d63f98e07de70e7a19d2b705ab9b331cdc" +checksum = "20aa2ed67fbb202e7b716ff8bfc6571dd9301617767380197d701c31124e88f6" dependencies = [ "colored", "once_cell", @@ -3604,7 +3594,7 @@ dependencies = [ "cfg-if", "fastrand", "redox_syscall 0.3.5", - "rustix 0.37.21", + "rustix 0.37.27", "windows-sys 0.48.0", ] diff --git a/Cargo.toml b/Cargo.toml index 9eb459ff..6c215d50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,12 +27,12 @@ anyhow = "1.0" rustyline = { version = "11.0.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } -spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] } +spinoff = { version = "0.8.0", default-features = false, features = ["dots2"] } clap = { version = "4.1.8", features = ["derive"] } memmap2 = "0.5.10" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = { version = "0.1", features = ["log"] } -llm-samplers = "=0.0.6" +llm-samplers = "=0.0.7" indexmap = "2.0.2" # Config for 'cargo dist' diff --git a/binaries/generate-ggml-bindings/src/main.rs b/binaries/generate-ggml-bindings/src/main.rs index 39acbb86..1878f471 100644 --- a/binaries/generate-ggml-bindings/src/main.rs +++ b/binaries/generate-ggml-bindings/src/main.rs @@ -27,6 +27,8 @@ fn generate_main(ggml_path: &Path, src_path: &Path) { .allowlist_file(r".*ggml.h") .header(ggml_path.join("k_quants.h").to_string_lossy()) .allowlist_file(r".*k_quants.h") + .header(ggml_path.join("ggml-alloc.h").to_string_lossy()) + .allowlist_file(r".*ggml-alloc.h") // Suppress some warnings .raw_line("#![allow(non_upper_case_globals)]") .raw_line("#![allow(non_camel_case_types)]") @@ -88,6 +90,9 @@ fn generate_metal(ggml_path: &Path, src_path: &Path) { generate_extra("metal", ggml_path, src_path, |b| { b.header(ggml_path.join("ggml-metal.h").to_string_lossy()) .allowlist_file(r".*ggml-metal\.h") + .raw_line("use super::ggml_cgraph;") + .raw_line("use super::ggml_log_callback;") + .raw_line("use super::ggml_tensor;") }); } diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 6c72ff9e..da513180 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -312,6 +312,15 @@ pub struct Generate { /// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen. /// p(0.95): The cumulative probability after which no more tokens are kept for sampling. /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// top_a (default: disabled) - This sampler prunes tokens that don't meet a threshold based on the most probable token. The formula is `a1 * pow(max_prob, a2)`. See https://github.com/BlinkDL/RWKV-LM#the-top-a-sampling-method for more information. + /// a1(0.0): Threshold scale. A reasonable value is 0.2. Setting either a1 or a2 to 0 disables the sampler. + /// a2(0.0): Threshold power. A reasonable value is 2. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// min_p (default: disabled) - This sampler prunes tokens that don't meet a certain percentage of the most probable token. For example if `p` is `0.05` then after `min_keep` is satisfied, other tokens must be at least 5% of the most probable token. See https://github.com/ggerganov/llama.cpp/issues/3483 for more information. + /// p(0.0): Probability threshold. 0.05 to 0.2 are good starting values to try. Setting this to 0 disables the sampler. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. #[arg(long = "sampler", short = 's', verbatim_doc_comment)] pub sampler_options: Vec, @@ -544,7 +553,7 @@ impl ModelLoad { let tokenizer_source = match self.model_and_tokenizer.to_source() { Ok(vs) => vs, Err(err) => { - if let Some(sp) = sp.take() { + if let Some(mut sp) = sp.take() { sp.fail(&format!("Failed to load tokenizer: {}", err)); } return Err(err); @@ -596,7 +605,7 @@ impl ModelLoad { file_size, tensor_count, } => { - if let Some(sp) = sp.take() { + if let Some(mut sp) = sp.take() { sp.success(&format!( "Loaded {tensor_count} tensors ({}) after {}ms", bytesize::to_string(file_size, false), @@ -611,7 +620,7 @@ impl ModelLoad { if model.is_err() { // If we've failed at loading the model, we probably haven't stopped the spinner yet. // Cancel it now if needed. - if let Some(sp) = sp { + if let Some(mut sp) = sp { sp.fail("Failed to load model") } } diff --git a/binaries/llm-cli/src/interactive.rs b/binaries/llm-cli/src/interactive.rs index 4657bc9d..3ad7e486 100644 --- a/binaries/llm-cli/src/interactive.rs +++ b/binaries/llm-cli/src/interactive.rs @@ -141,7 +141,7 @@ fn feed_prompt_with_spinner( prompt.insert(0, '\n'); } - let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); + let mut sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); let result = session.feed_prompt( model, &prompt, diff --git a/binaries/llm-test/configs/mpt.json b/binaries/llm-test/configs/mpt.json index 37b39bf3..31540573 100644 --- a/binaries/llm-test/configs/mpt.json +++ b/binaries/llm-test/configs/mpt.json @@ -5,7 +5,7 @@ { "Inference": { "input": "When a llama rides a crab, ", - "output": "When a llama rides a crab,  the llama is called the \"crab rider\".\nThe crabs are very popular in South America, especially Brazil. They have been used as transportation for many years and they can carry up to five people at once!", + "output": "When a llama rides a crab,  the llama is called the \"crab rider\"\nThe Llamas are an animal that can be found in The Maze. They have no special abilities, but they do drop Llamaskin and occasionally some other items when killed by players or monsters alike (see below). It's unknown if there was ever any sort of breeding system for these animals as it seems to only exist on this one world so far; however their existence has been confirmed through player reports from multiple worlds where people claim having seen them before being able see anything else about what happened after seeing just 1-2 at most per game session which makes me believe", "maximum_token_count": 128 } }, diff --git a/binaries/llm-test/src/inference.rs b/binaries/llm-test/src/inference.rs index a9ace889..3666167e 100644 --- a/binaries/llm-test/src/inference.rs +++ b/binaries/llm-test/src/inference.rs @@ -92,14 +92,14 @@ fn run_inference( // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` // at all #[derive(Debug, Default)] -struct DeterministicSampler(SampleGreedy); +struct DeterministicSampler(SampleGreedy); -impl Sampler for DeterministicSampler { +impl Sampler for DeterministicSampler { fn sample<'a>( &mut self, - res: &mut dyn HasSamplerResources, - logits: &'a mut Logits, - ) -> anyhow::Result<&'a mut Logits> { + res: &mut dyn HasSamplerResources, + logits: &'a mut Logits, + ) -> anyhow::Result<&'a mut Logits> { let mut flat_bias = Default::default(); // This might look a little weird, but it's necessary because the resource diff --git a/crates/ggml/src/accelerator/metal.rs b/crates/ggml/src/accelerator/metal.rs index 8fced466..a15e39f1 100644 --- a/crates/ggml/src/accelerator/metal.rs +++ b/crates/ggml/src/accelerator/metal.rs @@ -14,8 +14,8 @@ pub struct MetalContext { impl MetalContext { /// Create a new Metal context - pub fn new(n_threads: usize) -> Self { - let raw = unsafe { metal::ggml_metal_init(n_threads.try_into().unwrap()) }; + pub fn new() -> Self { + let raw = unsafe { metal::ggml_metal_init(1) }; MetalContext { contexts: vec![], @@ -83,19 +83,14 @@ impl MetalContext { unsafe { metal::ggml_metal_graph_compute( self.ptr.as_ptr(), - graph.inner as *mut ggml_sys::ggml_cgraph as *mut metal::ggml_cgraph, + graph.inner as *mut ggml_sys::ggml_cgraph, ); } } /// Reads a tensor from Metal pub fn get_tensor(&self, tensor: &Tensor) { - unsafe { - metal::ggml_metal_get_tensor( - self.ptr.as_ptr(), - tensor.ptr.as_ptr() as *mut metal::ggml_tensor, - ) - } + unsafe { metal::ggml_metal_get_tensor(self.ptr.as_ptr(), tensor.ptr.as_ptr()) } } } diff --git a/crates/ggml/src/accelerator/mod.rs b/crates/ggml/src/accelerator/mod.rs index 2e1cef17..731de9bc 100644 --- a/crates/ggml/src/accelerator/mod.rs +++ b/crates/ggml/src/accelerator/mod.rs @@ -71,6 +71,7 @@ pub fn initialize(device: i32) { //TODO: Make this configurable sys::cuda::ggml_init_cublas(); sys::cuda::ggml_cuda_set_main_device(device); + sys::cuda::ggml_cuda_set_mul_mat_q(true); let split = 1.0f32; sys::cuda::ggml_cuda_set_tensor_split(&split as *const f32); } diff --git a/crates/ggml/src/context.rs b/crates/ggml/src/context.rs index 2f2d04f0..96f81b4f 100644 --- a/crates/ggml/src/context.rs +++ b/crates/ggml/src/context.rs @@ -21,7 +21,7 @@ pub struct Context { /// allocated tensors. Tensors are owned by the object, so a [`Tensor`] /// contains a `Weak` reference underneath and doesn't let you do anything /// with it if the underlying context has been deallocated. - inner: Arc, + pub inner: Arc, /// The storage for this context. This is stored so that the buffer can be dropped when the context is dropped. storage: Option, @@ -31,7 +31,7 @@ pub struct Context { } /// Contains state shared between a context and its tensors -pub(crate) struct ContextInner { +pub struct ContextInner { pub ptr: NonNull, /// Offloaded tensors. Used to free them when the context is dropped. @@ -73,7 +73,12 @@ impl ContextInner { /// Controls how the context uses memory. pub enum ContextStorage { /// Use the provided buffer as memory. - Buffer(Buffer), + Buffer { + /// The buffer to use as memory. + buffer: Buffer, + /// Whether to allocate tensors into this buffer. + allocate: bool, + }, /// Use the provided memory mapped file as memory. Mmap(Mmap), /// Allocate `mem_size` bytes of memory. @@ -94,7 +99,10 @@ impl ContextStorage { /// Returns the `Buffer` if this is a `Buffer` variant. pub fn as_buffer(&self) -> Option<&Buffer> { match self { - Self::Buffer(v) => Some(v), + Self::Buffer { + buffer: v, + allocate: _, + } => Some(v), _ => None, } } @@ -115,7 +123,16 @@ impl PartialEq for ContextStorage { fn eq(&self, other: &Self) -> bool { use ContextStorage::*; match (self, other) { - (Buffer(l0), Buffer(r0)) => l0 == r0, + ( + Buffer { + buffer: l0, + allocate: l1, + }, + Buffer { + buffer: r0, + allocate: r1, + }, + ) => l0 == r0 && l1 == r1, (Mmap(l0), Mmap(r0)) => l0.as_ptr() == r0.as_ptr(), (Allocate { mem_size: l }, Allocate { mem_size: r }) => l == r, _ => false, @@ -130,10 +147,10 @@ impl Context { /// Creates a new [Context] with the given storage. pub fn new(storage: ContextStorage) -> Self { let init_params = match &storage { - ContextStorage::Buffer(buffer) => sys::ggml_init_params { + ContextStorage::Buffer { buffer, allocate } => sys::ggml_init_params { mem_size: buffer.size(), mem_buffer: buffer.data, - no_alloc: false, + no_alloc: !allocate, }, ContextStorage::Mmap(mmap) => sys::ggml_init_params { mem_size: mmap.len(), @@ -160,8 +177,8 @@ impl Context { /// Creates a new [Context] with the specified buffer. /// The buffer will be used by GGML. - pub fn new_with_buffer(buffer: Buffer) -> Self { - Self::new(ContextStorage::Buffer(buffer)) + pub fn new_with_buffer(buffer: Buffer, allocate: bool) -> Self { + Self::new(ContextStorage::Buffer { buffer, allocate }) } /// Creates a new [Context] with the specified memory mapped file. @@ -206,28 +223,6 @@ impl Context { unsafe { sys::ggml_used_mem(self.as_ptr()) } } - /// Sets the scratch buffer to be used by this [Context]. - /// - /// If `scratch_buffer` is `None`, the scratch buffer will be disabled. - pub fn use_scratch<'a>(&'a self, scratch_buffer: Option<&'a Buffer>) { - let (size, data) = if let Some(buffer) = scratch_buffer { - (buffer.size(), buffer.data) - } else { - (0, std::ptr::null_mut()) - }; - // SAFETY: this just passes (most likely uninitialized) memory buffer to the ggml C API - unsafe { - sys::ggml_set_scratch( - self.as_ptr(), - sys::ggml_scratch { - offs: 0, - size, - data, - }, - ); - } - } - /// Creates a new 1D tensor. pub fn new_tensor_1d(&self, typ: Type, ne0: usize) -> Tensor { let raw = unsafe { sys::ggml_new_tensor_1d(self.as_ptr(), typ.into(), usize_to_i64(ne0)) }; @@ -271,6 +266,12 @@ impl Context { pub fn storage(&self) -> &ContextStorage { self.storage.as_ref().unwrap() } + + /// Set all values of the tensor with the specified value. + pub fn set_f32(&self, a: &Tensor, x: f32) -> Tensor { + let raw = unsafe { sys::ggml_set_f32(a.ptr.as_ptr(), x) }; + self.new_tensor_raw(raw) + } } // Operations impl Context { @@ -288,7 +289,7 @@ impl Context { /// Creates a new tensor with the values of `a`, but normalized. pub fn op_norm(&self, a: &Tensor) -> Tensor { - let tensor = unsafe { sys::ggml_norm(self.as_ptr(), a.ptr.as_ptr()) }; + let tensor = unsafe { sys::ggml_norm(self.as_ptr(), a.ptr.as_ptr(), crate::DEFAULT_EPS) }; self.new_tensor_raw(tensor) } @@ -623,6 +624,30 @@ impl Context { }; self.new_tensor_raw(tensor) } + + /// Creates a new tensor with the square of `a` + pub fn op_sqr(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sqr(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Creates a new tensor with the square-root of `a` + pub fn op_sqrt(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sqrt(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Unknown + pub fn op_sum(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sum(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Unknown + pub fn op_div(&self, a: &Tensor, b: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_div(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } } // Public to this crate methods impl Context { diff --git a/crates/ggml/src/lib.rs b/crates/ggml/src/lib.rs index 1f4ac50f..fc99a0c0 100644 --- a/crates/ggml/src/lib.rs +++ b/crates/ggml/src/lib.rs @@ -10,6 +10,8 @@ use std::{ alloc::Layout, os::raw::{c_int, c_void}, + ptr::NonNull, + sync::Arc, }; mod context; @@ -41,7 +43,13 @@ pub const OBJECT_SIZE: usize = sys::GGML_OBJECT_SIZE; pub const MAX_NAME_LENGTH: usize = sys::GGML_MAX_NAME as usize; /// Default epsilon to use for RMS computation. -pub const DEFAULT_EPS: f32 = sys::llama::LLAMA_DEFAULT_RMS_EPS as f32; +pub const DEFAULT_EPS: f32 = 0.000005; + +/// Maximum number of nodes in a `ggml` graph. +pub const MAX_NODES: usize = sys::GGML_MAX_NODES as usize; + +/// Alignment used for the Tensors in a `ggml` graph. +pub const TENSOR_ALIGNMENT: usize = 32; /// Value overrides to use for RoPE. /// @@ -192,10 +200,8 @@ impl Type { } } -/// A buffer of memory that can be used as a scratch buffer for a [Context]. -/// -/// See [Context::use_scratch]. -#[derive(PartialEq, Eq)] +/// A buffer of memory that can be used as a buffer for a [Context] or [GraphAllocator]. +#[derive(PartialEq, Eq, Debug)] pub struct Buffer { data: *mut c_void, layout: Layout, @@ -216,10 +222,27 @@ impl Buffer { } } + /// Creates a new buffer of the specified size, without aligning it. + pub fn new_unaligned(size: usize) -> Self { + let layout = Layout::from_size_align(size, 1).unwrap(); + + unsafe { + Buffer { + data: std::alloc::alloc(layout).cast(), + layout, + } + } + } + /// Returns the size of the buffer in bytes pub fn size(&self) -> usize { self.layout.size() } + + /// Returns a pointer to the data in this buffer. + pub fn data(&mut self) -> *mut c_void { + self.data + } } impl Drop for Buffer { @@ -245,6 +268,37 @@ impl ComputationGraph { pub fn build_forward_expand(&mut self, tensor: &Tensor) { unsafe { sys::ggml_build_forward_expand(self.inner, tensor.ptr.as_ptr()) } } + + /// Returns the leafs in this graph. + pub fn leafs(&self, context: &Context) -> Vec { + let mut wrapped_leafs: Vec = vec![]; + unsafe { + for leaf in self.inner.as_ref().unwrap().leafs { + if !leaf.is_null() { + wrapped_leafs.push(Tensor { + ptr: NonNull::new(leaf).expect("Should not be null"), + inner: Arc::downgrade(&context.inner), + }) + } + } + wrapped_leafs + } + } + /// Returns the nodes in this graph. + pub fn nodes(&self, context: &Context) -> Vec { + let mut wrapped_nodes: Vec = vec![]; + unsafe { + for leaf in self.inner.as_ref().unwrap().leafs { + if !leaf.is_null() { + wrapped_nodes.push(Tensor { + ptr: NonNull::new(leaf).expect("Should not be null"), + inner: Arc::downgrade(&context.inner), + }) + } + } + wrapped_nodes + } + } } /// A `ggml` execution plan. Contains the information needed to execute a computation graph. @@ -262,30 +316,79 @@ impl GraphExecutionPlan { } } - /// Creates a [Type::I8] work buffer with size `plan.work_size` for this [GraphExecutionPlan] in the given [Context]. - fn create_work_buffer(&mut self, context: &Context) -> Tensor { - context.new_tensor_1d(Type::I8, self.inner.work_size) - } + /// Execute this [GraphExecutionPlan] in the given [Context]. + pub fn execute(&mut self, buffer: &mut Vec) { + if self.inner.work_size > 0 { + buffer.resize(self.inner.work_size, 0); + self.inner.work_data = buffer.as_mut_ptr().cast(); + } - /// Assign a work buffer to this [GraphExecutionPlan]. - fn assign_work_buffer(&mut self, buffer: &mut Tensor) { - assert!( - buffer.get_type() == Type::I8, - "Work buffer must be of type i8" - ); unsafe { - self.inner.work_data = buffer.data().cast(); + sys::ggml_graph_compute(self.inner_graph, &mut self.inner); } } +} - /// Execute this [GraphExecutionPlan] in the given [Context]. - pub fn execute(&mut self, context: &Context) { - let mut work_buffer = self.create_work_buffer(context); - self.assign_work_buffer(&mut work_buffer); +#[derive(PartialEq, Eq, Debug)] +/// Acts as a RAII-guard over a `sys::ggml_allocr`, allocating via +/// `ggml_allocr_new` and dropping via `ggml_allocr_free`. +/// Used to allocate the memory used by a computational graph. +pub struct GraphAllocator { + /// The underlying `sys::ggml_allocr` pointer. + pub ptr: *mut sys::ggml_allocr, + /// The buffer used by this allocator. + pub buffer: Buffer, +} - unsafe { - sys::ggml_graph_compute(self.inner_graph, &mut self.inner); - } +impl GraphAllocator { + /// Create a new allocator with the specified buffer. + pub fn new(buffer: Buffer, tensor_alignment: usize) -> Self { + let ptr = unsafe { sys::ggml_allocr_new(buffer.data, buffer.size(), tensor_alignment) }; + Self { ptr, buffer } + } + + /// Create a new allocator to measure a computational graph. + pub fn new_measurement(tensor_alignment: usize) -> Self { + let ptr = unsafe { sys::ggml_allocr_new_measure(tensor_alignment) }; + let buffer = Buffer::new(tensor_alignment); + Self { ptr, buffer } + } + + /// Allocates a computational graph in the allocator and returns the size in bytes. + pub fn allocate_graph(&self, graph: &ComputationGraph) -> usize { + unsafe { sys::ggml_allocr_alloc_graph(self.ptr, graph.inner) } + } + + /// Resets the allocator for a new forward pass. + pub fn reset(&self) { + unsafe { sys::ggml_allocr_reset(self.ptr) } + } + + /// Returns true if the allocator is in measuring mode. + pub fn in_measuring_mode(&self) -> bool { + unsafe { sys::ggml_allocr_is_measure(self.ptr) } + } + + /// Allocates memory for a given tensor in the allocator. + pub fn allocate(&self, tensor: &Tensor) { + unsafe { sys::ggml_allocr_alloc(self.ptr, tensor.ptr.as_ptr()) } + } + + /// Switches the buffer used by the allocator. + pub fn resize_buffer(&mut self, graph_size: usize, tensor_alignment: usize) { + // Free the old allocator + unsafe { sys::ggml_allocr_free(self.ptr) } + //Resize the buffer + self.buffer = Buffer::new_unaligned(graph_size); + // Create a new allocator with the new buffer + self.ptr = + unsafe { sys::ggml_allocr_new(self.buffer.data, self.buffer.size(), tensor_alignment) }; + } +} + +impl Drop for GraphAllocator { + fn drop(&mut self) { + unsafe { sys::ggml_allocr_free(self.ptr) } } } @@ -408,3 +511,8 @@ pub fn cpu_has_gpublas() -> bool { pub fn graph_overhead() -> usize { unsafe { sys::ggml_graph_overhead() } } + +/// Returns the tensor overhead in bytes. +pub fn tensor_overhead() -> usize { + unsafe { sys::ggml_tensor_overhead() } +} diff --git a/crates/ggml/src/tensor.rs b/crates/ggml/src/tensor.rs index 33d7114c..ee5354c2 100644 --- a/crates/ggml/src/tensor.rs +++ b/crates/ggml/src/tensor.rs @@ -52,6 +52,11 @@ impl Tensor { }) } + /// Returns true if the 'extra' field of this tensor is set. e.g. by ggml-cuda + pub fn has_extras(&self) -> bool { + self.with_alive_ctx(|| unsafe { !self.ptr.as_ref().extra.is_null() }) + } + /// Sets the tensor's acceleration backend and moves the tensor's data to the new backend. pub fn transfer_to(mut self, backend: Backend) -> Tensor { self.with_alive_ctx_mut(|t| { @@ -88,7 +93,7 @@ impl Tensor { self.with_alive_ctx(|| { #[cfg(feature = "cublas")] unsafe { - sys::cuda::ggml_cuda_assign_buffers(self.ptr.as_ptr()); + sys::cuda::ggml_cuda_assign_buffers_no_alloc(self.ptr.as_ptr()); } }) } @@ -111,6 +116,18 @@ impl Tensor { }) } + /// If ggml-sys is compiled with CUDA support, this function will set the tensor's scratch offset. + /// If not, this is a no-op. + #[allow(unused_variables)] + pub fn assign_scratch_offset(&self, offset: usize) { + self.with_alive_ctx(|| { + #[cfg(feature = "cublas")] + unsafe { + sys::cuda::ggml_cuda_assign_scratch_offset(self.ptr.as_ptr(), offset); + } + }) + } + /// Creates a shared copy of this tensor pointer. pub fn share(&self) -> Self { Tensor { diff --git a/crates/ggml/sys/build.rs b/crates/ggml/sys/build.rs index 736fa156..ba7e876b 100644 --- a/crates/ggml/sys/build.rs +++ b/crates/ggml/sys/build.rs @@ -12,8 +12,13 @@ fn main() { let mut builder = cc::Build::new(); let build = builder - .files(["llama-cpp/ggml.c", "llama-cpp/k_quants.c"]) + .files([ + "llama-cpp/ggml.c", + "llama-cpp/k_quants.c", + "llama-cpp/ggml-alloc.c", + ]) .define("GGML_USE_K_QUANTS", None) + .define("QK_K", Some("256")) .includes(["llama-cpp"]); // This is a very basic heuristic for applying compile flags. @@ -77,9 +82,9 @@ fn main() { if compiler.is_like_clang() || compiler.is_like_gnu() { if target_os == "macos" { build.flag("-mcpu=apple-m1"); - build.flag("-mfpu=neon"); } else if std::env::var("HOST") == std::env::var("TARGET") { build.flag("-mcpu=native"); + build.flag("-mfpu=neon"); } build.flag("-pthread"); } @@ -87,6 +92,10 @@ fn main() { _ => {} } + if compiler.is_like_gnu() && target_os == "linux" { + build.define("_GNU_SOURCE", None); + } + if is_release { build.define("NDEBUG", None); } diff --git a/crates/ggml/sys/llama-cpp b/crates/ggml/sys/llama-cpp index 8183159c..da040034 160000 --- a/crates/ggml/sys/llama-cpp +++ b/crates/ggml/sys/llama-cpp @@ -1 +1 @@ -Subproject commit 8183159cf3def112f6d1fe94815fce70e1bffa12 +Subproject commit da0400344be12074e67dcabc565140289cf7efaa diff --git a/crates/ggml/sys/src/cuda.rs b/crates/ggml/sys/src/cuda.rs index a9ae1a8d..5208b66e 100644 --- a/crates/ggml/sys/src/cuda.rs +++ b/crates/ggml/sys/src/cuda.rs @@ -3,15 +3,17 @@ use super::ggml_compute_params; use super::ggml_tensor; +pub const GGML_CUDA_NAME: &[u8; 5usize] = b"CUDA\0"; +pub const GGML_CUBLAS_NAME: &[u8; 7usize] = b"cuBLAS\0"; pub const GGML_CUDA_MAX_DEVICES: u32 = 16; extern "C" { pub fn ggml_init_cublas(); } extern "C" { - pub fn ggml_cuda_set_tensor_split(tensor_split: *const f32); + pub fn ggml_cuda_host_malloc(size: usize) -> *mut ::std::os::raw::c_void; } extern "C" { - pub fn ggml_cuda_mul(src0: *const ggml_tensor, src1: *const ggml_tensor, dst: *mut ggml_tensor); + pub fn ggml_cuda_host_free(ptr: *mut ::std::os::raw::c_void); } extern "C" { pub fn ggml_cuda_can_mul_mat( @@ -21,26 +23,7 @@ extern "C" { ) -> bool; } extern "C" { - pub fn ggml_cuda_mul_mat_get_wsize( - src0: *const ggml_tensor, - src1: *const ggml_tensor, - dst: *mut ggml_tensor, - ) -> usize; -} -extern "C" { - pub fn ggml_cuda_mul_mat( - src0: *const ggml_tensor, - src1: *const ggml_tensor, - dst: *mut ggml_tensor, - wdata: *mut ::std::os::raw::c_void, - wsize: usize, - ); -} -extern "C" { - pub fn ggml_cuda_host_malloc(size: usize) -> *mut ::std::os::raw::c_void; -} -extern "C" { - pub fn ggml_cuda_host_free(ptr: *mut ::std::os::raw::c_void); + pub fn ggml_cuda_set_tensor_split(tensor_split: *const f32); } extern "C" { pub fn ggml_cuda_transform_tensor(data: *mut ::std::os::raw::c_void, tensor: *mut ggml_tensor); @@ -57,6 +40,12 @@ extern "C" { extern "C" { pub fn ggml_cuda_assign_buffers_force_inplace(tensor: *mut ggml_tensor); } +extern "C" { + pub fn ggml_cuda_assign_buffers_no_alloc(tensor: *mut ggml_tensor); +} +extern "C" { + pub fn ggml_cuda_assign_scratch_offset(tensor: *mut ggml_tensor, offset: usize); +} extern "C" { pub fn ggml_cuda_set_main_device(main_device: ::std::os::raw::c_int); } @@ -75,3 +64,13 @@ extern "C" { tensor: *mut ggml_tensor, ) -> bool; } +extern "C" { + pub fn ggml_cuda_get_device_count() -> ::std::os::raw::c_int; +} +extern "C" { + pub fn ggml_cuda_get_device_description( + device: ::std::os::raw::c_int, + description: *mut ::std::os::raw::c_char, + description_size: usize, + ); +} diff --git a/crates/ggml/sys/src/lib.rs b/crates/ggml/sys/src/lib.rs index 77b47802..71b34251 100644 --- a/crates/ggml/sys/src/lib.rs +++ b/crates/ggml/sys/src/lib.rs @@ -22,12 +22,17 @@ pub const GGML_MAX_NODES: u32 = 4096; pub const GGML_MAX_PARAMS: u32 = 256; pub const GGML_MAX_CONTEXTS: u32 = 64; pub const GGML_MAX_SRC: u32 = 6; -pub const GGML_MAX_NAME: u32 = 48; +pub const GGML_MAX_NAME: u32 = 64; pub const GGML_MAX_OP_PARAMS: u32 = 32; pub const GGML_DEFAULT_N_THREADS: u32 = 4; +pub const GGML_MEM_ALIGN: u32 = 16; pub const GGML_EXIT_SUCCESS: u32 = 0; pub const GGML_EXIT_ABORTED: u32 = 1; +pub const GGUF_MAGIC: u32 = 1179993927; +pub const GGUF_VERSION: u32 = 2; +pub const GGUF_DEFAULT_ALIGNMENT: u32 = 32; pub const GGML_GRAPH_HASHTABLE_SIZE: u32 = 8273; +pub const GGML_N_TASKS_MAX: i32 = -1; pub const QK_K: u32 = 256; pub const K_SCALE_SIZE: u32 = 12; pub type ggml_fp16_t = u16; @@ -103,49 +108,58 @@ pub const ggml_op_GGML_OP_MEAN: ggml_op = 13; pub const ggml_op_GGML_OP_ARGMAX: ggml_op = 14; pub const ggml_op_GGML_OP_REPEAT: ggml_op = 15; pub const ggml_op_GGML_OP_REPEAT_BACK: ggml_op = 16; -pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 17; -pub const ggml_op_GGML_OP_NORM: ggml_op = 18; -pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 19; -pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 20; -pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 21; -pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 22; -pub const ggml_op_GGML_OP_SCALE: ggml_op = 23; -pub const ggml_op_GGML_OP_SET: ggml_op = 24; -pub const ggml_op_GGML_OP_CPY: ggml_op = 25; -pub const ggml_op_GGML_OP_CONT: ggml_op = 26; -pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 27; -pub const ggml_op_GGML_OP_VIEW: ggml_op = 28; -pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 29; -pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 30; -pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 31; -pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 32; -pub const ggml_op_GGML_OP_DIAG: ggml_op = 33; -pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 34; -pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 35; -pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 36; -pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 37; -pub const ggml_op_GGML_OP_ROPE: ggml_op = 38; -pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 39; -pub const ggml_op_GGML_OP_ALIBI: ggml_op = 40; -pub const ggml_op_GGML_OP_CLAMP: ggml_op = 41; -pub const ggml_op_GGML_OP_CONV_1D: ggml_op = 42; -pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 43; -pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 44; -pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 45; -pub const ggml_op_GGML_OP_FLASH_ATTN: ggml_op = 46; -pub const ggml_op_GGML_OP_FLASH_FF: ggml_op = 47; -pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 48; -pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 49; -pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 50; -pub const ggml_op_GGML_OP_UNARY: ggml_op = 51; -pub const ggml_op_GGML_OP_MAP_UNARY: ggml_op = 52; -pub const ggml_op_GGML_OP_MAP_BINARY: ggml_op = 53; -pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 54; -pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 55; -pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 56; -pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 57; -pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 58; -pub const ggml_op_GGML_OP_COUNT: ggml_op = 59; +pub const ggml_op_GGML_OP_CONCAT: ggml_op = 17; +pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 18; +pub const ggml_op_GGML_OP_NORM: ggml_op = 19; +pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 20; +pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 21; +pub const ggml_op_GGML_OP_GROUP_NORM: ggml_op = 22; +pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 23; +pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 24; +pub const ggml_op_GGML_OP_SCALE: ggml_op = 25; +pub const ggml_op_GGML_OP_SET: ggml_op = 26; +pub const ggml_op_GGML_OP_CPY: ggml_op = 27; +pub const ggml_op_GGML_OP_CONT: ggml_op = 28; +pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 29; +pub const ggml_op_GGML_OP_VIEW: ggml_op = 30; +pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 31; +pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 32; +pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 33; +pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 34; +pub const ggml_op_GGML_OP_DIAG: ggml_op = 35; +pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 36; +pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 37; +pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 38; +pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 39; +pub const ggml_op_GGML_OP_ROPE: ggml_op = 40; +pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 41; +pub const ggml_op_GGML_OP_ALIBI: ggml_op = 42; +pub const ggml_op_GGML_OP_CLAMP: ggml_op = 43; +pub const ggml_op_GGML_OP_CONV_1D: ggml_op = 44; +pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 45; +pub const ggml_op_GGML_OP_CONV_TRANSPOSE_2D: ggml_op = 46; +pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 47; +pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 48; +pub const ggml_op_GGML_OP_UPSCALE: ggml_op = 49; +pub const ggml_op_GGML_OP_FLASH_ATTN: ggml_op = 50; +pub const ggml_op_GGML_OP_FLASH_FF: ggml_op = 51; +pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 52; +pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 53; +pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 54; +pub const ggml_op_GGML_OP_GET_REL_POS: ggml_op = 55; +pub const ggml_op_GGML_OP_ADD_REL_POS: ggml_op = 56; +pub const ggml_op_GGML_OP_UNARY: ggml_op = 57; +pub const ggml_op_GGML_OP_MAP_UNARY: ggml_op = 58; +pub const ggml_op_GGML_OP_MAP_BINARY: ggml_op = 59; +pub const ggml_op_GGML_OP_MAP_CUSTOM1_F32: ggml_op = 60; +pub const ggml_op_GGML_OP_MAP_CUSTOM2_F32: ggml_op = 61; +pub const ggml_op_GGML_OP_MAP_CUSTOM3_F32: ggml_op = 62; +pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 63; +pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 64; +pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 65; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 66; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 67; +pub const ggml_op_GGML_OP_COUNT: ggml_op = 68; pub type ggml_op = ::std::os::raw::c_uint; pub const ggml_unary_op_GGML_UNARY_OP_ABS: ggml_unary_op = 0; pub const ggml_unary_op_GGML_UNARY_OP_SGN: ggml_unary_op = 1; @@ -157,11 +171,15 @@ pub const ggml_unary_op_GGML_UNARY_OP_RELU: ggml_unary_op = 6; pub const ggml_unary_op_GGML_UNARY_OP_GELU: ggml_unary_op = 7; pub const ggml_unary_op_GGML_UNARY_OP_GELU_QUICK: ggml_unary_op = 8; pub const ggml_unary_op_GGML_UNARY_OP_SILU: ggml_unary_op = 9; -pub type ggml_unary_op = ::std::os::raw::c_int; +pub type ggml_unary_op = ::std::os::raw::c_uint; pub const ggml_object_type_GGML_OBJECT_TENSOR: ggml_object_type = 0; pub const ggml_object_type_GGML_OBJECT_GRAPH: ggml_object_type = 1; pub const ggml_object_type_GGML_OBJECT_WORK_BUFFER: ggml_object_type = 2; -pub type ggml_object_type = ::std::os::raw::c_int; +pub type ggml_object_type = ::std::os::raw::c_uint; +pub const ggml_log_level_GGML_LOG_LEVEL_ERROR: ggml_log_level = 2; +pub const ggml_log_level_GGML_LOG_LEVEL_WARN: ggml_log_level = 3; +pub const ggml_log_level_GGML_LOG_LEVEL_INFO: ggml_log_level = 4; +pub type ggml_log_level = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_object { @@ -253,8 +271,10 @@ pub struct ggml_tensor { pub perf_runs: ::std::os::raw::c_int, pub perf_cycles: i64, pub perf_time_us: i64, + pub view_src: *mut ggml_tensor, + pub view_offs: usize, pub data: *mut ::std::os::raw::c_void, - pub name: [::std::os::raw::c_char; 48usize], + pub name: [::std::os::raw::c_char; 64usize], pub extra: *mut ::std::os::raw::c_void, pub padding: [::std::os::raw::c_char; 4usize], } @@ -264,7 +284,7 @@ fn bindgen_test_layout_ggml_tensor() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 272usize, + 304usize, concat!("Size of: ", stringify!(ggml_tensor)) ); assert_eq!( @@ -403,8 +423,28 @@ fn bindgen_test_layout_ggml_tensor() { ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).data) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).view_src) as usize - ptr as usize }, 200usize, + concat!( + "Offset of field: ", + stringify!(ggml_tensor), + "::", + stringify!(view_src) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).view_offs) as usize - ptr as usize }, + 208usize, + concat!( + "Offset of field: ", + stringify!(ggml_tensor), + "::", + stringify!(view_offs) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).data) as usize - ptr as usize }, + 216usize, concat!( "Offset of field: ", stringify!(ggml_tensor), @@ -414,7 +454,7 @@ fn bindgen_test_layout_ggml_tensor() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).name) as usize - ptr as usize }, - 208usize, + 224usize, concat!( "Offset of field: ", stringify!(ggml_tensor), @@ -424,7 +464,7 @@ fn bindgen_test_layout_ggml_tensor() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).extra) as usize - ptr as usize }, - 256usize, + 288usize, concat!( "Offset of field: ", stringify!(ggml_tensor), @@ -434,7 +474,7 @@ fn bindgen_test_layout_ggml_tensor() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).padding) as usize - ptr as usize }, - 264usize, + 296usize, concat!( "Offset of field: ", stringify!(ggml_tensor), @@ -443,7 +483,7 @@ fn bindgen_test_layout_ggml_tensor() { ) ); } -pub const GGML_TENSOR_SIZE: usize = 272; +pub const GGML_TENSOR_SIZE: usize = 304; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_cplan { @@ -867,6 +907,9 @@ extern "C" { extern "C" { pub fn ggml_nbytes(tensor: *const ggml_tensor) -> usize; } +extern "C" { + pub fn ggml_nbytes_pad(tensor: *const ggml_tensor) -> usize; +} extern "C" { pub fn ggml_nbytes_split( tensor: *const ggml_tensor, @@ -909,6 +952,9 @@ extern "C" { extern "C" { pub fn ggml_is_permuted(tensor: *const ggml_tensor) -> bool; } +extern "C" { + pub fn ggml_are_same_shape(t0: *const ggml_tensor, t1: *const ggml_tensor) -> bool; +} extern "C" { pub fn ggml_tensor_overhead() -> usize; } @@ -991,7 +1037,7 @@ extern "C" { pub fn ggml_dup_tensor(ctx: *mut ggml_context, src: *const ggml_tensor) -> *mut ggml_tensor; } extern "C" { - pub fn ggml_view_tensor(ctx: *mut ggml_context, src: *const ggml_tensor) -> *mut ggml_tensor; + pub fn ggml_view_tensor(ctx: *mut ggml_context, src: *mut ggml_tensor) -> *mut ggml_tensor; } extern "C" { pub fn ggml_get_tensor( @@ -1187,6 +1233,13 @@ extern "C" { b: *mut ggml_tensor, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_concat( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_abs(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } @@ -1256,10 +1309,14 @@ extern "C" { ) -> *mut ggml_tensor; } extern "C" { - pub fn ggml_norm(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; + pub fn ggml_norm(ctx: *mut ggml_context, a: *mut ggml_tensor, eps: f32) -> *mut ggml_tensor; } extern "C" { - pub fn ggml_norm_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; + pub fn ggml_norm_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + eps: f32, + ) -> *mut ggml_tensor; } extern "C" { pub fn ggml_rms_norm(ctx: *mut ggml_context, a: *mut ggml_tensor, eps: f32) @@ -1272,11 +1329,26 @@ extern "C" { eps: f32, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_group_norm( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_groups: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_group_norm_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_groups: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_rms_norm_back( ctx: *mut ggml_context, a: *mut ggml_tensor, b: *mut ggml_tensor, + eps: f32, ) -> *mut ggml_tensor; } extern "C" { @@ -1591,6 +1663,16 @@ extern "C" { freq_scale: f32, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_rope_xpos_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_past: ::std::os::raw::c_int, + n_dims: ::std::os::raw::c_int, + base: f32, + down: bool, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_rope_back( ctx: *mut ggml_context, @@ -1599,6 +1681,10 @@ extern "C" { n_dims: ::std::os::raw::c_int, mode: ::std::os::raw::c_int, n_ctx: ::std::os::raw::c_int, + freq_base: f32, + freq_scale: f32, + xpos_base: f32, + xpos_down: bool, ) -> *mut ggml_tensor; } extern "C" { @@ -1628,6 +1714,15 @@ extern "C" { d0: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_conv_1d_ph( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s: ::std::os::raw::c_int, + d: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_conv_2d( ctx: *mut ggml_context, @@ -1642,18 +1737,31 @@ extern "C" { ) -> *mut ggml_tensor; } extern "C" { - pub fn ggml_conv_1d_ph( + pub fn ggml_conv_2d_sk_p0( ctx: *mut ggml_context, a: *mut ggml_tensor, b: *mut ggml_tensor, - s: ::std::os::raw::c_int, - d: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_conv_2d_s1_ph( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_conv_transpose_2d_p0( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + stride: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } pub const ggml_op_pool_GGML_OP_POOL_MAX: ggml_op_pool = 0; pub const ggml_op_pool_GGML_OP_POOL_AVG: ggml_op_pool = 1; pub const ggml_op_pool_GGML_OP_POOL_COUNT: ggml_op_pool = 2; -pub type ggml_op_pool = ::std::os::raw::c_int; +pub type ggml_op_pool = ::std::os::raw::c_uint; extern "C" { pub fn ggml_pool_1d( ctx: *mut ggml_context, @@ -1677,6 +1785,13 @@ extern "C" { p1: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_upscale( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + scale_factor: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_flash_attn( ctx: *mut ggml_context, @@ -1722,6 +1837,44 @@ extern "C" { w: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +extern "C" { + pub fn ggml_unary( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_unary_op, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_unary_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + op: ggml_unary_op, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_get_rel_pos( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + qh: ::std::os::raw::c_int, + kh: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_add_rel_pos( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + pw: *mut ggml_tensor, + ph: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_add_rel_pos_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + pw: *mut ggml_tensor, + ph: *mut ggml_tensor, + ) -> *mut ggml_tensor; +} pub type ggml_unary_op_f32_t = ::std::option::Option< unsafe extern "C" fn(arg1: ::std::os::raw::c_int, arg2: *mut f32, arg3: *const f32), >; @@ -1750,20 +1903,6 @@ pub type ggml_custom3_op_f32_t = ::std::option::Option< arg4: *const ggml_tensor, ), >; -extern "C" { - pub fn ggml_unary( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - op: ggml_unary_op, - ) -> *mut ggml_tensor; -} -extern "C" { - pub fn ggml_unary_inplace( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - op: ggml_unary_op, - ) -> *mut ggml_tensor; -} extern "C" { pub fn ggml_map_unary_f32( ctx: *mut ggml_context, @@ -1842,6 +1981,96 @@ extern "C" { fun: ggml_custom3_op_f32_t, ) -> *mut ggml_tensor; } +pub type ggml_custom1_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +pub type ggml_custom2_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + b: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +pub type ggml_custom3_op_t = ::std::option::Option< + unsafe extern "C" fn( + dst: *mut ggml_tensor, + a: *const ggml_tensor, + b: *const ggml_tensor, + c: *const ggml_tensor, + ith: ::std::os::raw::c_int, + nth: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ), +>; +extern "C" { + pub fn ggml_map_custom1( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + fun: ggml_custom1_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom1_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + fun: ggml_custom1_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom2( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + fun: ggml_custom2_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom2_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + fun: ggml_custom2_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom3( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + fun: ggml_custom3_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} +extern "C" { + pub fn ggml_map_custom3_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + c: *mut ggml_tensor, + fun: ggml_custom3_op_t, + n_tasks: ::std::os::raw::c_int, + userdata: *mut ::std::os::raw::c_void, + ) -> *mut ggml_tensor; +} extern "C" { pub fn ggml_cross_entropy_loss( ctx: *mut ggml_context, @@ -1863,6 +2092,14 @@ extern "C" { extern "C" { pub fn ggml_build_forward_expand(cgraph: *mut ggml_cgraph, tensor: *mut ggml_tensor); } +extern "C" { + pub fn ggml_build_backward_expand( + ctx: *mut ggml_context, + gf: *mut ggml_cgraph, + gb: *mut ggml_cgraph, + keep: bool, + ); +} extern "C" { pub fn ggml_build_forward(tensor: *mut ggml_tensor) -> ggml_cgraph; } @@ -1952,6 +2189,15 @@ pub const ggml_opt_result_GGML_LINESEARCH_MAXIMUM_STEP: ggml_opt_result = -126; pub const ggml_opt_result_GGML_LINESEARCH_MAXIMUM_ITERATIONS: ggml_opt_result = -125; pub const ggml_opt_result_GGML_LINESEARCH_INVALID_PARAMETERS: ggml_opt_result = -124; pub type ggml_opt_result = ::std::os::raw::c_int; +pub type ggml_opt_callback = + ::std::option::Option; +pub type ggml_log_callback = ::std::option::Option< + unsafe extern "C" fn( + level: ggml_log_level, + text: *const ::std::os::raw::c_char, + user_data: *mut ::std::os::raw::c_void, + ), +>; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_opt_params { @@ -1971,12 +2217,14 @@ pub struct ggml_opt_params__bindgen_ty_1 { pub n_iter: ::std::os::raw::c_int, pub sched: f32, pub decay: f32, + pub decay_min_ndim: ::std::os::raw::c_int, pub alpha: f32, pub beta1: f32, pub beta2: f32, pub eps: f32, pub eps_f: f32, pub eps_g: f32, + pub gclip: f32, } #[test] fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { @@ -1985,7 +2233,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 36usize, + 44usize, concat!("Size of: ", stringify!(ggml_opt_params__bindgen_ty_1)) ); assert_eq!( @@ -2024,8 +2272,18 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).alpha) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).decay_min_ndim) as usize - ptr as usize }, 12usize, + concat!( + "Offset of field: ", + stringify!(ggml_opt_params__bindgen_ty_1), + "::", + stringify!(decay_min_ndim) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).alpha) as usize - ptr as usize }, + 16usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2035,7 +2293,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).beta1) as usize - ptr as usize }, - 16usize, + 20usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2045,7 +2303,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).beta2) as usize - ptr as usize }, - 20usize, + 24usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2055,7 +2313,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).eps) as usize - ptr as usize }, - 24usize, + 28usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2065,7 +2323,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).eps_f) as usize - ptr as usize }, - 28usize, + 32usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2075,7 +2333,7 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).eps_g) as usize - ptr as usize }, - 32usize, + 36usize, concat!( "Offset of field: ", stringify!(ggml_opt_params__bindgen_ty_1), @@ -2083,6 +2341,16 @@ fn bindgen_test_layout_ggml_opt_params__bindgen_ty_1() { stringify!(eps_g) ) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).gclip) as usize - ptr as usize }, + 40usize, + concat!( + "Offset of field: ", + stringify!(ggml_opt_params__bindgen_ty_1), + "::", + stringify!(gclip) + ) + ); } #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -2209,7 +2477,7 @@ fn bindgen_test_layout_ggml_opt_params() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 96usize, + 104usize, concat!("Size of: ", stringify!(ggml_opt_params)) ); assert_eq!( @@ -2299,7 +2567,7 @@ fn bindgen_test_layout_ggml_opt_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).lbfgs) as usize - ptr as usize }, - 60usize, + 68usize, concat!( "Offset of field: ", stringify!(ggml_opt_params), @@ -2316,19 +2584,16 @@ pub struct ggml_opt_context { pub iter: ::std::os::raw::c_int, pub nx: i64, pub just_initialized: bool, + pub loss_before: f32, + pub loss_after: f32, pub adam: ggml_opt_context__bindgen_ty_1, pub lbfgs: ggml_opt_context__bindgen_ty_2, } #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_opt_context__bindgen_ty_1 { - pub x: *mut ggml_tensor, - pub g1: *mut ggml_tensor, - pub g2: *mut ggml_tensor, pub m: *mut ggml_tensor, pub v: *mut ggml_tensor, - pub mh: *mut ggml_tensor, - pub vh: *mut ggml_tensor, pub pf: *mut ggml_tensor, pub fx_best: f32, pub fx_prev: f32, @@ -2341,7 +2606,7 @@ fn bindgen_test_layout_ggml_opt_context__bindgen_ty_1() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 80usize, + 40usize, concat!("Size of: ", stringify!(ggml_opt_context__bindgen_ty_1)) ); assert_eq!( @@ -2350,113 +2615,63 @@ fn bindgen_test_layout_ggml_opt_context__bindgen_ty_1() { concat!("Alignment of ", stringify!(ggml_opt_context__bindgen_ty_1)) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).x) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).m) as usize - ptr as usize }, 0usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(x) + stringify!(m) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).g1) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).v) as usize - ptr as usize }, 8usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(g1) + stringify!(v) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).g2) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).pf) as usize - ptr as usize }, 16usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(g2) + stringify!(pf) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).m) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).fx_best) as usize - ptr as usize }, 24usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(m) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).v) as usize - ptr as usize }, - 32usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(v) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).mh) as usize - ptr as usize }, - 40usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(mh) + stringify!(fx_best) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).vh) as usize - ptr as usize }, - 48usize, + unsafe { ::std::ptr::addr_of!((*ptr).fx_prev) as usize - ptr as usize }, + 28usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(vh) + stringify!(fx_prev) ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).pf) as usize - ptr as usize }, - 56usize, + unsafe { ::std::ptr::addr_of!((*ptr).n_no_improvement) as usize - ptr as usize }, + 32usize, concat!( "Offset of field: ", stringify!(ggml_opt_context__bindgen_ty_1), "::", - stringify!(pf) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).fx_best) as usize - ptr as usize }, - 64usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(fx_best) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).fx_prev) as usize - ptr as usize }, - 68usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(fx_prev) - ) - ); - assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).n_no_improvement) as usize - ptr as usize }, - 72usize, - concat!( - "Offset of field: ", - stringify!(ggml_opt_context__bindgen_ty_1), - "::", - stringify!(n_no_improvement) + stringify!(n_no_improvement) ) ); } @@ -2662,7 +2877,7 @@ fn bindgen_test_layout_ggml_opt_context() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 312usize, + 288usize, concat!("Size of: ", stringify!(ggml_opt_context)) ); assert_eq!( @@ -2692,7 +2907,7 @@ fn bindgen_test_layout_ggml_opt_context() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).iter) as usize - ptr as usize }, - 104usize, + 112usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2702,7 +2917,7 @@ fn bindgen_test_layout_ggml_opt_context() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).nx) as usize - ptr as usize }, - 112usize, + 120usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2712,7 +2927,7 @@ fn bindgen_test_layout_ggml_opt_context() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).just_initialized) as usize - ptr as usize }, - 120usize, + 128usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2720,9 +2935,29 @@ fn bindgen_test_layout_ggml_opt_context() { stringify!(just_initialized) ) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).loss_before) as usize - ptr as usize }, + 132usize, + concat!( + "Offset of field: ", + stringify!(ggml_opt_context), + "::", + stringify!(loss_before) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).loss_after) as usize - ptr as usize }, + 136usize, + concat!( + "Offset of field: ", + stringify!(ggml_opt_context), + "::", + stringify!(loss_after) + ) + ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).adam) as usize - ptr as usize }, - 128usize, + 144usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2732,7 +2967,7 @@ fn bindgen_test_layout_ggml_opt_context() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).lbfgs) as usize - ptr as usize }, - 208usize, + 184usize, concat!( "Offset of field: ", stringify!(ggml_opt_context), @@ -2773,6 +3008,8 @@ extern "C" { f: *mut ggml_tensor, gf: *mut ggml_cgraph, gb: *mut ggml_cgraph, + callback: ggml_opt_callback, + callback_data: *mut ::std::os::raw::c_void, ) -> ggml_opt_result; } extern "C" { @@ -2830,6 +3067,282 @@ extern "C" { hist: *mut i64, ) -> usize; } +pub const gguf_type_GGUF_TYPE_UINT8: gguf_type = 0; +pub const gguf_type_GGUF_TYPE_INT8: gguf_type = 1; +pub const gguf_type_GGUF_TYPE_UINT16: gguf_type = 2; +pub const gguf_type_GGUF_TYPE_INT16: gguf_type = 3; +pub const gguf_type_GGUF_TYPE_UINT32: gguf_type = 4; +pub const gguf_type_GGUF_TYPE_INT32: gguf_type = 5; +pub const gguf_type_GGUF_TYPE_FLOAT32: gguf_type = 6; +pub const gguf_type_GGUF_TYPE_BOOL: gguf_type = 7; +pub const gguf_type_GGUF_TYPE_STRING: gguf_type = 8; +pub const gguf_type_GGUF_TYPE_ARRAY: gguf_type = 9; +pub const gguf_type_GGUF_TYPE_UINT64: gguf_type = 10; +pub const gguf_type_GGUF_TYPE_INT64: gguf_type = 11; +pub const gguf_type_GGUF_TYPE_FLOAT64: gguf_type = 12; +pub const gguf_type_GGUF_TYPE_COUNT: gguf_type = 13; +pub type gguf_type = ::std::os::raw::c_uint; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct gguf_context { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct gguf_init_params { + pub no_alloc: bool, + pub ctx: *mut *mut ggml_context, +} +#[test] +fn bindgen_test_layout_gguf_init_params() { + const UNINIT: ::std::mem::MaybeUninit = ::std::mem::MaybeUninit::uninit(); + let ptr = UNINIT.as_ptr(); + assert_eq!( + ::std::mem::size_of::(), + 16usize, + concat!("Size of: ", stringify!(gguf_init_params)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(gguf_init_params)) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).no_alloc) as usize - ptr as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(gguf_init_params), + "::", + stringify!(no_alloc) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).ctx) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(gguf_init_params), + "::", + stringify!(ctx) + ) + ); +} +extern "C" { + pub fn gguf_init_empty() -> *mut gguf_context; +} +extern "C" { + pub fn gguf_init_from_file( + fname: *const ::std::os::raw::c_char, + params: gguf_init_params, + ) -> *mut gguf_context; +} +extern "C" { + pub fn gguf_free(ctx: *mut gguf_context); +} +extern "C" { + pub fn gguf_type_name(type_: gguf_type) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_get_version(ctx: *const gguf_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_get_alignment(ctx: *const gguf_context) -> usize; +} +extern "C" { + pub fn gguf_get_data_offset(ctx: *const gguf_context) -> usize; +} +extern "C" { + pub fn gguf_get_data(ctx: *const gguf_context) -> *mut ::std::os::raw::c_void; +} +extern "C" { + pub fn gguf_get_n_kv(ctx: *const gguf_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_find_key( + ctx: *const gguf_context, + key: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_get_key( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_get_kv_type(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> gguf_type; +} +extern "C" { + pub fn gguf_get_arr_type(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> gguf_type; +} +extern "C" { + pub fn gguf_get_val_u8(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> u8; +} +extern "C" { + pub fn gguf_get_val_i8(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> i8; +} +extern "C" { + pub fn gguf_get_val_u16(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> u16; +} +extern "C" { + pub fn gguf_get_val_i16(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> i16; +} +extern "C" { + pub fn gguf_get_val_u32(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> u32; +} +extern "C" { + pub fn gguf_get_val_i32(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> i32; +} +extern "C" { + pub fn gguf_get_val_f32(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> f32; +} +extern "C" { + pub fn gguf_get_val_u64(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> u64; +} +extern "C" { + pub fn gguf_get_val_i64(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> i64; +} +extern "C" { + pub fn gguf_get_val_f64(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> f64; +} +extern "C" { + pub fn gguf_get_val_bool(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> bool; +} +extern "C" { + pub fn gguf_get_val_str( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_get_arr_n( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_get_arr_data( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_void; +} +extern "C" { + pub fn gguf_get_arr_str( + ctx: *const gguf_context, + key_id: ::std::os::raw::c_int, + i: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_get_n_tensors(ctx: *const gguf_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_find_tensor( + ctx: *const gguf_context, + name: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn gguf_get_tensor_offset(ctx: *const gguf_context, i: ::std::os::raw::c_int) -> usize; +} +extern "C" { + pub fn gguf_get_tensor_name( + ctx: *const gguf_context, + i: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_char; +} +extern "C" { + pub fn gguf_set_val_u8(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: u8); +} +extern "C" { + pub fn gguf_set_val_i8(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: i8); +} +extern "C" { + pub fn gguf_set_val_u16(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: u16); +} +extern "C" { + pub fn gguf_set_val_i16(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: i16); +} +extern "C" { + pub fn gguf_set_val_u32(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: u32); +} +extern "C" { + pub fn gguf_set_val_i32(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: i32); +} +extern "C" { + pub fn gguf_set_val_f32(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: f32); +} +extern "C" { + pub fn gguf_set_val_u64(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: u64); +} +extern "C" { + pub fn gguf_set_val_i64(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: i64); +} +extern "C" { + pub fn gguf_set_val_f64(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: f64); +} +extern "C" { + pub fn gguf_set_val_bool(ctx: *mut gguf_context, key: *const ::std::os::raw::c_char, val: bool); +} +extern "C" { + pub fn gguf_set_val_str( + ctx: *mut gguf_context, + key: *const ::std::os::raw::c_char, + val: *const ::std::os::raw::c_char, + ); +} +extern "C" { + pub fn gguf_set_arr_data( + ctx: *mut gguf_context, + key: *const ::std::os::raw::c_char, + type_: gguf_type, + data: *const ::std::os::raw::c_void, + n: ::std::os::raw::c_int, + ); +} +extern "C" { + pub fn gguf_set_arr_str( + ctx: *mut gguf_context, + key: *const ::std::os::raw::c_char, + data: *mut *const ::std::os::raw::c_char, + n: ::std::os::raw::c_int, + ); +} +extern "C" { + pub fn gguf_set_kv(ctx: *mut gguf_context, src: *mut gguf_context); +} +extern "C" { + pub fn gguf_add_tensor(ctx: *mut gguf_context, tensor: *const ggml_tensor); +} +extern "C" { + pub fn gguf_set_tensor_type( + ctx: *mut gguf_context, + name: *const ::std::os::raw::c_char, + type_: ggml_type, + ); +} +extern "C" { + pub fn gguf_set_tensor_data( + ctx: *mut gguf_context, + name: *const ::std::os::raw::c_char, + data: *const ::std::os::raw::c_void, + size: usize, + ); +} +extern "C" { + pub fn gguf_write_to_file( + ctx: *const gguf_context, + fname: *const ::std::os::raw::c_char, + only_meta: bool, + ); +} +extern "C" { + pub fn gguf_get_meta_size(ctx: *const gguf_context) -> usize; +} +extern "C" { + pub fn gguf_get_meta_data(ctx: *const gguf_context, data: *mut ::std::os::raw::c_void); +} extern "C" { pub fn ggml_cpu_has_avx() -> ::std::os::raw::c_int; } @@ -2854,6 +3367,9 @@ extern "C" { extern "C" { pub fn ggml_cpu_has_arm_fma() -> ::std::os::raw::c_int; } +extern "C" { + pub fn ggml_cpu_has_metal() -> ::std::os::raw::c_int; +} extern "C" { pub fn ggml_cpu_has_f16c() -> ::std::os::raw::c_int; } @@ -2878,6 +3394,9 @@ extern "C" { extern "C" { pub fn ggml_cpu_has_sse3() -> ::std::os::raw::c_int; } +extern "C" { + pub fn ggml_cpu_has_ssse3() -> ::std::os::raw::c_int; +} extern "C" { pub fn ggml_cpu_has_vsx() -> ::std::os::raw::c_int; } @@ -2898,6 +3417,10 @@ pub type ggml_vec_dot_t = ::std::option::Option< #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_type_traits_t { + pub type_name: *const ::std::os::raw::c_char, + pub blck_size: ::std::os::raw::c_int, + pub type_size: usize, + pub is_quantized: bool, pub to_float: ggml_to_float_t, pub from_float: ggml_from_float_t, pub from_float_reference: ggml_from_float_t, @@ -2910,7 +3433,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 40usize, + 72usize, concat!("Size of: ", stringify!(ggml_type_traits_t)) ); assert_eq!( @@ -2919,8 +3442,48 @@ fn bindgen_test_layout_ggml_type_traits_t() { concat!("Alignment of ", stringify!(ggml_type_traits_t)) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).to_float) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).type_name) as usize - ptr as usize }, 0usize, + concat!( + "Offset of field: ", + stringify!(ggml_type_traits_t), + "::", + stringify!(type_name) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).blck_size) as usize - ptr as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(ggml_type_traits_t), + "::", + stringify!(blck_size) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).type_size) as usize - ptr as usize }, + 16usize, + concat!( + "Offset of field: ", + stringify!(ggml_type_traits_t), + "::", + stringify!(type_size) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).is_quantized) as usize - ptr as usize }, + 24usize, + concat!( + "Offset of field: ", + stringify!(ggml_type_traits_t), + "::", + stringify!(is_quantized) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).to_float) as usize - ptr as usize }, + 32usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2930,7 +3493,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).from_float) as usize - ptr as usize }, - 8usize, + 40usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2940,7 +3503,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).from_float_reference) as usize - ptr as usize }, - 16usize, + 48usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2950,7 +3513,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).vec_dot) as usize - ptr as usize }, - 24usize, + 56usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2960,7 +3523,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).vec_dot_type) as usize - ptr as usize }, - 32usize, + 64usize, concat!( "Offset of field: ", stringify!(ggml_type_traits_t), @@ -2970,7 +3533,7 @@ fn bindgen_test_layout_ggml_type_traits_t() { ); } extern "C" { - pub fn ggml_internal_get_type_traits(i: ggml_type) -> ggml_type_traits_t; + pub fn ggml_internal_get_type_traits(type_: ggml_type) -> ggml_type_traits_t; } #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -3513,3 +4076,40 @@ extern "C" { hist: *mut i64, ) -> usize; } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct ggml_allocr { + _unused: [u8; 0], +} +extern "C" { + pub fn ggml_allocr_new( + data: *mut ::std::os::raw::c_void, + size: usize, + alignment: usize, + ) -> *mut ggml_allocr; +} +extern "C" { + pub fn ggml_allocr_new_measure(alignment: usize) -> *mut ggml_allocr; +} +extern "C" { + pub fn ggml_allocr_set_parse_seq( + alloc: *mut ggml_allocr, + list: *const ::std::os::raw::c_int, + n: ::std::os::raw::c_int, + ); +} +extern "C" { + pub fn ggml_allocr_free(alloc: *mut ggml_allocr); +} +extern "C" { + pub fn ggml_allocr_is_measure(alloc: *mut ggml_allocr) -> bool; +} +extern "C" { + pub fn ggml_allocr_reset(alloc: *mut ggml_allocr); +} +extern "C" { + pub fn ggml_allocr_alloc(alloc: *mut ggml_allocr, tensor: *mut ggml_tensor); +} +extern "C" { + pub fn ggml_allocr_alloc_graph(alloc: *mut ggml_allocr, graph: *mut ggml_cgraph) -> usize; +} diff --git a/crates/ggml/sys/src/llama.rs b/crates/ggml/sys/src/llama.rs index a8aa42ef..d3552cd9 100644 --- a/crates/ggml/sys/src/llama.rs +++ b/crates/ggml/sys/src/llama.rs @@ -1,18 +1,10 @@ /* automatically generated by rust-bindgen 0.65.1 */ pub const LLAMA_MAX_DEVICES: u32 = 1; -pub const LLAMA_FILE_MAGIC_GGJT: u32 = 1734830708; -pub const LLAMA_FILE_MAGIC_GGLA: u32 = 1734831201; -pub const LLAMA_FILE_MAGIC_GGMF: u32 = 1734831462; -pub const LLAMA_FILE_MAGIC_GGML: u32 = 1734831468; +pub const LLAMA_DEFAULT_SEED: u32 = 4294967295; pub const LLAMA_FILE_MAGIC_GGSN: u32 = 1734833006; -pub const LLAMA_FILE_VERSION: u32 = 3; -pub const LLAMA_FILE_MAGIC: u32 = 1734830708; -pub const LLAMA_FILE_MAGIC_UNVERSIONED: u32 = 1734831468; pub const LLAMA_SESSION_MAGIC: u32 = 1734833006; pub const LLAMA_SESSION_VERSION: u32 = 1; -pub const LLAMA_DEFAULT_SEED: u32 = 4294967295; -pub const LLAMA_DEFAULT_RMS_EPS: f64 = 0.000005; pub const LLAMA_FTYPE_ALL_F32: llama_ftype = 0; pub const LLAMA_FTYPE_MOSTLY_F16: llama_ftype = 1; pub const LLAMA_FTYPE_MOSTLY_Q4_0: llama_ftype = 2; @@ -30,4 +22,5 @@ pub const LLAMA_FTYPE_MOSTLY_Q4_K_M: llama_ftype = 15; pub const LLAMA_FTYPE_MOSTLY_Q5_K_S: llama_ftype = 16; pub const LLAMA_FTYPE_MOSTLY_Q5_K_M: llama_ftype = 17; pub const LLAMA_FTYPE_MOSTLY_Q6_K: llama_ftype = 18; -pub type llama_ftype = ::std::os::raw::c_int; +pub const LLAMA_FTYPE_GUESSED: llama_ftype = 1024; +pub type llama_ftype = ::std::os::raw::c_uint; diff --git a/crates/ggml/sys/src/metal.rs b/crates/ggml/sys/src/metal.rs index bbd16034..464db3ce 100644 --- a/crates/ggml/sys/src/metal.rs +++ b/crates/ggml/sys/src/metal.rs @@ -1,15 +1,16 @@ /* automatically generated by rust-bindgen 0.65.1 */ +use super::ggml_cgraph; +use super::ggml_log_callback; +use super::ggml_tensor; + pub const GGML_METAL_MAX_BUFFERS: u32 = 16; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct ggml_tensor { - _unused: [u8; 0], -} -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct ggml_cgraph { - _unused: [u8; 0], +pub const GGML_METAL_MAX_COMMAND_BUFFERS: u32 = 32; +extern "C" { + pub fn ggml_metal_log_set_callback( + log_callback: ggml_log_callback, + user_data: *mut ::std::os::raw::c_void, + ); } #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -22,6 +23,12 @@ extern "C" { extern "C" { pub fn ggml_metal_free(ctx: *mut ggml_metal_context); } +extern "C" { + pub fn ggml_metal_host_malloc(n: usize) -> *mut ::std::os::raw::c_void; +} +extern "C" { + pub fn ggml_metal_host_free(data: *mut ::std::os::raw::c_void); +} extern "C" { pub fn ggml_metal_set_n_cb(ctx: *mut ggml_metal_context, n_cb: ::std::os::raw::c_int); } @@ -41,10 +48,17 @@ extern "C" { pub fn ggml_metal_get_tensor(ctx: *mut ggml_metal_context, t: *mut ggml_tensor); } extern "C" { - pub fn ggml_metal_graph_find_concurrency(ctx: *mut ggml_metal_context, gf: *mut ggml_cgraph); + pub fn ggml_metal_graph_find_concurrency( + ctx: *mut ggml_metal_context, + gf: *mut ggml_cgraph, + check_mem: bool, + ); +} +extern "C" { + pub fn ggml_metal_if_optimized(ctx: *mut ggml_metal_context) -> ::std::os::raw::c_int; } extern "C" { - pub fn ggml_metal_if_optimized(ctx: *mut ggml_metal_context) -> bool; + pub fn ggml_metal_get_concur_list(ctx: *mut ggml_metal_context) -> *mut ::std::os::raw::c_int; } extern "C" { pub fn ggml_metal_graph_compute(ctx: *mut ggml_metal_context, gf: *mut ggml_cgraph); diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index e45f0106..5340ec11 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -1,4 +1,4 @@ -use ggml::{Buffer, ComputationGraph, Context, GraphExecutionPlan, Tensor}; +use ggml::{Buffer, ComputationGraph, Context, GraphAllocator, GraphExecutionPlan, Tensor}; use serde::Serialize; use std::{cell::RefCell, fmt::Display, sync::Arc}; use thiserror::Error; @@ -12,21 +12,6 @@ use crate::{ TokenId, TokenUtf8Buffer, TokenizationError, }; -// The size of a scratch buffer used for inference. This is used for temporary -// storage of intermediate results during inference. -// -// The specific value was copied from `llama.cpp`. -const SCRATCH_SIZE: usize = 512 * 1024 * 1024; - -type ScratchBuffers = [ggml::Buffer; 2]; - -fn scratch_buffers() -> ScratchBuffers { - [ - ggml::Buffer::new(SCRATCH_SIZE), - ggml::Buffer::new(SCRATCH_SIZE), - ] -} - /// Result of graph building pub struct GraphOutputs { /// The output containing the model's result @@ -34,6 +19,9 @@ pub struct GraphOutputs { /// The output containing embeddings pub embedding_result: Tensor, + + /// The length of the output + pub output_length: usize, } /// An inference session represents the state of the text generation. This holds @@ -66,7 +54,7 @@ pub struct InferenceSession { /// How many tokens have been fed into the model's working memory so far. #[doc(hidden)] - pub n_past: usize, + n_past: usize, /// How much memory is required per token for the temporary context used /// during inference. @@ -90,21 +78,32 @@ pub struct InferenceSession { n_embd: usize, - scratch: ScratchBuffers, + /// Allocator used by this session + allocator: GraphAllocator, + + ///Context size of this session + context_size: usize, + + /// Work buffer for graph planing + work_buffer: Vec, + + /// If the session can use the gpu + use_gpu: bool, } pub struct BuildContext<'session> { //FIXME: Borrowing issue, dont know how to fix it pub ctx0: RefCell<&'session mut Context>, + pub allocator: RefCell<&'session GraphAllocator>, pub embd: &'session Tensor, pub memory_k: &'session Tensor, pub memory_v: &'session Tensor, - pub scratch: &'session ScratchBuffers, + pub n_past: usize, } impl<'session> BuildContext<'session> { - pub fn get_scratch(&self, idx: usize) -> Option<&Buffer> { - Some(&self.scratch[idx]) + pub fn input_length(&self) -> usize { + self.embd.nelements() } } @@ -124,7 +123,7 @@ impl InferenceSession { .. } = *params; - let context_byte_size = { + let cache_byte_size = { let mut size = 0; size += mulf!( context_size, @@ -138,53 +137,48 @@ impl InferenceSession { n_embd, ggml::type_sizef(config.memory_v_type.into()) ); // memory_v - size += (5 + 10 * n_layer) * 256; // object overhead + size += 2 * 1024 * 1024; // overhead size }; + log::info!( + "Allocating {:.2} MB for KV-memory", + cache_byte_size / (1024 * 1024) + ); + if use_gpu { ggml::accelerator::initialize(0); - ggml::accelerator::set_scratch_size(config.n_batch * 1024 * 1024); + ggml::accelerator::set_scratch_size(0); } // TODO: revisit this with `Rc`, maybe? We should be able to prove that the session // context is only accessed from one thread at a time, but I've already spent enough // time on this as-is. #[allow(clippy::arc_with_non_send_sync)] - let session_ctx = Arc::new(ggml::Context::new_with_allocate(context_byte_size)); + let session_ctx = Arc::new(ggml::Context::new_with_allocate(cache_byte_size)); // Initialize key + value memory tensors let n_mem = n_layer * context_size; let n_elements = n_embd * n_mem; let (memory_k, memory_v) = kv_memory(&session_ctx, &config, use_gpu, n_elements); - let scratch = scratch_buffers(); - - // Allocate buffer for storing intermediate values during evaluation (ctx0 backing) - // For the first run, we need to guess a maximum buffer size so we can measure - // the actual memory consumption of the temporary ggml context. - // - // These numbers are from `llama.cpp`, and could potentially be more efficient. - let buf_size = { - let buf_size_mb = if n_layer >= 80 { - 1536 - } else if n_layer >= 60 { - 1280 - } else { - 1024 - }; - buf_size_mb * 1024 * 1024 + ggml::graph_overhead() - }; - + // Allocate buffer for storing tensor and graph structs + let buf_size = ggml::graph_overhead() + (ggml::tensor_overhead() * ggml::MAX_NODES); let eval = Buffer::new(buf_size); - let ctx0 = ggml::Context::new_with_buffer(eval); + log::info!( + "Allocating {:.2} MB for eval-context", + buf_size / (1024 * 1024) + ); + let ctx0 = ggml::Context::new_with_buffer(eval, false); + + let allocator = GraphAllocator::new_measurement(ggml::TENSOR_ALIGNMENT); // Set up Metal support #[cfg(feature = "metal")] let metal_context = { if use_gpu { - let mut metal_context = MetalContext::new(config.n_threads); + let mut metal_context = MetalContext::new(); metal_context.add_scratch_buffer(ctx0.storage().as_buffer().unwrap()); for buf in scratch.iter() { @@ -199,7 +193,7 @@ impl InferenceSession { InferenceSession { _session_ctx: session_ctx, - _memory_size: context_byte_size, + _memory_size: cache_byte_size, config, memory_k, memory_v, @@ -212,7 +206,10 @@ impl InferenceSession { metal_context, ctx0, n_embd, - scratch, + allocator, + context_size, + work_buffer: vec![0], + use_gpu, } } @@ -224,24 +221,98 @@ impl InferenceSession { builder: F, ) -> GraphOutputs where - F: FnOnce(BuildContext) -> (ComputationGraph, GraphOutputs), + F: Fn(BuildContext) -> (ComputationGraph, GraphOutputs), { - // Build a graph + // Check if we need to allocate the graph + if self.allocator.in_measuring_mode() { + // Build a graph + self.ctx0.recreate(); + let ctx0 = &mut self.ctx0; + + // If we are in measuring mode, we need to build a "worst case" graph, meaning the input has either `batch_size` or `context_size` tokens. + let max_n_tokens = self.config.n_batch.min(self.context_size); + // We assume the history is full + let max_n_past = self.context_size - max_n_tokens; + let embd = ctx0 + .new_tensor_1d(ggml::Type::I32, max_n_tokens) + .set_name("embd"); + + self.allocator.allocate(&embd); + + let bc = BuildContext { + ctx0: RefCell::new(ctx0), + allocator: RefCell::new(&self.allocator), + embd: &embd, + memory_k: &self.memory_k, + memory_v: &self.memory_v, + n_past: max_n_past, + }; + + let (mut worst_case_graph, built_result) = builder(bc); + // Expand the graph + worst_case_graph.build_forward_expand(&built_result.result); + + // Allocate the graph + let graph_size = + self.allocator.allocate_graph(&worst_case_graph) + ggml::TENSOR_ALIGNMENT; + log::info!("Allocating {:.2} MB for graph", graph_size / (1024 * 1024)); + // Pre-allocate the buffer for future use + self.allocator + .resize_buffer(graph_size, ggml::TENSOR_ALIGNMENT); + + if self.use_gpu { + ggml::accelerator::set_scratch_size(graph_size); + } + } + + // Reset the context and allocator self.ctx0.recreate(); + self.allocator.reset(); let ctx0 = &mut self.ctx0; + let mut embd = ctx0 .new_tensor_1d(ggml::Type::I32, input_tokens.len()) .set_name("embd"); + self.allocator.allocate(&embd); + let bc = BuildContext { ctx0: RefCell::new(ctx0), + allocator: RefCell::new(&self.allocator), embd: &embd, memory_k: &self.memory_k, memory_v: &self.memory_v, - scratch: &mut self.scratch, + n_past: self.n_past, }; + let (mut built_gf, built_result) = builder(bc); + // Build the graph + built_gf.build_forward_expand(&built_result.result); + + // Allocate the graph + self.allocator.allocate_graph(&built_gf); + + #[cfg(feature = "cublas")] + { + for mut leaf in built_gf.leafs(&ctx0) { + if leaf.backend() == ggml::accelerator::Backend::Gpu && !leaf.has_extras() { + unsafe { + let offset = leaf.data().offset_from(self.allocator.buffer.data()) as usize; + leaf.assign_scratch_offset(offset); + } + } + } + + for mut node in built_gf.nodes(&ctx0) { + if node.backend() == ggml::accelerator::Backend::Gpu && !node.has_extras() { + unsafe { + let offset = node.data().offset_from(self.allocator.buffer.data()) as usize; + node.assign_scratch_offset(offset); + } + } + } + } // Do Metal'y stuff #[cfg(feature = "metal")] { @@ -253,9 +324,6 @@ impl InferenceSession { // Write input tokens unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; - // Compute the graph - built_gf.build_forward_expand(&built_result.result); - #[cfg(feature = "metal")] { // FIXME can only process one token at a time currently @@ -276,7 +344,7 @@ impl InferenceSession { #[cfg(not(feature = "metal"))] { let mut plan = GraphExecutionPlan::new(&mut built_gf, self.config.n_threads); - plan.execute(ctx0); + plan.execute(&mut self.work_buffer); } // Adjust the required memory per token if we didn't know that already @@ -291,6 +359,7 @@ impl InferenceSession { GraphOutputs { result: built_result.result.share(), embedding_result: built_result.embedding_result.share(), + output_length: input_tokens.len(), } } @@ -303,15 +372,22 @@ impl InferenceSession { output_request: &mut OutputRequest, mut callback: impl FnMut(&[u8]) -> Result, ) -> Result<(), InferenceError> { - let beginning_of_sentence = self.n_past == 0; - - let vocab = model.tokenizer(); - let prompt_tokens = prompt.into().to_tokens(vocab, beginning_of_sentence)?; + let prompt_tokens = self.get_prompt_tokens(model, prompt)?; if self.n_past + prompt_tokens.len() >= model.context_size() { return Err(InferenceError::ContextFull); } + self.feed_prompt_tokens(model, output_request, &mut callback, prompt_tokens) + } + + fn feed_prompt_tokens( + &mut self, + model: &dyn Model, + output_request: &mut OutputRequest, + mut callback: impl FnMut(&[u8]) -> Result, + prompt_tokens: Vec, + ) -> Result<(), InferenceError> { 'outer: for batch in prompt_tokens.chunks(self.config.n_batch) { model.evaluate(self, batch, output_request); for &tk in batch { @@ -342,10 +418,46 @@ impl InferenceSession { } } log::trace!("Finished feed prompt"); - Ok(()) } + fn get_prompt_tokens<'a, P: Into>>( + &self, + model: &dyn Model, + prompt: P, + ) -> Result, TokenizationError> { + let beginning_of_sentence = self.n_past == 0; + + let vocab = model.tokenizer(); + prompt.into().to_tokens(vocab, beginning_of_sentence) + } + + /// Feed a prompt to the model for this session. + /// Same as [Self::feed_prompt] but includes logic for cutting tokens in case if the prompt is longer than current n_past. + #[instrument(skip_all)] + pub fn feed_prompt_with_swap< + 'a, + E: std::error::Error + Send + Sync + 'static, + P: Into>, + >( + &mut self, + model: &dyn Model, + prompt: P, + n_keep: usize, + output_request: &mut OutputRequest, + mut callback: impl FnMut(&[u8]) -> Result, + ) -> Result<(), InferenceError> { + let prompt_tokens = self.get_prompt_tokens(model, prompt)?; + + if self.n_past + prompt_tokens.len() >= model.context_size() { + let rewind_by = self.n_past + prompt_tokens.len() - model.context_size(); + self.remove_tokens(model, n_keep, rewind_by) + .map_err(|_e| InferenceError::ContextFull)?; + } + + self.feed_prompt_tokens(model, output_request, &mut callback, prompt_tokens) + } + /// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`. pub fn rewind(&mut self, model: &dyn Model, num: usize) -> Result, RewindError> { if !model.supports_rewind() { @@ -357,7 +469,7 @@ impl InferenceSession { } // Remove the tokens from self.tokens. - let token_start = self.n_past - num; + let token_start = self.tokens.len() - num; let deleted_tokens: Vec<_> = self.tokens.drain(token_start..).collect(); // Remove the corresponding chars from decoded @@ -373,6 +485,46 @@ impl InferenceSession { Ok(deleted_tokens) } + /// Removes `num` tokens from the specified position of the buffer. Similar to [Self::rewind]. + fn remove_tokens( + &mut self, + model: &dyn Model, + start_from: usize, + num: usize, + ) -> Result, RewindError> { + if !model.supports_rewind() { + return Err(RewindError::UnsupportedArchitecture); + } + + if start_from + num >= self.n_past { + return Err(RewindError::NotEnoughTokens); + } + + // Remove the tokens from self.tokens. + let end = start_from + num; + let deleted_tokens: Vec<_> = self.tokens.drain(start_from..end).collect(); + + // Remove the corresponding chars from decoded + let mut decoded_start = 0; + let mut decoded_end = 0; + if start_from != 0 { + for id in &self.tokens[0..start_from] { + decoded_start += model.tokenizer().token(*id as usize).len(); + } + decoded_end += decoded_start; + } + + for id in &deleted_tokens { + decoded_end += model.tokenizer().token(*id as usize).len(); + } + self.decoded_tokens.drain(decoded_start..decoded_end); + + // Decrement the n_past tokens counter. + self.n_past -= num; + + Ok(deleted_tokens) + } + /// Infer the next token for this session. #[instrument(level = "trace", skip_all)] pub fn infer_next_token( @@ -429,19 +581,7 @@ impl InferenceSession { ) -> Result { let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX); if request.play_back_previous_tokens { - // "Play back" the existing tokens, so that loading from an inference snapshot works - // as expected. - let mut token_utf8_buf = TokenUtf8Buffer::new(); - for token_id in &self.tokens { - // Buffer the token until it's valid UTF-8, then call the callback. - if let Some(tokens) = - token_utf8_buf.push(&model.tokenizer().token(*token_id as usize)) - { - if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) { - return Err(InferenceError::UserCallback(Box::new(e))); - } - } - } + self.play_back_previous_tokens(model, &mut callback)? } log::trace!( "Starting inference request with max_token_count: {}", @@ -466,10 +606,25 @@ impl InferenceSession { stats.feed_prompt_duration = start_at.elapsed().unwrap(); stats.prompt_tokens = self.n_past; - // After the prompt is consumed, sample tokens by repeatedly calling - // `infer_next_token`. We generate tokens until the model returns an - // EndOfText token, or we run out of space in the context window, - // or we reach the specified limit. + self.infer_tokens(model, rng, &mut callback, maximum_token_count, parameters)?; + stats.predict_duration = start_at.elapsed().unwrap(); + stats.predict_tokens = self.n_past; + + Ok(stats) + } + + /// sample tokens by repeatedly calling + /// [Self::infer_next_token]. Generate tokens until the model returns an + /// EndOfText token, or we run out of space in the context window, + /// or we reach the specified limit. + fn infer_tokens( + &mut self, + model: &dyn Model, + rng: &mut impl rand::Rng, + mut callback: impl FnMut(InferenceResponse) -> Result, + maximum_token_count: usize, + parameters: &InferenceParameters, + ) -> Result<(), InferenceError> { let mut tokens_processed = 0; let mut token_utf8_buf = TokenUtf8Buffer::new(); while tokens_processed < maximum_token_count { @@ -493,6 +648,79 @@ impl InferenceSession { tokens_processed += 1; } + Ok(()) + } + + /// "Play back" the existing tokens, so that loading from an inference snapshot works + /// as expected. + fn play_back_previous_tokens( + &mut self, + model: &dyn Model, + mut callback: impl FnMut(InferenceResponse) -> Result, + ) -> Result<(), InferenceError> { + let mut token_utf8_buf = TokenUtf8Buffer::new(); + for token_id in &self.tokens { + // Buffer the token until it's valid UTF-8, then call the callback. + if let Some(tokens) = token_utf8_buf.push(&model.tokenizer().token(*token_id as usize)) + { + if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) { + return Err(InferenceError::UserCallback(Box::new(e))); + } + } + } + Ok(()) + } + + /// Generate text by using the provided [Model] to evaluate the `prompt`. + /// Works the same way as [Self::infer] except has infinite text generation via context swapping + #[instrument(skip_all)] + pub fn infer_with_swap( + &mut self, + model: &dyn Model, + rng: &mut impl rand::Rng, + request: &InferenceRequest, + n_keep: usize, + output_request: &mut OutputRequest, + mut callback: impl FnMut(InferenceResponse) -> Result, + ) -> Result { + let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX); + if request.play_back_previous_tokens { + self.play_back_previous_tokens(model, &mut callback)? + } + + // infinite text generation via context swapping + // if we run out of context: + // - take the n_keep first tokens from the original prompt + // - remove half of the tokens after n_keep ((n_ctx - n_keep) / 2) + if self.n_past >= model.context_size() { + self.remove_tokens(model, n_keep, (self.n_past - n_keep) / 2) + .map_err(|_e| InferenceError::ContextFull)?; + } + + log::trace!( + "Starting inference request with max_token_count: {}", + maximum_token_count + ); + + let mut stats = InferenceStats::default(); + let start_at = std::time::SystemTime::now(); + + let parameters = request.parameters; + + // Feed the initial prompt through the transformer, to update its + // context window with new data, if necessary. + if !request.prompt.is_empty() { + self.feed_prompt( + model, + request.prompt, + output_request, + feed_prompt_callback(&mut callback), + )?; + } + stats.feed_prompt_duration = start_at.elapsed().unwrap(); + stats.prompt_tokens = self.n_past; + + self.infer_tokens(model, rng, &mut callback, maximum_token_count, parameters)?; stats.predict_duration = start_at.elapsed().unwrap(); stats.predict_tokens = self.n_past; @@ -596,7 +824,8 @@ impl InferenceSession { npast: self.n_past, config: self.config, tokens: self.tokens.clone(), - logits: self.last_logits.clone(), + decoded_tokens: self.decoded_tokens.clone(), + last_logits: self.last_logits.clone(), memory_k, memory_v, } @@ -628,6 +857,7 @@ impl InferenceSession { session.n_past = snapshot.npast; session.tokens = snapshot.tokens; + session.decoded_tokens = snapshot.decoded_tokens; session.last_logits = snapshot.last_logits; Ok(session) @@ -736,8 +966,10 @@ pub struct InferenceSnapshotRef<'a> { pub config: InferenceSessionConfig, /// All tokens generated by this inference session. pub tokens: Vec, + /// All decoded tokens generated by this inference session. + pub decoded_tokens: Vec, /// The vector of logits that was produced after the last inference. - pub logits: Vec, + pub last_logits: Vec, /// The contents of the 'key' memory tensor. #[serde(with = "serde_bytes")] pub memory_k: &'a [u8], @@ -754,7 +986,8 @@ impl InferenceSnapshotRef<'_> { npast: self.npast, config: self.config, tokens: self.tokens.clone(), - last_logits: self.logits.clone(), + decoded_tokens: self.decoded_tokens.clone(), + last_logits: self.last_logits.clone(), memory_k: self.memory_k.to_vec(), memory_v: self.memory_v.to_vec(), } @@ -772,6 +1005,8 @@ pub struct InferenceSnapshot { pub config: InferenceSessionConfig, /// All tokens generated by this inference session. pub tokens: Vec, + /// All decoded tokens generated by this inference session. + pub decoded_tokens: Vec, /// The vector of logits that was produced after the last inference. pub last_logits: Vec, /// The contents of the 'key' memory tensor. diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index 9cdd60bc..ebf71e77 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -57,7 +57,7 @@ pub struct InferenceParameters { /// This can be anything that implements [Sampler]. Refer to /// the `llm-samplers` documentation for possible samplers and suggested /// combinations: - pub sampler: Arc>>, + pub sampler: Arc>, } //Since Sampler implements Send and Sync, InferenceParameters should too. diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index e79999fc..52732155 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -29,22 +29,22 @@ pub struct FileType { /// The quantization version. pub quantization_version: u32, } -impl From for i32 { +impl From for llama_ftype { fn from(value: FileType) -> Self { - (value.quantization_version * ggml::QNT_VERSION_FACTOR) as i32 + (value.quantization_version * ggml::QNT_VERSION_FACTOR) as llama_ftype + llama_ftype::from(value.format) } } -impl TryFrom for FileType { +impl TryFrom for FileType { type Error = llama_ftype; - fn try_from(value: i32) -> Result { + fn try_from(value: llama_ftype) -> Result { let format = FileTypeFormat::try_from(((value as u32) % ggml::QNT_VERSION_FACTOR) as llama_ftype)?; Ok(Self { format, - quantization_version: (value as u32) / ggml::QNT_VERSION_FACTOR, + quantization_version: value / ggml::QNT_VERSION_FACTOR, }) } } @@ -63,7 +63,7 @@ impl FileType { .get_optional("general.file_type") .and_then(|v| v.as_uint32()) .map(|v| { - FileType::try_from(v as i32).map_err(|ftype| { + FileType::try_from(v as llama_ftype).map_err(|ftype| { HyperparametersReadError::UnsupportedFileType { file_type: ftype } }) }) diff --git a/crates/llm-base/src/lora.rs b/crates/llm-base/src/lora.rs index 07bf97c5..c034fdaa 100644 --- a/crates/llm-base/src/lora.rs +++ b/crates/llm-base/src/lora.rs @@ -108,8 +108,9 @@ impl LoraAdapter { gf.build_forward_expand(&output); //TODO: maybe pass the model's thread count to this context + let mut work_buffer = vec![0u8]; let mut plan = GraphExecutionPlan::new(&mut gf, 8); - plan.execute(&patch_context); + plan.execute(&mut work_buffer); // Overwrite the original tensor. // The `output` and the `target_tensor` are not from the same context, diff --git a/crates/llm-base/src/samplers.rs b/crates/llm-base/src/samplers.rs index 7a179f0b..f0b07b9e 100644 --- a/crates/llm-base/src/samplers.rs +++ b/crates/llm-base/src/samplers.rs @@ -59,7 +59,7 @@ pub enum SamplingError { /// to ensure a valid configuration. pub struct ConfiguredSamplers { /// A builder from the `llm-samplers` crate. - pub builder: SamplerChainBuilder, + pub builder: SamplerChainBuilder, /// Mirostat 1 is present. pub mirostat1: bool, /// Mirostat 2 is present. @@ -74,15 +74,17 @@ pub struct ConfiguredSamplers { /// We call a configuration of samplers that run in a certain order a "chain". /// Here is a description of the default chain `llm` uses: /// -/// 1. Repetition (present by default, multiple allowed) -/// 2. Frequency/Presence (optional, multiple allowed) -/// 3. Sequence Repetition (optional, multiple allowed) -/// 4. Top-K (present by default - incompatible with Mirostat) -/// 5. Tail Free (optional - incompatible with Mirostat) -/// 6. Locally Typical (optional - incompatible with Mirostat) -/// 7. Top-P (present by default - incompatible with Mirostat) -/// 8. Temperature (present by default) -/// 9. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution. +/// 1. Repetition (present by default, multiple allowed) +/// 2. Frequency/Presence (optional, multiple allowed) +/// 3. Sequence Repetition (optional, multiple allowed) +/// 4. Top-K (present by default - incompatible with Mirostat) +/// 5. Tail Free (optional - incompatible with Mirostat) +/// 6. Locally Typical (optional - incompatible with Mirostat) +/// 7. Top-P (present by default - incompatible with Mirostat) +/// 8. Top-A (optional - incompatible with Mirostat) +/// 9. Min-P (optional - incompatible with Mirostat) +/// 10. Temperature (present by default) +/// 11. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution. /// /// Samplers listed as "present by default" but incompatible with Mirostat will /// only be enabled by default if there is no Mirostat sampler enabled. @@ -142,6 +144,20 @@ impl Default for ConfiguredSamplers { Option::::None, ), ), + ( + "topa", + SamplerSlot::new_single( + || Box::new(SampleTopA::default().a1(0.0).a2(0.0)), + Option::::None, + ), + ), + ( + "minp", + SamplerSlot::new_single( + || Box::new(SampleMinP::default().p(0.0)), + Option::::None, + ), + ), ( "temperature", SamplerSlot::new_single( @@ -203,7 +219,7 @@ impl ConfiguredSamplers { ))? } else if (self.mirostat1 || self.mirostat2) && self.incompat_mirostat { Err(SamplerConfigurationError::SamplerCombinationError( - "Cannot enable top-p, top-k, locally typical or tail free samplers with Mirostat 1 or 2".to_string(), + "Cannot enable top-p, top-k, top-a, min-p, locally typical or tail free samplers with Mirostat 1 or 2".to_string(), ))? } Ok(()) @@ -245,7 +261,9 @@ impl FromStr for ConfiguredSamplers { .inspect(|(name, _slot)| match name.as_str() { "mirostat1" => result.mirostat1 = true, "mirostat2" => result.mirostat2 = true, - "topp" | "topk" | "locallytypical" | "tailfree" => result.incompat_mirostat = true, + "topa" | "minp" | "topp" | "topk" | "locallytypical" | "tailfree" => { + result.incompat_mirostat = true + } _ => (), }) .collect::>(); @@ -269,7 +287,7 @@ impl FromStr for ConfiguredSamplers { /// Sample a token. This convenience function handles building /// the sampler resources and logits objects the sampler needs. pub fn sample_token( - mut sampler: impl Sampler, + mut sampler: impl Sampler, rng: &mut impl rand::Rng, previous_tokens: &[TokenId], last_logits: impl IntoIterator, @@ -297,7 +315,7 @@ pub fn build_sampler( n_vocab: usize, bias: &[(TokenId, f32)], args: &[impl AsRef], -) -> Result>>, SamplerConfigurationError> { +) -> Result>, SamplerConfigurationError> { let mut samplers = SamplerChain::new(); if !bias.is_empty() { @@ -326,7 +344,7 @@ pub fn build_sampler( } /// Get the default sampler chain. -pub fn default_samplers() -> Arc>> { +pub fn default_samplers() -> Arc> { let mut result = ConfiguredSamplers::default(); result.ensure_default_slots(); Arc::new(Mutex::new(result.builder.into_chain())) @@ -349,8 +367,6 @@ impl<'pt, 'r> fmt::Debug for SamplerResources<'pt, 'r> { } impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { - type TokenId = TokenId; - fn with_rng_mut( &mut self, fun: &mut dyn FnMut(&mut dyn rand::RngCore), @@ -359,7 +375,7 @@ impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { Ok(()) } - fn with_last_tokens(&self, fun: &mut dyn FnMut(&[Self::TokenId])) -> Result<(), SamplerError> { + fn with_last_tokens(&self, fun: &mut dyn FnMut(&[TokenId])) -> Result<(), SamplerError> { fun(self.previous_tokens); Ok(()) } diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index d21b6bcd..5db0ec0b 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -16,6 +16,7 @@ llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" } llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" } llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" } llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" } +llm-bert = { path = "../models/bert", optional = true, version = "0.2.0-dev" } serde = { workspace = true } tracing = { workspace = true } @@ -35,13 +36,14 @@ default = ["models", "tokenizers-remote"] tokenizers-remote = ["llm-base/tokenizers-remote"] -models = ["llama", "gptneox"] #, "gpt2", "gptj", "bloom", "mpt"] +models = ["llama", "gptneox"] #, "gpt2", "gptj", "bloom", "mpt", "bert"] llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] gptj = ["dep:llm-gptj"] bloom = ["dep:llm-bloom"] gptneox = ["dep:llm-gptneox"] mpt = ["dep:llm-mpt"] +bert = ["dep:llm-bert"] # Falcon is off by default. See `llm_falcon`'s module documentation for more information. falcon = ["dep:llm-falcon"] diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index f7bf0d03..d588d034 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -173,6 +173,7 @@ macro_rules! define_models { } define_models!( + (bert, "bert", Bert, llm_bert, "Bert"), (bloom, "bloom", Bloom, llm_bloom, "BLOOM"), (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"), (gptj, "gptj", GptJ, llm_gptj, "GPT-J"), diff --git a/crates/models/bert/Cargo.toml b/crates/models/bert/Cargo.toml new file mode 100644 index 00000000..0be81b40 --- /dev/null +++ b/crates/models/bert/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "llm-bert" +version = "0.2.0-dev" +license = { workspace = true } +repository = { workspace = true } +description = "An implementation of BERT for the `llm` ecosystem." +edition = "2021" +readme = "../../../README.md" + +[dependencies] +bytemuck.workspace = true +llm-base = { path = "../../llm-base", version = "0.2.0-dev" } +tracing = { version = "0.1", features = ["log"] } + diff --git a/crates/models/bert/src/lib.rs b/crates/models/bert/src/lib.rs new file mode 100644 index 00000000..b9bf1c63 --- /dev/null +++ b/crates/models/bert/src/lib.rs @@ -0,0 +1,464 @@ +// //! An implementation of [BERT](https://huggingface.co/docs/transformers/model_doc/bert) for the `llm` ecosystem. +// #![deny(missing_docs)] + +// use std::error::Error; + +// use llm_base::{ +// ggml, +// model::{common, HyperparametersWriteError}, +// util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, LoadError, Model, +// ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, +// }; + +// /// The BERT model. +// /// +// /// # Safety +// /// This implements [Send] and [Sync] as it is immutable after construction. +// pub struct Bert { +// params: ModelParameters, +// hyperparameters: Hyperparameters, +// tokenizer: Tokenizer, + +// word_embeddings: ggml::Tensor, +// token_type_embeddings: ggml::Tensor, +// position_embeddings: ggml::Tensor, +// ln_e_w: ggml::Tensor, +// ln_e_b: ggml::Tensor, + +// // weights for the model +// layers: Vec, + +// // must be kept alive for the model +// context: ModelContext, +// } + +// unsafe impl Send for Bert {} +// unsafe impl Sync for Bert {} + +// /// BERT [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +// #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +// pub struct Hyperparameters { +// /// Size of the model's vocabulary +// pub n_vocab: usize, + +// /// Maximum number of tokens +// pub n_max_tokens: usize, + +// /// Size of the model's embedding layer +// pub n_embd: usize, + +// /// n_head +// pub n_intermediate: usize, + +// /// Number of attention heads +// pub n_head: usize, + +// /// Number of layers in the model +// pub n_layer: usize, + +// /// file_type +// pub file_type: FileType, +// } + +// impl Model for Bert { +// type Hyperparameters = Hyperparameters; + +// fn new( +// hyperparameters: Self::Hyperparameters, +// params: ModelParameters, +// tokenizer: Tokenizer, +// tensor_loader: impl TensorLoader, +// ) -> Result { +// let mut tl = tensor_loader; + +// let word_embeddings = tl.load("embeddings.word_embeddings.weight")?; +// let token_type_embeddings = tl.load("embeddings.token_type_embeddings.weight")?; +// let position_embeddings = tl.load("embeddings.position_embeddings.weight")?; + +// let ln_e_w = tl.load("embeddings.LayerNorm.weight")?; +// let ln_e_b = tl.load("embeddings.LayerNorm.bias")?; + +// let mut layers = Vec::new(); + +// for i in 0..hyperparameters.n_layer { +// let backend = params.backend(i); + +// let layer = Layer { +// ln_att_w: tl +// .load(&format!( +// "encoder.layer.{i}.attention.output.LayerNorm.weight" +// ))? +// .transfer_to(backend), +// ln_att_b: tl +// .load(&format!( +// "encoder.layer.{i}.attention.output.LayerNorm.bias" +// ))? +// .transfer_to(backend), + +// // attention +// q_w: tl +// .load(&format!("encoder.layer.{i}.attention.self.query.weight"))? +// .transfer_to(backend), +// q_b: tl +// .load(&format!("encoder.layer.{i}.attention.self.query.bias"))? +// .transfer_to(backend), +// k_w: tl +// .load(&format!("encoder.layer.{i}.attention.self.key.weight"))? +// .transfer_to(backend), +// k_b: tl +// .load(&format!("encoder.layer.{i}.attention.self.key.bias"))? +// .transfer_to(backend), +// v_w: tl +// .load(&format!("encoder.layer.{i}.attention.self.value.weight"))? +// .transfer_to(backend), +// v_b: tl +// .load(&format!("encoder.layer.{i}.attention.self.value.bias"))? +// .transfer_to(backend), + +// o_w: tl +// .load(&format!("encoder.layer.{i}.attention.output.dense.weight"))? +// .transfer_to(backend), +// o_b: tl +// .load(&format!("encoder.layer.{i}.attention.output.dense.bias"))? +// .transfer_to(backend), + +// // ff +// ff_i_w: tl +// .load(&format!("encoder.layer.{i}.intermediate.dense.weight"))? +// .transfer_to(backend), +// ff_i_b: tl +// .load(&format!("encoder.layer.{i}.intermediate.dense.bias"))? +// .transfer_to(backend), + +// ln_out_w: tl +// .load(&format!("encoder.layer.{i}.output.LayerNorm.weight"))? +// .transfer_to(backend), +// ln_out_b: tl +// .load(&format!("encoder.layer.{i}.output.LayerNorm.bias"))? +// .transfer_to(backend), +// ff_o_w: tl +// .load(&format!("encoder.layer.{i}.output.dense.weight"))? +// .transfer_to(backend), +// ff_o_b: tl +// .load(&format!("encoder.layer.{i}.output.dense.bias"))? +// .transfer_to(backend), +// }; + +// layers.push(layer); +// } +// let context = tl.finish(); + +// Ok(Self { +// ln_e_b, +// ln_e_w, +// position_embeddings, +// token_type_embeddings, +// word_embeddings, +// hyperparameters, +// params, +// tokenizer, +// layers, +// context, +// }) +// } + +// /// Starts a new `InferenceSession` for this model. +// fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { +// InferenceSession::new( +// config, +// &self.params, +// self.hyperparameters.n_layer, +// self.hyperparameters.n_embd, +// self.hyperparameters.n_vocab, +// ) +// } + +// #[tracing::instrument(level = "trace", skip_all)] +// fn evaluate( +// &self, +// session: &mut InferenceSession, +// input_tokens: &[TokenId], +// output_request: &mut OutputRequest, +// ) { +// let input_len = input_tokens.len(); +// let _ctx_size = self.params.context_size; + +// let Hyperparameters { +// n_vocab, +// n_max_tokens: _, +// n_embd, +// n_intermediate: _, +// n_head, +// n_layer, +// file_type: _, +// } = self.hyperparameters; + +// let d_head = n_embd / n_head; + +// let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let mut ctx0 = builder.ctx0.borrow_mut(); +// let gf = ctx0.create_compute_graph(); + +// let embd = builder.embd; + +// let mut input_layer = ctx0.op_get_rows(&self.word_embeddings, embd); + +// // IL = word_embeddings + token_types + position_embeddingso +// { +// // token-types: a zero tensor +// let mut token_types = ctx0.new_tensor_1d(llm_base::ElementType::I32, input_len); +// token_types.zero_data(); + +// // position embeddings: another tensor +// let position_buf: Vec = (0..input_len as i32).collect(); +// let mut positions = ctx0.new_tensor_1d(llm_base::ElementType::I32, input_len); +// unsafe { positions.write_data(bytemuck::cast_slice(&position_buf)) }; + +// // IL += token_types +// input_layer = ctx0.op_add( +// &input_layer, +// &ctx0.op_get_rows(&self.token_type_embeddings, &token_types), +// ); + +// // IL += position_embeddings +// input_layer = ctx0.op_add( +// &input_layer, +// &ctx0.op_get_rows(&self.position_embeddings, &positions), +// ); +// } + +// // embd norm +// { +// input_layer = ctx0.op_norm(&input_layer); +// input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_e_w), &self.ln_e_b); +// } + +// for il in 0..n_layer { +// ctx0.set_offloading(self.params.should_offload(il)); + +// let mut current = input_layer.share(); + +// // self-attention +// { +// let q_current = ctx0.op_reshape_3d( +// &ctx0.op_add( +// &ctx0.op_mul_mat(&self.layers[il].q_w, ¤t), +// &self.layers[il].q_b, +// ), +// d_head, +// n_head, +// input_len, +// ); +// let q = ctx0.op_permute(&q_current, (0, 2, 1, 3)); + +// let k_current = ctx0.op_reshape_3d( +// &ctx0.op_add( +// &ctx0.op_mul_mat(&self.layers[il].k_w, ¤t), +// &self.layers[il].k_b, +// ), +// d_head, +// n_head, +// input_len, +// ); +// let k = ctx0.op_permute(&k_current, (0, 2, 1, 3)); + +// let v_current = ctx0.op_reshape_3d( +// &ctx0.op_add( +// &ctx0.op_mul_mat(&self.layers[il].v_w, ¤t), +// &self.layers[il].v_b, +// ), +// d_head, +// n_head, +// input_len, +// ); +// let mut v = ctx0.op_permute(&v_current, (0, 2, 1, 3)); + +// let mut kq = ctx0.op_mul_mat(&k, &q); + +// // TODO: look into op_scale_inplace and op_soft_max_inplace +// kq = ctx0.op_scale( +// &kq, +// &ctx0.new_f32(1.0 / ((n_embd as f32 / n_head as f32).sqrt())), +// ); +// kq = ctx0.op_soft_max(&kq); + +// v = ctx0.op_cont(&ctx0.op_transpose(&v)); + +// let mut kqv = ctx0.op_mul_mat(&v, &kq); +// kqv = ctx0.op_permute(&kqv, (0, 2, 1, 3)); + +// current = ctx0.op_cpy( +// &kqv, +// &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), +// ); +// } + +// // attention output +// current = ctx0.op_add( +// &ctx0.op_mul_mat(&self.layers[il].o_w, ¤t), +// &self.layers[il].o_b, +// ); + +// // re-add the layer input +// current = ctx0.op_add(¤t, &input_layer); + +// // attention norm +// { +// current = ctx0.op_norm(¤t); +// current = ctx0.op_add( +// &ctx0.op_mul(¤t, &self.layers[il].ln_att_w), +// &self.layers[il].ln_att_b, +// ); +// } + +// let att_output = current.share(); + +// // intermediate output +// current = ctx0.op_mul_mat(&self.layers[il].ff_i_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].ff_i_b); +// current = ctx0.op_gelu(¤t); + +// // layer output +// current = ctx0.op_mul_mat(&self.layers[il].ff_o_w, ¤t); +// current = ctx0.op_add(¤t, &self.layers[il].ff_o_b); + +// // attentions bypass the intermediate layer +// current = ctx0.op_add(&att_output, ¤t); + +// // output norm +// { +// current = ctx0.op_norm(¤t); +// current = ctx0.op_add( +// &ctx0.op_mul(¤t, &self.layers[il].ln_out_w), +// &self.layers[il].ln_out_b, +// ); +// } + +// // input for next layer +// input_layer = current; +// } +// input_layer = ctx0.op_cont(&ctx0.op_transpose(&input_layer)); + +// ctx0.set_offloading(false); +// // pooler +// let mut sum = ctx0.new_tensor_2d(llm_base::ElementType::F32, input_len, 1); +// sum = ctx0.set_f32(&sum, 1.0 / (input_len as f32)); +// input_layer = ctx0.op_mul_mat(&input_layer, &sum); + +// // normalizer +// let length = ctx0.op_sqrt(&ctx0.op_sum(&ctx0.op_sqr(&input_layer))); + +// input_layer = ctx0.op_scale(&input_layer, &ctx0.op_div(&ctx0.new_f32(1.0), &length)); + +// ( +// gf, +// GraphOutputs { +// result: input_layer.share(), +// embedding_result: input_layer.share(), +// output_length: input_len, +// }, +// ) +// }); + +// // finish evaluation +// common::read_last_token(session, &outputs.result, n_vocab, input_len); +// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); +// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, 1); +// } + +// fn hyperparameters(&self) -> &Self::Hyperparameters { +// &self.hyperparameters +// } + +// fn tokenizer(&self) -> &Tokenizer { +// &self.tokenizer +// } + +// fn context_size(&self) -> usize { +// self.params.context_size +// } + +// fn bot_token_id(&self) -> Option { +// self.tokenizer.id("[PAD]".as_bytes()) +// } + +// fn eot_token_id(&self) -> TokenId { +// self.tokenizer.id("".as_bytes()).unwrap_or(2) +// } + +// fn quantize_tensors() -> Vec { +// vec![Regex::new(".*weight").unwrap()] +// } + +// fn skip_quantize_tensors() -> Vec { +// vec![] +// } + +// fn supports_rewind(&self) -> bool { +// true +// } +// } + +// impl llm_base::Hyperparameters for Hyperparameters { +// fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { +// Ok(Hyperparameters { +// n_vocab: util::read_i32(reader)?.try_into()?, +// n_max_tokens: util::read_i32(reader)?.try_into()?, +// n_embd: util::read_i32(reader)?.try_into()?, +// n_intermediate: util::read_i32(reader)?.try_into()?, +// n_head: util::read_i32(reader)?.try_into()?, +// n_layer: util::read_i32(reader)?.try_into()?, +// file_type: util::read_filetype(reader)?, +// }) +// } + +// fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { +// util::write_i32(writer, self.n_vocab.try_into()?)?; +// util::write_i32(writer, self.n_max_tokens.try_into()?)?; +// util::write_i32(writer, self.n_embd.try_into()?)?; +// util::write_i32(writer, self.n_intermediate.try_into()?)?; +// util::write_i32(writer, self.n_head.try_into()?)?; +// util::write_i32(writer, self.n_layer.try_into()?)?; +// util::write_i32(writer, self.file_type.into())?; +// Ok(()) +// } + +// fn n_vocabulary(&self) -> usize { +// self.n_vocab +// } + +// fn file_type(&self) -> Option { +// Some(self.file_type) +// } + +// fn file_type_mut(&mut self) -> Option<&mut FileType> { +// Some(&mut self.file_type) +// } +// } + +// struct Layer { +// // normalization +// ln_att_w: ggml::Tensor, +// ln_att_b: ggml::Tensor, + +// ln_out_w: ggml::Tensor, +// ln_out_b: ggml::Tensor, + +// // attention +// q_w: ggml::Tensor, +// q_b: ggml::Tensor, +// k_w: ggml::Tensor, +// k_b: ggml::Tensor, +// v_w: ggml::Tensor, +// v_b: ggml::Tensor, + +// o_w: ggml::Tensor, +// o_b: ggml::Tensor, + +// // ff +// ff_i_w: ggml::Tensor, +// ff_i_b: ggml::Tensor, + +// ff_o_w: ggml::Tensor, +// ff_o_b: ggml::Tensor, +// } diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index d8c6f822..41ce7262 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -119,8 +119,6 @@ // input_tokens: &[TokenId], // output_request: &mut OutputRequest, // ) { -// let input_len = input_tokens.len(); -// let session_len = session.n_past; // let ctx_size = self.params.context_size; // let Hyperparameters { @@ -133,6 +131,8 @@ // } = self.hyperparameters; // let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let session_len = builder.n_past; +// let input_len = builder.input_length(); // let ctx0 = builder.ctx0.borrow(); // let (memory_k_size, memory_v_size) = ( // builder.memory_k.element_size(), @@ -331,14 +331,25 @@ // GraphOutputs { // result: input_layer, // embedding_result: embeddings_tensor, +// output_length: input_len, // }, // ) // }); // // finish evaluation -// common::read_last_token(session, &outputs.result, n_vocab, input_len); -// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); -// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); +// common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); +// common::extract_logits( +// output_request, +// &outputs.result, +// n_vocab, +// outputs.output_length, +// ); +// common::extract_embeddings( +// output_request, +// &outputs.embedding_result, +// n_embd, +// outputs.output_length, +// ); // } // fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs index 79d118ee..e4f570d6 100644 --- a/crates/models/falcon/src/lib.rs +++ b/crates/models/falcon/src/lib.rs @@ -156,8 +156,6 @@ // input_tokens: &[TokenId], // output_request: &mut OutputRequest, // ) { -// let input_len = input_tokens.len(); -// let session_len = session.n_past; // let ctx_size = self.params.context_size; // let Hyperparameters { @@ -170,9 +168,12 @@ // } = self.hyperparameters; // let head_dim = n_embd / n_head; -// let n = input_len; // let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let input_len = builder.input_length(); +// let n = input_len; +// let session_len = builder.n_past; + // let mut ctx0 = builder.ctx0.borrow_mut(); // let embd = builder.embd; // let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd); @@ -192,7 +193,6 @@ // for il in 0..n_layer { // // attention uses first scratch buffer -// ctx0.use_scratch(builder.get_scratch(0)); // ctx0.set_offloading(self.params.should_offload(il)); // // self-attention @@ -319,9 +319,6 @@ // // projection // current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); -// // feed forward uses second scratch buffer -// ctx0.use_scratch(builder.get_scratch(1)); - // let inp_ff = layernorm_output.share(); // let attn_out = // ctx0.op_cpy(¤t, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n)); @@ -336,8 +333,6 @@ // input_layer = current.share(); // } -// ctx0.use_scratch(builder.get_scratch(0)); - // // norm // input_layer = ctx0.op_norm(&input_layer); @@ -349,7 +344,6 @@ // let embeddings_tensor: ggml::Tensor = input_layer.share(); // ctx0.set_offloading(false); -// ctx0.use_scratch(None); // // lm_head // input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer); @@ -359,14 +353,25 @@ // GraphOutputs { // result: input_layer, // embedding_result: embeddings_tensor, +// output_length: n, // }, // ) // }); // // finish evaluation -// common::read_last_token(session, &outputs.result, n_vocab, input_len); -// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); -// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); +// common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); +// common::extract_logits( +// output_request, +// &outputs.result, +// n_vocab, +// outputs.output_length, +// ); +// common::extract_embeddings( +// output_request, +// &outputs.embedding_result, +// n_embd, +// outputs.output_length, +// ); // } // fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index b370a7ab..dfac1064 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -141,8 +141,6 @@ // input_tokens: &[TokenId], // output_request: &mut OutputRequest, // ) { -// let input_len = input_tokens.len(); -// let session_len = session.n_past; // let ctx_size = self.params.context_size; // let Hyperparameters { @@ -154,6 +152,8 @@ // } = self.hyperparameters; // let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let input_len = builder.input_length(); +// let session_len = builder.n_past; // let mut ctx0 = builder.ctx0.borrow_mut(); // let (memory_k_size, memory_v_size) = ( // builder.memory_k.element_size(), @@ -174,7 +174,7 @@ // let mut gf = ctx0.create_compute_graph(); // for il in 0..n_layer { // ctx0.set_offloading(self.params.should_offload(il)); -// ctx0.use_scratch(builder.get_scratch(0)); + // // norm // let mut current = ctx0.op_norm(&input_layer); // current = ctx0.op_add( @@ -281,8 +281,6 @@ // // feed-forward // let ff_in = current.share(); -// ctx0.use_scratch(builder.get_scratch(1)); - // // feed-forward normalization // current = ctx0.op_norm(&ff_in); // current = ctx0.op_add( @@ -305,13 +303,10 @@ // input_layer = ctx0.op_add(¤t, &ff_in); // } -// ctx0.use_scratch(builder.get_scratch(0)); - // // normalization // input_layer = ctx0.op_norm(&input_layer); // input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b); -// ctx0.use_scratch(None); // ctx0.set_offloading(false); // let embeddings_tensor: ggml::Tensor = input_layer.share(); @@ -324,14 +319,25 @@ // GraphOutputs { // result: input_layer, // embedding_result: embeddings_tensor, +// output_length: input_len, // }, // ) // }); // // finish evaluation -// common::read_last_token(session, &outputs.result, n_vocab, input_len); -// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); -// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); +// common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); +// common::extract_logits( +// output_request, +// &outputs.result, +// n_vocab, +// outputs.output_length, +// ); +// common::extract_embeddings( +// output_request, +// &outputs.embedding_result, +// n_embd, +// outputs.output_length, +// ); // } // fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index 03da600c..9baff116 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -137,8 +137,6 @@ // input_tokens: &[TokenId], // output_request: &mut OutputRequest, // ) { -// let input_len = input_tokens.len(); -// let session_len = session.n_past; // let ctx_size = self.params.context_size; // let Hyperparameters { @@ -151,6 +149,9 @@ // } = self.hyperparameters; // let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let input_len = builder.input_length(); +// let session_len = builder.n_past; + // let mut ctx0 = builder.ctx0.borrow_mut(); // let (memory_k_size, memory_v_size) = ( // builder.memory_k.element_size(), @@ -300,14 +301,25 @@ // GraphOutputs { // result: input_layer, // embedding_result: embeddings_tensor, +// output_length: input_len, // }, // ) // }); // // finish evaluation -// common::read_last_token(session, &outputs.result, n_vocab, input_len); -// common::extract_logits(output_request, &outputs.result, n_vocab, input_len); -// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); +// common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); +// common::extract_logits( +// output_request, +// &outputs.result, +// n_vocab, +// outputs.output_length, +// ); +// common::extract_embeddings( +// output_request, +// &outputs.embedding_result, +// n_embd, +// outputs.output_length, +// ); // } // fn hyperparameters(&self) -> &Self::Hyperparameters { diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 8786e512..6b7b0086 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -138,8 +138,6 @@ impl Model for GptNeoX { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; let params = &self.data.params; let ctx_size = params.context_size; @@ -155,6 +153,9 @@ impl Model for GptNeoX { } = self.hyperparameters; let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let input_len = builder.input_length(); + let session_len = builder.n_past; + let mut ctx0 = builder.ctx0.borrow_mut(); let embd = builder.embd; @@ -169,8 +170,6 @@ impl Model for GptNeoX { for il in 0..block_count { ctx0.set_offloading(params.should_offload(il)); - // attention uses first scratch buffer - ctx0.use_scratch(builder.get_scratch(0)); // self-attention let mut current = ctx0.op_norm(&input_layer); @@ -295,9 +294,6 @@ impl Model for GptNeoX { ¤t, ); - // use the second scratch for the feed forward - ctx0.use_scratch(builder.get_scratch(1)); - let feedforward_input: Tensor; if !use_parallel_residual { feedforward_input = ctx0.op_add(¤t, &input_layer); @@ -320,9 +316,6 @@ impl Model for GptNeoX { } } - // use the first scratch for the norm - ctx0.use_scratch(builder.get_scratch(0)); - // normalize the output input_layer = ctx0.op_norm(&input_layer); // inpL = ln_f_g*inpL + ln_f_b @@ -333,8 +326,6 @@ impl Model for GptNeoX { let embeddings_tensor: ggml::Tensor = input_layer.share(); - // Disable the scratchbuffer - ctx0.use_scratch(None); ctx0.set_offloading(false); // apply language model head input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer); @@ -344,18 +335,29 @@ impl Model for GptNeoX { GraphOutputs { result: input_layer, embedding_result: embeddings_tensor, + output_length: input_len, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, vocabulary_count, input_len); - common::extract_logits(output_request, &outputs.result, vocabulary_count, input_len); + common::read_last_token( + session, + &outputs.result, + vocabulary_count, + outputs.output_length, + ); + common::extract_logits( + output_request, + &outputs.result, + vocabulary_count, + outputs.output_length, + ); common::extract_embeddings( output_request, &outputs.embedding_result, embedding_length, - input_len, + outputs.output_length, ); } diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index bbc843d0..b282ecdb 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -120,8 +120,6 @@ impl Model for Llama { input_tokens: &[TokenId], output_request: &mut OutputRequest, ) { - let input_len = input_tokens.len(); - let session_len = session.n_past; let params = &self.data.params; let ctx_size = params.context_size; @@ -140,6 +138,8 @@ impl Model for Llama { embedding_length / self.hyperparameters.grouped_query_attention(); let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let session_len = builder.n_past; + let input_len = builder.input_length(); let mut ctx0 = builder.ctx0.borrow_mut(); let embd = builder.embd; @@ -153,8 +153,6 @@ impl Model for Llama { let input_self_attention = input_layer.share(); let mut current: ggml::Tensor; - ctx0.use_scratch(builder.get_scratch(0)); - // norm current = ctx0.op_rms_norm(&input_layer); @@ -294,8 +292,6 @@ impl Model for Llama { // projection (no bias) current = ctx0.op_mul_mat(&self.blocks[il].attn_output, ¤t); - ctx0.use_scratch(builder.get_scratch(1)); - let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); // feed-forward network @@ -322,8 +318,6 @@ impl Model for Llama { input_layer = current; } - ctx0.use_scratch(builder.get_scratch(0)); - // norm input_layer = ctx0.op_rms_norm(&input_layer); @@ -336,24 +330,34 @@ impl Model for Llama { // lm_head input_layer = ctx0.op_mul_mat(&self.output, &input_layer); - ctx0.use_scratch(None); ( gf, GraphOutputs { result: input_layer, embedding_result, + output_length: input_len, }, ) }); // finish evaluation - common::read_last_token(session, &outputs.result, vocabulary_count, input_len); - common::extract_logits(output_request, &outputs.result, vocabulary_count, input_len); + common::read_last_token( + session, + &outputs.result, + vocabulary_count, + outputs.output_length, + ); + common::extract_logits( + output_request, + &outputs.result, + vocabulary_count, + outputs.output_length, + ); common::extract_embeddings( output_request, &outputs.embedding_result, embedding_length, - input_len, + outputs.output_length, ); } diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index cfeeefb6..2685fcd1 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -96,8 +96,6 @@ // input_tokens: &[TokenId], // output_request: &mut OutputRequest, // ) { -// let n = input_tokens.len(); -// let session_len = session.n_past; // let ctx_size = self.params.context_size; // let Hyperparameters { @@ -110,6 +108,8 @@ // } = self.hyperparameters; // let outputs = session.compute(self.context.clone(), input_tokens, |builder| { +// let n = builder.input_length(); +// let session_len = builder.n_past; // let ctx0 = builder.ctx0.borrow(); // let (memory_k_size, memory_v_size) = ( // builder.memory_k.element_size(), @@ -123,9 +123,6 @@ // let mut gf = ctx0.create_compute_graph(); // for il in 0..n_layer { -// // attention uses first scratch buffer -// ctx0.use_scratch(builder.get_scratch(0)); - // let mut current = ctx0.op_norm(&input_layer); // current = ctx0.op_mul(¤t, &self.layers[il].norm_1_weight); @@ -213,9 +210,6 @@ // input_layer = ctx0.op_add(&input_layer, ¤t); -// // feed forward uses second scratch buffer -// ctx0.use_scratch(builder.get_scratch(1)); - // current = ctx0.op_norm(&input_layer); // current = ctx0.op_mul(¤t, &self.layers[il].norm_2_weight); @@ -229,17 +223,12 @@ // input_layer = ctx0.op_add(&input_layer, ¤t); // } -// //use scratch buffer 0 for the rest -// ctx0.use_scratch(builder.get_scratch(0)); - // // norm // input_layer = ctx0.op_norm(&input_layer); // input_layer = ctx0.op_mul(&input_layer, &self.norm); // let embeddings_tensor: ggml::Tensor = input_layer.share(); -// // disable scratch buffer for last layer -// ctx0.use_scratch(None); // // output embedding weight tied to input embedding // input_layer = ctx0.op_mul_mat(&self.wte, &input_layer); @@ -248,14 +237,25 @@ // GraphOutputs { // result: input_layer, // embedding_result: embeddings_tensor, +// output_length: n, // }, // ) // }); // // finish evaluation -// common::read_last_token(session, &outputs.result, n_vocab, n); -// common::extract_logits(output_request, &outputs.result, n_vocab, n); -// common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); +// common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length); +// common::extract_logits( +// output_request, +// &outputs.result, +// n_vocab, +// outputs.output_length, +// ); +// common::extract_embeddings( +// output_request, +// &outputs.embedding_result, +// n_embd, +// outputs.output_length, +// ); // } // fn hyperparameters(&self) -> &Self::Hyperparameters { From 7c3d1cf3a877f139f3d248df269fecc2fd356ce2 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 12 Nov 2023 23:08:14 +0100 Subject: [PATCH 33/33] chore: fix precommit --- crates/ggml/src/format/gguf/metadata.rs | 8 ++++++-- crates/ggml/src/format/gguf/mod.rs | 2 +- crates/llm-base/src/loader.rs | 5 ++--- crates/llm-base/src/tokenizer/embedded.rs | 1 + crates/llm-base/src/tokenizer/mod.rs | 24 +++++++++++------------ crates/llm/Cargo.toml | 2 +- crates/llm/examples/embeddings.rs | 2 +- crates/llm/src/lib.rs | 19 +++++++++--------- crates/llm/src/loader.rs | 4 ++-- 9 files changed, 35 insertions(+), 32 deletions(-) diff --git a/crates/ggml/src/format/gguf/metadata.rs b/crates/ggml/src/format/gguf/metadata.rs index 70e20347..11da1916 100644 --- a/crates/ggml/src/format/gguf/metadata.rs +++ b/crates/ggml/src/format/gguf/metadata.rs @@ -79,13 +79,13 @@ impl Metadata { // TODO: consider finding a way to automate getting with traits pub fn get_str(&self, key: &str) -> Result<&str, MetadataError> { let metadata_value = self.get(key)?; - Ok(metadata_value + metadata_value .as_string() .ok_or_else(|| MetadataError::InvalidType { key: key.to_string(), expected_type: MetadataValueType::String, actual_type: metadata_value.value_type(), - })?) + }) } pub fn get_countable(&self, key: &str) -> Result { @@ -460,6 +460,10 @@ impl MetadataArrayValue { _ => None, } } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } // Shared diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs index c58e7276..dc2ab9dd 100644 --- a/crates/ggml/src/format/gguf/mod.rs +++ b/crates/ggml/src/format/gguf/mod.rs @@ -253,7 +253,7 @@ impl TensorInfo { util::write_length(writer, ctx.use_64_bit_length, *dimension)?; } - util::write_u32(writer, ggml_type::from(self.element_type) as u32)?; + util::write_u32(writer, ggml_type::from(self.element_type))?; util::write_u64(writer, self.offset)?; Ok(()) diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 52732155..58b4ab4d 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -39,8 +39,7 @@ impl TryFrom for FileType { type Error = llama_ftype; fn try_from(value: llama_ftype) -> Result { - let format = - FileTypeFormat::try_from(((value as u32) % ggml::QNT_VERSION_FACTOR) as llama_ftype)?; + let format = FileTypeFormat::try_from((value % ggml::QNT_VERSION_FACTOR) as llama_ftype)?; Ok(Self { format, @@ -360,7 +359,7 @@ pub trait ModelFactory { /// This method returns a [`Box`], which means that the model will have single ownership. /// If you'd like to share ownership (i.e. to use the model in multiple threads), we /// suggest using [`Arc::from(Box)`](https://doc.rust-lang.org/std/sync/struct.Arc.html#impl-From%3CBox%3CT,+Global%3E%3E-for-Arc%3CT%3E) -/// to convert the [`Box`] into an [`Arc`](std::sync::Arc) after loading. +/// to convert the [`Box`] into an [`Arc`] after loading. pub fn load( path: &Path, tokenizer_source: TokenizerSource, diff --git a/crates/llm-base/src/tokenizer/embedded.rs b/crates/llm-base/src/tokenizer/embedded.rs index 02acb4c3..25387d23 100644 --- a/crates/llm-base/src/tokenizer/embedded.rs +++ b/crates/llm-base/src/tokenizer/embedded.rs @@ -497,6 +497,7 @@ fn unescape_whitespace(text: &[u8]) -> Vec { let mut buffer: Vec = vec![]; for &b in text { + #[allow(clippy::if_same_then_else)] if b == 0xE2 { // If the current byte is 0xE2, start buffering and check for the sequence. buffer.push(b); diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index afa8c9d6..9852993e 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -118,7 +118,7 @@ impl TokenizerSource { tokenizer_source: HuggingFaceTokenizerErrorSource::Remote( identifier.clone(), ), - error: error.into(), + error, } })?, ) @@ -128,7 +128,7 @@ impl TokenizerSource { tokenizers::Tokenizer::from_file(&path).map_err(|error| { TokenizerLoadError::HuggingFaceTokenizerError { tokenizer_source: HuggingFaceTokenizerErrorSource::File(path.clone()), - error: error.into(), + error, } })?, ) @@ -139,20 +139,18 @@ impl TokenizerSource { Self::Embedded => { if let Ok(hf) = gguf.metadata.get_str("tokenizer.huggingface.json") { Ok(Self::load_huggingface_json(hf)?) - } else { - if EmbeddedTokenizer::is_present_in_metadata(&gguf.metadata) { - if EMBEDDED_TOKENIZER_ENABLED { - Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into()) - } else { - Err(TokenizerLoadError::NoSupportedTokenizersFound { - unsupported_tokenizers: vec!["embedded".to_owned()], - }) - } + } else if EmbeddedTokenizer::is_present_in_metadata(&gguf.metadata) { + if EMBEDDED_TOKENIZER_ENABLED { + Ok(EmbeddedTokenizer::from_metadata(&gguf.metadata)?.into()) } else { Err(TokenizerLoadError::NoSupportedTokenizersFound { - unsupported_tokenizers: vec![], + unsupported_tokenizers: vec!["embedded".to_owned()], }) } + } else { + Err(TokenizerLoadError::NoSupportedTokenizersFound { + unsupported_tokenizers: vec![], + }) } } } @@ -163,7 +161,7 @@ impl TokenizerSource { HuggingFaceTokenizer::new(tokenizers::Tokenizer::from_str(tokenizer_json).map_err( |error| TokenizerLoadError::HuggingFaceTokenizerError { tokenizer_source: HuggingFaceTokenizerErrorSource::String, - error: error.into(), + error, }, )?) .into(), diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 5db0ec0b..159950f8 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -36,7 +36,7 @@ default = ["models", "tokenizers-remote"] tokenizers-remote = ["llm-base/tokenizers-remote"] -models = ["llama", "gptneox"] #, "gpt2", "gptj", "bloom", "mpt", "bert"] +models = ["llama", "gptneox", "gpt2", "gptj", "bloom", "mpt", "bert"] llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] gptj = ["dep:llm-gptj"] diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index 795d9740..64fc4009 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -104,7 +104,7 @@ fn main() { fn get_embeddings( model: &dyn llm::Model, - inference_parameters: &llm::InferenceParameters, + _inference_parameters: &llm::InferenceParameters, query: &str, ) -> Vec { let mut session = model.start_session(Default::default()); diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index d588d034..39069f06 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -7,6 +7,7 @@ //! - [GPT-NeoX](llm_gptneox) //! - [LLaMA](llm_llama) //! - [MPT](llm_mpt) +//! - [BERT](llm_bert) //! - Falcon (currently disabled due to incompleteness) //! //! At present, the only supported backend is [GGML](https://github.com/ggerganov/ggml), but this is expected to @@ -19,7 +20,7 @@ //! use llm::Model; //! //! // load a GGML model from disk -//! let llama = llm::load::( +//! let llama = llm::load( //! // path to GGML file //! std::path::Path::new("/path/to/model"), //! // llm::TokenizerSource @@ -35,7 +36,7 @@ //! let mut session = llama.start_session(Default::default()); //! let res = session.infer::( //! // model to use for text generation -//! &llama, +//! llama.as_ref(), //! // randomness provider //! &mut rand::thread_rng(), //! // the prompt to use for text generation, as well as other @@ -94,7 +95,7 @@ pub use loader::{load, load_progress_callback_stdout, LoadError, LoadProgress}; use serde::Serialize; macro_rules! define_models { - ($(($model_lowercase:ident, $model_lowercase_str:literal, $model_pascalcase:ident, $krate_ident:ident, $display_name:literal)),*) => { + ($(($model_lowercase:ident, $model_lowercase_str:literal, $model_pascalcase:ident, $krate_ident:ident, $display_name:literal),)*) => { /// All available models. pub mod models { $( @@ -173,14 +174,14 @@ macro_rules! define_models { } define_models!( - (bert, "bert", Bert, llm_bert, "Bert"), - (bloom, "bloom", Bloom, llm_bloom, "BLOOM"), - (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"), - (gptj, "gptj", GptJ, llm_gptj, "GPT-J"), + // (bert, "bert", Bert, llm_bert, "Bert"), + // (bloom, "bloom", Bloom, llm_bloom, "BLOOM"), + // (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"), + // (gptj, "gptj", GptJ, llm_gptj, "GPT-J"), (gptneox, "gptneox", GptNeoX, llm_gptneox, "GPT-NeoX"), (llama, "llama", Llama, llm_llama, "LLaMA"), - (mpt, "mpt", Mpt, llm_mpt, "MPT"), - (falcon, "falcon", Falcon, llm_falcon, "Falcon") + // (mpt, "mpt", Mpt, llm_mpt, "MPT"), + // (falcon, "falcon", Falcon, llm_falcon, "Falcon"), ); /// Used to dispatch some code based on the model architecture. diff --git a/crates/llm/src/loader.rs b/crates/llm/src/loader.rs index bc4c871c..dd75fd4a 100644 --- a/crates/llm/src/loader.rs +++ b/crates/llm/src/loader.rs @@ -23,13 +23,13 @@ pub fn load( params: ModelParameters, load_progress_callback: impl FnMut(LoadProgress), ) -> Result, LoadError> { - Ok(llm_base::loader::load( + llm_base::loader::load( path, tokenizer_source, params, VisitorModelFactory, load_progress_callback, - )?) + ) } struct VisitorModelFactory;