Skip to content

Commit

Permalink
Merge pull request #6 from binedge/feat/test-harness
Browse files Browse the repository at this point in the history
feat: `*_async` functions, test harness, other fixes
  • Loading branch information
scriptis authored Nov 8, 2023
2 parents cb62c12 + b97b1e1 commit 3e4aebc
Show file tree
Hide file tree
Showing 10 changed files with 1,109 additions and 42 deletions.
820 changes: 812 additions & 8 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
resolver = "2"
members = [
"crates/llama_cpp_sys",
"crates/llama_cpp"
"crates/llama_cpp",
"crates/llama_cpp_tests"
]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ let mut decoded_tokens = 0;
let mut completions = ctx.get_completions();

while let Some(next_token) = completions.next_token() {
println!("{}", String::from_utf8_lossy(next_token.as_bytes()));
println!("{}", String::from_utf8_lossy(&*next_token.detokenize()));
decoded_tokens += 1;
if decoded_tokens > max_tokens {
break;
Expand Down
73 changes: 72 additions & 1 deletion crates/llama_cpp/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,82 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v0.1.2 (2023-11-08)

### New Features

- <csr-id-dcfccdf721eb47a364cce5b1c7a54bcf94335ac0/> more `async` function variants
- <csr-id-56285a119633682951f8748e85c6b8988e514232/> add `LlamaSession.model`

### Commit Statistics

<csr-read-only-do-not-edit/>

- 2 commits contributed to the release.
- 2 commits were understood as [conventional](https://www.conventionalcommits.org).
- 0 issues like '(#ID)' were seen in commit messages

### Commit Details

<csr-read-only-do-not-edit/>

<details><summary>view details</summary>

* **Uncategorized**
- More `async` function variants ([`dcfccdf`](https://github.com/binedge/llama_cpp-rs/commit/dcfccdf721eb47a364cce5b1c7a54bcf94335ac0))
- Add `LlamaSession.model` ([`56285a1`](https://github.com/binedge/llama_cpp-rs/commit/56285a119633682951f8748e85c6b8988e514232))
</details>

## v0.1.1 (2023-11-08)

<csr-id-3eddbab3cc35a59acbe66fa4f5333a9ca0edb326/>

### Chore

- <csr-id-3eddbab3cc35a59acbe66fa4f5333a9ca0edb326/> Remove debug binary from Cargo.toml

### New Features

- <csr-id-3bada658c9139af1c3dcdb32c60c222efb87a9f6/> add `LlamaModel::load_from_file_async`

### Bug Fixes

- <csr-id-b676baa3c1a6863c7afd7a88b6f7e8ddd2a1b9bd/> require `llama_context` is accessed from behind a mutex
This solves a race condition when several `get_completions` threads are spawned at the same time
- <csr-id-4eb0bc9800877e460fe0d1d25398f35976b4d730/> `start_completing` should not be invoked on a per-iteration basis
There's still some UB that can be triggered due to llama.cpp's threading model, which needs patching up.

### Commit Statistics

<csr-read-only-do-not-edit/>

- 6 commits contributed to the release.
- 13 days passed between releases.
- 4 commits were understood as [conventional](https://www.conventionalcommits.org).
- 0 issues like '(#ID)' were seen in commit messages

### Commit Details

<csr-read-only-do-not-edit/>

<details><summary>view details</summary>

* **Uncategorized**
- Release llama_cpp_sys v0.2.1, llama_cpp v0.1.1 ([`ef4e3f7`](https://github.com/binedge/llama_cpp-rs/commit/ef4e3f7a3c868a892f26acfae2a5211de4900d1c))
- Add `LlamaModel::load_from_file_async` ([`3bada65`](https://github.com/binedge/llama_cpp-rs/commit/3bada658c9139af1c3dcdb32c60c222efb87a9f6))
- Remove debug binary from Cargo.toml ([`3eddbab`](https://github.com/binedge/llama_cpp-rs/commit/3eddbab3cc35a59acbe66fa4f5333a9ca0edb326))
- Require `llama_context` is accessed from behind a mutex ([`b676baa`](https://github.com/binedge/llama_cpp-rs/commit/b676baa3c1a6863c7afd7a88b6f7e8ddd2a1b9bd))
- `start_completing` should not be invoked on a per-iteration basis ([`4eb0bc9`](https://github.com/binedge/llama_cpp-rs/commit/4eb0bc9800877e460fe0d1d25398f35976b4d730))
- Update to llama.cpp 0a7c980 ([`94d7385`](https://github.com/binedge/llama_cpp-rs/commit/94d7385fefdab42ac6949c6d47c5ed262db08365))
</details>

## v0.1.0 (2023-10-25)

<csr-id-702a6ff49d83b10a0573a5ca1fb419efaa43746e/>
<csr-id-116fe8c82fe2c43bf9041f6dbfe2ed15d00e18e9/>
<csr-id-96548c840d3101091c879648074fa0ed1cee3011/>
<csr-id-a5fb19499ecbb1060ca8211111f186efc6e9b114/>
<csr-id-aa5eed4dcb6f50b25c878e584787211402a9138b/>

### Chore

Expand All @@ -34,7 +104,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

<csr-read-only-do-not-edit/>

- 8 commits contributed to the release over the course of 5 calendar days.
- 9 commits contributed to the release over the course of 5 calendar days.
- 6 commits were understood as [conventional](https://www.conventionalcommits.org).
- 1 unique issue was worked on: [#3](https://github.com/binedge/llama_cpp-rs/issues/3)

Expand All @@ -47,6 +117,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* **[#3](https://github.com/binedge/llama_cpp-rs/issues/3)**
- Release ([`116fe8c`](https://github.com/binedge/llama_cpp-rs/commit/116fe8c82fe2c43bf9041f6dbfe2ed15d00e18e9))
* **Uncategorized**
- Release llama_cpp v0.1.0 ([`f24c7fe`](https://github.com/binedge/llama_cpp-rs/commit/f24c7fe3ebd851a56301ce3d5a1b4250d2d797b9))
- Add CHANGELOG.md ([`aa5eed4`](https://github.com/binedge/llama_cpp-rs/commit/aa5eed4dcb6f50b25c878e584787211402a9138b))
- Remove `include` from llama_cpp ([`702a6ff`](https://github.com/binedge/llama_cpp-rs/commit/702a6ff49d83b10a0573a5ca1fb419efaa43746e))
- Use SPDX license identifiers ([`2cb06ae`](https://github.com/binedge/llama_cpp-rs/commit/2cb06aea62b892a032f515b78d720acb915f4a22))
Expand Down
5 changes: 3 additions & 2 deletions crates/llama_cpp/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "llama_cpp"
version = "0.1.0"
version = "0.1.2"
description = "High-level bindings to llama.cpp with a focus on just being really, really easy to use"
edition = "2021"
authors = ["Dakota Thompson <[email protected]>"]
Expand All @@ -13,8 +13,9 @@ publish = true
ctor = "0.2.5"
derive_more = "0.99.17"
flume = "0.11.0"
llama_cpp_sys = { version = "^0.2.0", path = "../llama_cpp_sys" }
llama_cpp_sys = { version = "^0.2.1", path = "../llama_cpp_sys" }
num_cpus = "1.16.0"
thiserror = "1.0.50"
tinyvec = { version = "1.6.0", features = ["alloc"] }
tokio = { version = "1.33.0", features = ["sync", "rt"] }
tracing = "0.1.39"
119 changes: 92 additions & 27 deletions crates/llama_cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
//! let mut completions = ctx.start_completing();
//!
//! while let Some(next_token) = completions.next_token() {
//! println!("{}", String::from_utf8_lossy(next_token.as_bytes()));
//! println!("{}", String::from_utf8_lossy(&*next_token.detokenize()));
//!
//! decoded_tokens += 1;
//!
Expand Down Expand Up @@ -74,10 +74,13 @@
//! [llama.cpp]: https://github.com/ggerganov/llama.cpp/

#![warn(missing_docs)]

use std::ffi::{c_void, CStr, CString};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::{ptr, thread};
use tinyvec::TinyVec;
use tokio::sync::{Mutex, RwLock};

use ctor::{ctor, dtor};
Expand Down Expand Up @@ -184,6 +187,7 @@ pub struct LlamaInternalError;
struct LlamaModelInner(*mut llama_model);

unsafe impl Send for LlamaModelInner {}

unsafe impl Sync for LlamaModelInner {}

impl Drop for LlamaModelInner {
Expand Down Expand Up @@ -297,7 +301,9 @@ impl LlamaModel {
pub async fn load_from_file_async(file_path: impl AsRef<Path>) -> Result<Self, LlamaLoadError> {
let path = file_path.as_ref().to_owned();

tokio::task::spawn_blocking(move || Self::load_from_file(path)).await.unwrap()
tokio::task::spawn_blocking(move || Self::load_from_file(path))
.await
.unwrap()
}

/// Converts `content` into a vector of tokens that are valid input for this model.
Expand Down Expand Up @@ -364,7 +370,13 @@ impl LlamaModel {
token.0
);

unsafe { CStr::from_ptr(llama_token_get_text(**self.model.try_read().unwrap(), token.0)) }.to_bytes()
unsafe {
CStr::from_ptr(llama_token_get_text(
**self.model.try_read().unwrap(),
token.0,
))
}
.to_bytes()
}

/// Creates a new evaluation context for this model.
Expand All @@ -384,7 +396,7 @@ impl LlamaModel {
let ctx = unsafe {
// SAFETY: due to `_model` being declared in the `LlamaContext`, `self` must live
// for at least the lifetime of `LlamaContext`.
llama_new_context_with_model(**self.model.blocking_read(), params)
llama_new_context_with_model(**self.model.try_read().unwrap(), params)
};

let cpus = num_cpus::get() as u32;
Expand All @@ -396,13 +408,14 @@ impl LlamaModel {
}

LlamaSession {
model: self.clone(),
inner: Arc::new(Mutex::new(LlamaContextInner { ptr: ctx }) ),
history_size: 0,
inner: Arc::new(LlamaSessionInner {
model: self.clone(),
ctx: Mutex::new(LlamaContextInner { ptr: ctx }),
history_size: AtomicUsize::new(0),
}),
}
}


/// Returns the beginning of sentence (BOS) token for this context.
pub fn bos(&self) -> Token {
self.bos_token
Expand Down Expand Up @@ -448,6 +461,7 @@ struct LlamaContextInner {
}

unsafe impl Send for LlamaContextInner {}

unsafe impl Sync for LlamaContextInner {}

impl Drop for LlamaContextInner {
Expand All @@ -464,15 +478,21 @@ impl Drop for LlamaContextInner {
///
/// This stores a small amount of state, which is destroyed when the session is dropped.
/// You can create an arbitrary number of sessions for a model using [`LlamaModel::create_session`].
#[derive(Clone)]
pub struct LlamaSession {
inner: Arc<LlamaSessionInner>,
}

/// The cloned part of a [`LlamaSession`].
struct LlamaSessionInner {
/// The model this session was created from.
model: LlamaModel,

/// A pointer to the llama.cpp side of the model context.
inner: Arc<Mutex<LlamaContextInner>>,
ctx: Mutex<LlamaContextInner>,

/// The number of tokens present in this model's context.
history_size: usize,
history_size: AtomicUsize,
}

/// An error raised while advancing the context in a [`LlamaSession`].
Expand Down Expand Up @@ -508,7 +528,7 @@ impl LlamaSession {
///
/// The model will generate new tokens from the end of the context.
pub fn advance_context_with_tokens(
&mut self,
&self,
tokens: impl AsRef<[Token]>,
) -> Result<(), LlamaContextError> {
let tokens = tokens.as_ref();
Expand Down Expand Up @@ -562,7 +582,7 @@ impl LlamaSession {
if unsafe {
// SAFETY: `llama_decode` will not fail for a valid `batch`, which we correctly
// initialized above.
llama_decode(self.inner.blocking_lock().ptr, batch)
llama_decode(self.inner.ctx.blocking_lock().ptr, batch)
} != 0
{
return Err(LlamaInternalError.into());
Expand All @@ -577,46 +597,89 @@ impl LlamaSession {
llama_batch_free(batch)
};

self.history_size += tokens.len();
self.inner
.history_size
.fetch_add(n_tokens, Ordering::SeqCst);

Ok(())
}

/// Advances the inner context of this model with `tokens`.
///
/// This is a thin `tokio::spawn_blocking` wrapper around
/// [`LlamaSession::advance_context_with_tokens`].
pub async fn advance_context_with_tokens_async(
&mut self,
tokens: impl AsRef<[Token]>,
) -> Result<(), LlamaContextError> {
let tokens = tokens.as_ref().to_owned();
let session = self.clone();

tokio::task::spawn_blocking(move || session.advance_context_with_tokens(tokens))
.await
.unwrap()
}

/// Tokenizes and feeds an arbitrary byte buffer `ctx` into this model.
///
/// `ctx` is typically a UTF-8 string, but anything that can be downcast to bytes is accepted.
pub fn advance_context(&mut self, ctx: impl AsRef<[u8]>) -> Result<(), LlamaContextError> {
let tokens = self.model.tokenize_bytes(ctx.as_ref())?.into_boxed_slice();
let tokens = self
.inner
.model
.tokenize_bytes(ctx.as_ref())?
.into_boxed_slice();

self.advance_context_with_tokens(tokens)
}

/// Tokenizes and feeds an arbitrary byte buffer `ctx` into this model.
///
/// This is a thin `tokio::spawn_blocking` wrapper around
/// [`LlamaSession::advance_context`].
pub async fn advance_context_async(
&self,
ctx: impl AsRef<[u8]>,
) -> Result<(), LlamaContextError> {
let ctx = ctx.as_ref().to_owned();
let session = self.clone();

tokio::task::spawn_blocking(move || {
let tokens = session.inner.model.tokenize_bytes(ctx)?.into_boxed_slice();

session.advance_context_with_tokens(tokens)
})
.await
.unwrap()
}

/// Starts generating tokens at the end of the context using llama.cpp's built-in Beam search.
/// This is where you want to be if you just want some completions.
pub fn start_completing(&mut self) -> CompletionHandle {
let (tx, rx) = flume::unbounded();
let history_size = self.inner.history_size.load(Ordering::SeqCst);
let session = self.clone();

info!(
"Generating completions with {} tokens of history",
self.history_size,
);

let past_tokens = self.history_size;
let mutex = self.inner.clone();
info!("Generating completions with {history_size} tokens of history");

thread::spawn(move || unsafe {
llama_beam_search(
mutex.blocking_lock().ptr,
session.inner.ctx.blocking_lock().ptr,
Some(detail::llama_beam_search_callback),
Box::leak(Box::new(detail::BeamSearchState { tx })) as *mut _ as *mut c_void,
1,
past_tokens as i32,
history_size as i32,
32_768,
);
});

CompletionHandle { ctx: self, rx }
}

/// Returns the model this session was created from.
pub fn model(&self) -> LlamaModel {
self.inner.model.clone()
}
}

/// An intermediate token generated during an LLM completion.
Expand All @@ -629,9 +692,11 @@ pub struct CompletionToken<'a> {
}

impl<'a> CompletionToken<'a> {
/// Decodes this token, returning the bytes composing it.
pub fn as_bytes(&self) -> &[u8] {
self.ctx.model.detokenize(self.token)
/// Decodes this token, returning the bytes it is composed of.
pub fn detokenize(&self) -> TinyVec<[u8; 8]> {
let model = self.ctx.model();

model.detokenize(self.token).into()
}

/// Returns this token as an `i32`.
Expand Down Expand Up @@ -730,7 +795,7 @@ mod detail {
// SAFETY: beam_views[i] exists where 0 <= i <= n_beams.
*beam_state.beam_views.add(i)
}
.eob = true;
.eob = true;
}
}

Expand Down
Loading

0 comments on commit 3e4aebc

Please sign in to comment.