Skip to content

Commit

Permalink
Merge pull request #4 from edgenai/chore/log-callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
pedro-devv authored Feb 5, 2024
2 parents df3be80 + cf5dca8 commit f8c905f
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 29 deletions.
32 changes: 32 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ cmake = "0.1.50"
derive_more = "0.99.17"
link-cplusplus = "1.0.9"
thiserror = "1.0.50"
tokio = { version = "1.34.0" }
tokio = "1.34.0"
tracing = "0.1.40"
wav = "1.0.0"
whisper_cpp = { path = "crates/whisper_cpp", default-features = false }
whisper_cpp_sys = { path = "crates/whisper_cpp_sys", default-features = false }
1 change: 1 addition & 0 deletions crates/whisper_cpp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ publish = true
thiserror = { workspace = true }
derive_more = { workspace = true }
tokio = { workspace = true, features = ["sync", "rt"] }
tracing = { workspace = true }
whisper_cpp_sys = { workspace = true, default-features = false }

[features]
Expand Down
108 changes: 80 additions & 28 deletions crates/whisper_cpp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::ffi::{CStr, CString};
use core::ffi::{c_char, c_int, CStr};
use std::ffi::CString;
use std::num::NonZeroUsize;
use std::ptr::null_mut;
use std::slice;
use std::sync::atomic::Ordering;
use std::sync::Arc;

use derive_more::{Deref, DerefMut};
Expand All @@ -14,11 +16,27 @@ use whisper_cpp_sys::{
whisper_full_get_token_id_from_state, whisper_full_n_segments_from_state,
whisper_full_n_tokens_from_state, whisper_full_params, whisper_full_params__bindgen_ty_1,
whisper_full_params__bindgen_ty_2, whisper_full_with_state,
whisper_init_from_file_with_params_no_state, whisper_init_state,
whisper_init_from_file_with_params_no_state, whisper_init_state, whisper_log_set,
whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH,
whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY, whisper_state, whisper_token,
};

/// Boolean indicating if a logger has already been set using [`whisper_log_set`].
static LOGGER_SET: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);

/// Sets the global [`whisper.cpp`][whisper.cpp] logger, if it wasn't set already.
///
/// [whisper.cpp]: https://github.com/ggerganov/whisper.cpp/
fn set_log() {
if !LOGGER_SET.swap(true, Ordering::SeqCst) {
unsafe {
// SAFETY: performs a simple assignment to static variables. Should only execute once
// before any logs are made.
whisper_log_set(Some(internal::whisper_log_callback), null_mut());
}
}
}

#[derive(Clone, Deref, DerefMut)]
struct WhisperContext(*mut whisper_context);

