Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
fix(ggml): bindgen issues
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Nov 1, 2023
1 parent e506b0b commit fcbfb4d
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 22 deletions.
3 changes: 3 additions & 0 deletions binaries/generate-ggml-bindings/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ fn generate_metal(ggml_path: &Path, src_path: &Path) {
generate_extra("metal", ggml_path, src_path, |b| {
b.header(ggml_path.join("ggml-metal.h").to_string_lossy())
.allowlist_file(r".*ggml-metal\.h")
.raw_line("use super::ggml_tensor;")
.raw_line("use super::ggml_log_callback;")
.raw_line("use super::ggml_cgraph;")
});
}

Expand Down
13 changes: 4 additions & 9 deletions crates/ggml/src/accelerator/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ pub struct MetalContext {

impl MetalContext {
/// Create a new Metal context
pub fn new(n_threads: usize) -> Self {
let raw = unsafe { metal::ggml_metal_init(n_threads.try_into().unwrap()) };
pub fn new() -> Self {
let raw = unsafe { metal::ggml_metal_init(1) };

MetalContext {
contexts: vec![],
Expand Down Expand Up @@ -83,19 +83,14 @@ impl MetalContext {
unsafe {
metal::ggml_metal_graph_compute(
self.ptr.as_ptr(),
graph.inner as *mut ggml_sys::ggml_cgraph as *mut metal::ggml_cgraph,
graph.inner as *mut ggml_sys::ggml_cgraph,
);
}
}

/// Reads a tensor from Metal
pub fn get_tensor(&self, tensor: &Tensor) {
unsafe {
metal::ggml_metal_get_tensor(
self.ptr.as_ptr(),
tensor.ptr.as_ptr() as *mut metal::ggml_tensor,
)
}
unsafe { metal::ggml_metal_get_tensor(self.ptr.as_ptr(), tensor.ptr.as_ptr()) }
}
}

Expand Down
10 changes: 5 additions & 5 deletions crates/ggml/sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ pub const ggml_unary_op_GGML_UNARY_OP_RELU: ggml_unary_op = 6;
pub const ggml_unary_op_GGML_UNARY_OP_GELU: ggml_unary_op = 7;
pub const ggml_unary_op_GGML_UNARY_OP_GELU_QUICK: ggml_unary_op = 8;
pub const ggml_unary_op_GGML_UNARY_OP_SILU: ggml_unary_op = 9;
pub type ggml_unary_op = ::std::os::raw::c_int;
pub type ggml_unary_op = ::std::os::raw::c_uint;
pub const ggml_object_type_GGML_OBJECT_TENSOR: ggml_object_type = 0;
pub const ggml_object_type_GGML_OBJECT_GRAPH: ggml_object_type = 1;
pub const ggml_object_type_GGML_OBJECT_WORK_BUFFER: ggml_object_type = 2;
pub type ggml_object_type = ::std::os::raw::c_int;
pub type ggml_object_type = ::std::os::raw::c_uint;
pub const ggml_log_level_GGML_LOG_LEVEL_ERROR: ggml_log_level = 2;
pub const ggml_log_level_GGML_LOG_LEVEL_WARN: ggml_log_level = 3;
pub const ggml_log_level_GGML_LOG_LEVEL_INFO: ggml_log_level = 4;
pub type ggml_log_level = ::std::os::raw::c_int;
pub type ggml_log_level = ::std::os::raw::c_uint;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct ggml_object {
Expand Down Expand Up @@ -1761,7 +1761,7 @@ extern "C" {
pub const ggml_op_pool_GGML_OP_POOL_MAX: ggml_op_pool = 0;
pub const ggml_op_pool_GGML_OP_POOL_AVG: ggml_op_pool = 1;
pub const ggml_op_pool_GGML_OP_POOL_COUNT: ggml_op_pool = 2;
pub type ggml_op_pool = ::std::os::raw::c_int;
pub type ggml_op_pool = ::std::os::raw::c_uint;
extern "C" {
pub fn ggml_pool_1d(
ctx: *mut ggml_context,
Expand Down Expand Up @@ -3081,7 +3081,7 @@ pub const gguf_type_GGUF_TYPE_UINT64: gguf_type = 10;
pub const gguf_type_GGUF_TYPE_INT64: gguf_type = 11;
pub const gguf_type_GGUF_TYPE_FLOAT64: gguf_type = 12;
pub const gguf_type_GGUF_TYPE_COUNT: gguf_type = 13;
pub type gguf_type = ::std::os::raw::c_int;
pub type gguf_type = ::std::os::raw::c_uint;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct gguf_context {
Expand Down
2 changes: 1 addition & 1 deletion crates/ggml/sys/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ pub const LLAMA_FTYPE_MOSTLY_Q5_K_S: llama_ftype = 16;
pub const LLAMA_FTYPE_MOSTLY_Q5_K_M: llama_ftype = 17;
pub const LLAMA_FTYPE_MOSTLY_Q6_K: llama_ftype = 18;
pub const LLAMA_FTYPE_GUESSED: llama_ftype = 1024;
pub type llama_ftype = ::std::os::raw::c_int;
pub type llama_ftype = ::std::os::raw::c_uint;
4 changes: 4 additions & 0 deletions crates/ggml/sys/src/metal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
/* automatically generated by rust-bindgen 0.65.1 */

use super::ggml_tensor;
use super::ggml_log_callback;
use super::ggml_cgraph;

pub const GGML_METAL_MAX_BUFFERS: u32 = 16;
pub const GGML_METAL_MAX_COMMAND_BUFFERS: u32 = 32;
extern "C" {
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl InferenceSession {
#[cfg(feature = "metal")]
let metal_context = {
if use_gpu {
let mut metal_context = MetalContext::new(config.n_threads);
let mut metal_context = MetalContext::new();
metal_context.add_scratch_buffer(ctx0.storage().as_buffer().unwrap());

for buf in scratch.iter() {
Expand Down
10 changes: 5 additions & 5 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ pub struct FileType {
/// The quantization version.
pub quantization_version: u32,
}
impl From<FileType> for i32 {
impl From<FileType> for u32 {
fn from(value: FileType) -> Self {
(value.quantization_version * ggml::QNT_VERSION_FACTOR) as i32
(value.quantization_version * ggml::QNT_VERSION_FACTOR) as u32
+ ggml::sys::llama::llama_ftype::from(value.format)
}
}
impl TryFrom<i32> for FileType {
impl TryFrom<u32> for FileType {
type Error = ();

fn try_from(value: i32) -> Result<Self, Self::Error> {
fn try_from(value: u32) -> Result<Self, Self::Error> {
let format = FileTypeFormat::try_from(
((value as u32) % ggml::QNT_VERSION_FACTOR) as ggml::sys::llama::llama_ftype,
)?;
Expand Down Expand Up @@ -252,7 +252,7 @@ pub enum LoadError {
#[error("unsupported ftype: {0}")]
/// The `ftype` hyperparameter had an invalid value. This usually means that the format used
/// by this file is unrecognized by this version of `llm`.
UnsupportedFileType(i32),
UnsupportedFileType(u32),
#[error("invalid magic number {magic} for {path:?}")]
/// An invalid magic number was encountered during the loading process.
InvalidMagic {
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{FileType, LoadError};

/// Read the filetype from a reader.
pub fn read_filetype(reader: &mut dyn BufRead) -> Result<FileType, LoadError> {
let ftype = read_i32(reader)?;
let ftype = read_u32(reader)?;
FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))
}

Expand Down

0 comments on commit fcbfb4d

Please sign in to comment.