From dd7aa263ffcbd57e0f9e517996ca2d526f34a17f Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 20 Aug 2023 22:15:40 +0200 Subject: [PATCH] feat(ggml): impl unwired gguf loader --- crates/ggml/examples/gguf.rs | 14 ++ crates/ggml/src/format/ggml/loader.rs | 6 + crates/ggml/src/format/ggml/mod.rs | 13 +- crates/ggml/src/format/gguf/mod.rs | 288 ++++++++++++++++++++++++++ crates/ggml/src/format/mod.rs | 1 + crates/ggml/src/util.rs | 13 ++ 6 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 crates/ggml/examples/gguf.rs create mode 100644 crates/ggml/src/format/gguf/mod.rs diff --git a/crates/ggml/examples/gguf.rs b/crates/ggml/examples/gguf.rs new file mode 100644 index 00000000..f7c5328f --- /dev/null +++ b/crates/ggml/examples/gguf.rs @@ -0,0 +1,14 @@ +use std::io::BufReader; + +use ggml::format::gguf; + +fn main() -> anyhow::Result<()> { + let mut file = BufReader::new(std::fs::File::open( + std::env::args().nth(1).expect("need a file to read"), + )?); + + let gguf = gguf::Gguf::load(&mut file)?; + dbg!(gguf); + + Ok(()) +} diff --git a/crates/ggml/src/format/ggml/loader.rs b/crates/ggml/src/format/ggml/loader.rs index 0700d60e..7b079ee8 100644 --- a/crates/ggml/src/format/ggml/loader.rs +++ b/crates/ggml/src/format/ggml/loader.rs @@ -126,6 +126,9 @@ pub fn load( // Legacy model, set empty score 0. } + ContainerType::Gguf(_) => { + unreachable!("This loader should not be used with GGUF") + } }; handler .vocabulary_token(i, token, token_score) @@ -138,6 +141,9 @@ pub fn load( ContainerType::Ggjt(_version) | ContainerType::Ggla(_version) => { load_weights(reader, handler, true) } + ContainerType::Gguf(_) => { + unreachable!("This loader should not be used with GGUF") + } } } diff --git a/crates/ggml/src/format/ggml/mod.rs b/crates/ggml/src/format/ggml/mod.rs index 6354370f..774ce69d 100644 --- a/crates/ggml/src/format/ggml/mod.rs +++ b/crates/ggml/src/format/ggml/mod.rs @@ -19,6 +19,8 @@ pub const FILE_MAGIC_GGMF: [u8; 4] = *b"fmgg"; pub const FILE_MAGIC_GGJT: [u8; 4] = *b"tjgg"; /// Magic constant for `ggla` files (LoRA adapter). pub const FILE_MAGIC_GGLA: [u8; 4] = *b"algg"; +/// Magic constant for `gguf` files. +pub const FILE_MAGIC_GGUF: [u8; 4] = *b"GGUF"; #[derive(Debug, PartialEq, Clone, Copy)] /// The format of the file containing the model. @@ -27,10 +29,12 @@ pub enum ContainerType { Ggml, /// Legacy format. Introduces versioning. Newer than GGML, older than GGJT. Ggmf(u32), - /// [mmap](https://en.wikipedia.org/wiki/Mmap)-able format. Current version of the format. + /// [mmap](https://en.wikipedia.org/wiki/Mmap)-able format. Ggjt(u32), /// LoRA adapter format. Ggla(u32), + /// GGUF format. Current version of the format. + Gguf(u32), } impl ContainerType { /// Does this container type support mmap? @@ -40,6 +44,7 @@ impl ContainerType { ContainerType::Ggmf(_) => false, ContainerType::Ggla(_) => false, ContainerType::Ggjt(_) => true, + ContainerType::Gguf(_) => true, } } @@ -63,6 +68,10 @@ impl ContainerType { let version = util::read_u32(reader)?; ContainerType::Ggla(version) } + FILE_MAGIC_GGUF => { + let version = util::read_u32(reader)?; + ContainerType::Gguf(version) + } magic => return Err(LoadError::InvalidMagic(util::FormatMagic(magic))), }; @@ -87,6 +96,8 @@ impl ContainerType { writer.write_all(&FILE_MAGIC_GGLA)?; util::write_u32(writer, *version)?; } + ContainerType::Gguf(version) => { + writer.write_all(&FILE_MAGIC_GGUF)?; util::write_u32(writer, *version)?; } } diff --git a/crates/ggml/src/format/gguf/mod.rs b/crates/ggml/src/format/gguf/mod.rs new file mode 100644 index 00000000..97156ff7 --- /dev/null +++ b/crates/ggml/src/format/gguf/mod.rs @@ -0,0 +1,288 @@ +#![allow(missing_docs)] + +use std::{ + collections::HashMap, + convert::Infallible, + io::{BufRead, Seek}, +}; + +use crate::{util, ElementType}; + +use super::{ggml::ContainerType, LoadError}; + +pub const DEFAULT_ALIGNMENT: u32 = 32; + +#[derive(Debug, Clone, PartialEq)] +pub struct Gguf { + pub metadata: HashMap, + pub tensor_infos: HashMap, + pub tensor_data_position: u64, +} +impl Gguf { + pub fn load(reader: &mut R) -> Result> { + let container = ContainerType::read(reader)?; + if container != ContainerType::Gguf(1) { + return Err(LoadError::InvalidFormatVersion(container)); + } + + let tensor_count = util::read_u32(reader)? as usize; + let metadata_kv_count = util::read_u32(reader)? as usize; + + let mut metadata = HashMap::with_capacity(metadata_kv_count); + for _ in 0..metadata_kv_count { + let (key, value) = MetadataValue::read_key_value(reader)?; + metadata.insert(key, value); + } + + let alignment = metadata + .get("general.alignment") + .and_then(|v| v.as_uint32()) + .unwrap_or(DEFAULT_ALIGNMENT) as u64; + + let mut tensor_infos = HashMap::with_capacity(tensor_count); + for _ in 0..tensor_count { + let (key, value) = TensorInfo::read_name_value(reader)?; + tensor_infos.insert(key, value); + } + + let tensor_data_position = { + let stream_position = reader.stream_position()?; + stream_position + (alignment - (stream_position % alignment)) % alignment + }; + + Ok(Gguf { + metadata, + tensor_infos, + tensor_data_position, + }) + } +} + +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum MetadataValueType { + // The value is a 8-bit unsigned integer. + UInt8 = 0, + // The value is a 8-bit signed integer. + Int8 = 1, + // The value is a 16-bit unsigned little-endian integer. + UInt16 = 2, + // The value is a 16-bit signed little-endian integer. + Int16 = 3, + // The value is a 32-bit unsigned little-endian integer. + UInt32 = 4, + // The value is a 32-bit signed little-endian integer. + Int32 = 5, + // The value is a 32-bit IEEE754 floating point number. + Float32 = 6, + // The value is a boolean. + // 1-byte value where 0 is false and 1 is true. + // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + Bool = 7, + // The value is a UTF-8 non-null-terminated string, with length prepended. + String = 8, + // The value is an array of other values, with the length and type prepended. + /// + // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + Array = 9, +} +impl TryFrom for MetadataValueType { + type Error = (); + + fn try_from(value: u32) -> Result { + // TODO: consider a macro solution to this? + for test_value in [ + MetadataValueType::UInt8, + MetadataValueType::Int8, + MetadataValueType::UInt16, + MetadataValueType::Int16, + MetadataValueType::UInt32, + MetadataValueType::Int32, + MetadataValueType::Float32, + MetadataValueType::Bool, + MetadataValueType::String, + MetadataValueType::Array, + ] { + if value == test_value as u32 { + return Ok(test_value); + } + } + Err(()) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum MetadataValue { + UInt8(u8), + Int8(i8), + UInt16(u16), + Int16(i16), + UInt32(u32), + Int32(i32), + Float32(f32), + Bool(bool), + String(String), + Array(MetadataArrayValue), +} +impl MetadataValue { + fn read_key_value(reader: &mut dyn BufRead) -> Result<(String, Self), LoadError> { + let key = util::read_string(reader)?; + let value_type = MetadataValueType::try_from(util::read_u32(reader)?) + .expect("TODO: handle invalid value types"); + + let value = Self::read_value(reader, value_type)?; + + Ok((key, value)) + } + + fn read_value( + reader: &mut dyn BufRead, + value_type: MetadataValueType, + ) -> Result> { + match value_type { + MetadataValueType::UInt8 => Self::read_u8(reader).map(MetadataValue::UInt8), + MetadataValueType::Int8 => Self::read_i8(reader).map(MetadataValue::Int8), + MetadataValueType::UInt16 => Self::read_u16(reader).map(MetadataValue::UInt16), + MetadataValueType::Int16 => Self::read_i16(reader).map(MetadataValue::Int16), + MetadataValueType::UInt32 => Self::read_u32(reader).map(MetadataValue::UInt32), + MetadataValueType::Int32 => Self::read_i32(reader).map(MetadataValue::Int32), + MetadataValueType::Float32 => Self::read_f32(reader).map(MetadataValue::Float32), + MetadataValueType::Bool => Self::read_bool(reader).map(MetadataValue::Bool), + MetadataValueType::String => Self::read_string(reader).map(MetadataValue::String), + MetadataValueType::Array => Self::read_array(reader).map(MetadataValue::Array), + } + } + + fn read_u8(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_bytes::<1>(reader)?[0]) + } + + fn read_i8(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_bytes::<1>(reader)?[0] as i8) + } + + fn read_u16(reader: &mut dyn BufRead) -> Result> { + Ok(u16::from_le_bytes(util::read_bytes::<2>(reader)?)) + } + + fn read_i16(reader: &mut dyn BufRead) -> Result> { + Ok(i16::from_le_bytes(util::read_bytes::<2>(reader)?)) + } + + fn read_u32(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_u32(reader)?) + } + + fn read_i32(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_i32(reader)?) + } + + fn read_f32(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_f32(reader)?) + } + + fn read_bool(reader: &mut dyn BufRead) -> Result> { + let v = Self::read_u8(reader)?; + if v == 0 { + Ok(false) + } else if v == 1 { + Ok(true) + } else { + panic!("TODO: error for invalid bools") + } + } + + fn read_string(reader: &mut dyn BufRead) -> Result> { + Ok(util::read_string(reader)?) + } + + fn read_array(reader: &mut dyn BufRead) -> Result> { + MetadataArrayValue::read_value(reader) + } + + pub fn as_uint32(&self) -> Option { + match self { + Self::UInt32(v) => Some(*v), + _ => None, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum MetadataArrayValue { + UInt8(Vec), + Int8(Vec), + UInt16(Vec), + Int16(Vec), + UInt32(Vec), + Int32(Vec), + Float32(Vec), + Bool(Vec), + String(Vec), + Array(Vec), +} +impl MetadataArrayValue { + fn read_value(reader: &mut dyn BufRead) -> Result> { + let value_type = MetadataValueType::try_from(util::read_u32(reader)?) + .expect("TODO: handle invalid value types"); + let length = util::read_u32(reader)? as usize; + + fn read_array( + reader: &mut dyn BufRead, + length: usize, + value_reader: impl Fn(&mut dyn BufRead) -> Result>, + ) -> Result, LoadError> { + (0..length).map(|_| value_reader(reader)).collect() + } + + use MetadataValue as MV; + use MetadataValueType as MVT; + Ok(match value_type { + MVT::UInt8 => read_array(reader, length, MV::read_u8).map(Self::UInt8), + MVT::Int8 => read_array(reader, length, MV::read_i8).map(Self::Int8), + MVT::UInt16 => read_array(reader, length, MV::read_u16).map(Self::UInt16), + MVT::Int16 => read_array(reader, length, MV::read_i16).map(Self::Int16), + MVT::UInt32 => read_array(reader, length, MV::read_u32).map(Self::UInt32), + MVT::Int32 => read_array(reader, length, MV::read_i32).map(Self::Int32), + MVT::Float32 => read_array(reader, length, MV::read_f32).map(Self::Float32), + MVT::Bool => read_array(reader, length, MV::read_bool).map(Self::Bool), + MVT::String => read_array(reader, length, MV::read_string).map(Self::String), + MVT::Array => read_array(reader, length, MV::read_array).map(Self::Array), + }?) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct TensorInfo { + pub dimensions: Vec, + pub element_type: ElementType, + pub offset: u64, +} +impl TensorInfo { + fn read_name_value(reader: &mut dyn BufRead) -> Result<(String, Self), LoadError> { + let name = util::read_string(reader)?; + + let dimension_count = util::read_u32(reader)? as usize; + let dimensions = (0..dimension_count) + .map(|_| util::read_u32(reader)) + .collect::, _>>()?; + + let element_type = util::read_u32(reader)?; + let element_type = + ElementType::try_from(element_type).map_err(|_| LoadError::UnsupportedElementType { + tensor_name: name.clone(), + ftype: element_type, + })?; + + let offset = util::read_u64(reader)?; + + Ok(( + name, + Self { + dimensions, + element_type, + offset, + }, + )) + } +} diff --git a/crates/ggml/src/format/mod.rs b/crates/ggml/src/format/mod.rs index 5370343f..9fe51d2b 100644 --- a/crates/ggml/src/format/mod.rs +++ b/crates/ggml/src/format/mod.rs @@ -5,6 +5,7 @@ use std::error::Error; use crate::{util::FormatMagic, ElementType}; pub mod ggml; +pub mod gguf; #[derive(Debug, thiserror::Error)] /// Errors that can occur while loading a model. diff --git a/crates/ggml/src/util.rs b/crates/ggml/src/util.rs index 814b349f..7812b71b 100644 --- a/crates/ggml/src/util.rs +++ b/crates/ggml/src/util.rs @@ -36,6 +36,11 @@ pub fn read_u32(reader: &mut dyn BufRead) -> Result { Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) } +/// Read a `u64` from a reader. +pub fn read_u64(reader: &mut dyn BufRead) -> Result { + Ok(u64::from_le_bytes(read_bytes::<8>(reader)?)) +} + /// Read a `f32` from a reader. pub fn read_f32(reader: &mut dyn BufRead) -> Result { Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) @@ -64,6 +69,14 @@ pub fn read_bytes_with_len( Ok(bytes) } +/// Read a string from a reader. +pub fn read_string(reader: &mut dyn BufRead) -> Result { + let len = read_u32(reader)? as usize; + let bytes = read_bytes_with_len(reader, len)?; + Ok(String::from_utf8(bytes) + .expect("string was not valid utf-8 (TODO: make this a library error)")) +} + /// Write a `i32` from a writer. pub fn write_i32(writer: &mut dyn Write, value: i32) -> Result<(), std::io::Error> { writer.write_all(&value.to_le_bytes())