Expand Down Expand Up @@ -49,9 +67,11 @@ impl WhisperModel {
/// Loads a new *ggml* *whisper* model, given its file path.
#[doc(alias = "whisper_init_from_file_with_params_no_state")]
pub fn new_from_file<P>(model_path: P, use_gpu: bool) -> Result<Self, WhisperError>
where
P: AsRef<std::path::Path>,
where
P: AsRef<std::path::Path>,
{
set_log();

let params = whisper_context_params { use_gpu };

let path_bytes = model_path
Expand Down Expand Up @@ -220,7 +240,7 @@ impl WhisperSession {
self.state.0,
c_params,
samples.as_ptr(),
samples.len() as std::os::raw::c_int,
samples.len() as c_int,
)
};

Expand Down Expand Up @@ -272,10 +292,7 @@ impl WhisperSession {
#[doc(alias = "whisper_full_get_segment_text_from_state")]
pub fn segment_text(&self, segment: u32) -> Result<String, WhisperSessionError> {
let text = unsafe {
let res = whisper_full_get_segment_text_from_state(
self.state.0,
segment as std::os::raw::c_int,
);
let res = whisper_full_get_segment_text_from_state(self.state.0, segment as c_int);
CStr::from_ptr(res.cast_mut())
};

Expand All @@ -285,9 +302,7 @@ impl WhisperSession {
/// Get number of tokens in the specified segment.
#[doc(alias = "whisper_full_n_tokens_from_state")]
pub fn token_count(&self, segment: u32) -> u32 {
let res = unsafe {
whisper_full_n_tokens_from_state(self.state.0, segment as std::os::raw::c_int)
};
let res = unsafe { whisper_full_n_tokens_from_state(self.state.0, segment as c_int) };

res as u32
}
Expand All @@ -302,11 +317,7 @@ impl WhisperSession {
#[doc(alias = "whisper_full_get_token_id_from_state")]
pub fn token_id(&self, segment: u32, token: u32) -> i32 {
unsafe {
whisper_full_get_token_id_from_state(
self.state.0,
segment as std::os::raw::c_int,
token as std::os::raw::c_int,
)
whisper_full_get_token_id_from_state(self.state.0, segment as c_int, token as c_int)
}
}

Expand Down Expand Up @@ -523,7 +534,7 @@ impl WhisperParams {
fn push_str(
storage: &mut Vec<CString>,
value: &str,
) -> Result<*const std::os::raw::c_char, std::ffi::NulError> {
) -> Result<*const c_char, std::ffi::NulError> {
if value.is_empty() {
Ok(null_mut())
} else {
Expand All @@ -540,10 +551,10 @@ impl WhisperParams {
whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH
}
},
n_threads: self.thread_count as std::os::raw::c_int,
n_max_text_ctx: self.max_text_ctx as std::os::raw::c_int,
offset_ms: self.offset_ms as std::os::raw::c_int,
duration_ms: self.duration_ms as std::os::raw::c_int,
n_threads: self.thread_count as c_int,
n_max_text_ctx: self.max_text_ctx as c_int,
offset_ms: self.offset_ms as c_int,
duration_ms: self.duration_ms as c_int,
translate: self.translate,
no_context: self.no_context,
no_timestamps: self.no_timestamps,
Expand All @@ -555,12 +566,12 @@ impl WhisperParams {
token_timestamps: self.token_timestamps,
thold_pt: self.thold_pt,
thold_ptsum: self.thold_ptsum,
max_len: self.max_len as std::os::raw::c_int,
max_len: self.max_len as c_int,
split_on_word: self.split_on_word,
max_tokens: self.max_tokens as std::os::raw::c_int,
max_tokens: self.max_tokens as c_int,
speed_up: self.speed_up,
debug_mode: self.debug_mode,
audio_ctx: self.audio_ctx as std::os::raw::c_int,
audio_ctx: self.audio_ctx as c_int,
tdrz_enable: self.tdrz_enable,
initial_prompt: push_str(&mut v, &self.initial_prompt)?,
prompt_tokens: {
Expand All @@ -570,7 +581,7 @@ impl WhisperParams {
self.prompt_tokens.as_ptr()
}
},
prompt_n_tokens: self.prompt_tokens.len() as std::os::raw::c_int,
prompt_n_tokens: self.prompt_tokens.len() as c_int,
language: push_str(&mut v, &self.language)?,
detect_language: self.detect_language,
suppress_blank: self.suppress_blank,
Expand All @@ -585,7 +596,7 @@ impl WhisperParams {
greedy: {
if let WhisperSampling::Greedy { best_of } = self.strategy {
whisper_full_params__bindgen_ty_1 {
best_of: best_of as std::os::raw::c_int,
best_of: best_of as c_int,
}
} else {
whisper_full_params__bindgen_ty_1 { best_of: 0 }
Expand All @@ -598,7 +609,7 @@ impl WhisperParams {
} = self.strategy
{
whisper_full_params__bindgen_ty_2 {
beam_size: beam_size as std::os::raw::c_int,
beam_size: beam_size as c_int,
patience,
}
} else {
Expand Down Expand Up @@ -726,3 +737,44 @@ impl From<whisper_full_params> for WhisperParams {
}
}
}

mod internal {
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]

use core::ffi::{c_char, c_void, CStr};

use tracing::{error, info, trace, warn};

use whisper_cpp_sys::{
ggml_log_level, ggml_log_level_GGML_LOG_LEVEL_ERROR, ggml_log_level_GGML_LOG_LEVEL_INFO,
ggml_log_level_GGML_LOG_LEVEL_WARN,
};

#[no_mangle]
pub(crate) unsafe extern "C" fn whisper_log_callback(
level: ggml_log_level,
text: *const c_char,
_user_data: *mut c_void,
) {
let text = unsafe {
// SAFETY: `text` is a NUL-terminated C String.
CStr::from_ptr(text)
};
let text = String::from_utf8_lossy(text.to_bytes());

let text = if let Some(stripped) = text.strip_suffix('\n') {
stripped
} else {
text.as_ref()
};

match level {
ggml_log_level_GGML_LOG_LEVEL_ERROR => error!("ggml: {text}"),
ggml_log_level_GGML_LOG_LEVEL_INFO => info!("ggml: {text}"),
ggml_log_level_GGML_LOG_LEVEL_WARN => warn!("ggml: {text}"),
_ => trace!("ggml: {text}"),
}
}
}

0 comments on commit f8c905f

Please sign in to comment.